In [1]:
from model.began import Generator128
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from collections import deque

import torchvision.utils as vutils

import os
import os.path

import argparse

import random

In [2]:
state_dict = torch.load('./trained_model/gen_208000.pth', map_location='cuda:0')

In [3]:
state_dict.keys()

odict_keys(['l0.weight', 'l0.bias', 'l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'l3.weight', 'l3.bias', 'l4.weight', 'l4.bias', 'l5.weight', 'l5.bias', 'l6.weight', 'l6.bias', 'l7.weight', 'l7.bias', 'l8.weight', 'l8.bias', 'l10.weight', 'l10.bias', 'l11.weight', 'l11.bias', 'l9.weight', 'l9.bias'])

In [4]:
class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(*self.shape)

class ConvBlock(nn.Module):
    """
    All convs are created with:
    conv(in_channel, out_channel, kernel, stride, pad, bias)
    """
    def __init__(self, in_ch, out_ch, k, s, p, b):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch,
                      out_ch,
                      kernel_size=k,
                      stride=s,
                      padding=p,
                      bias=b), nn.ELU())

    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.ch = 128
        self.latent_dim = 64
        self.scale_size = 128
        self.initial_size = 8
        
        self.layers = nn.ModuleList([
             nn.Sequential(
                nn.Linear(self.latent_dim,
                          self.initial_size**2 * self.ch,
                          bias=True),
                View((-1, self.ch, self.initial_size, self.initial_size))
             ),
            
            # first block
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            
            # second block
            nn.UpsamplingNearest2d(scale_factor=2),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            
            # third block
            nn.UpsamplingNearest2d(scale_factor=2),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            
            # fourth block
            nn.UpsamplingNearest2d(scale_factor=2),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            
            # fifth block
            nn.UpsamplingNearest2d(scale_factor=2),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            ConvBlock(self.ch, self.ch, 3, 1, 1, True),
            
            # last block
            nn.Sequential(
                nn.Conv2d(self.ch, 3, 3, 1, 1),
                nn.Tanh()
            )
        ])
        
        self.input_shapes = [
            # Raw input shape
            ((self.latent_dim, ), ()),

            # Skip Linear+View()
            ((128, 8, 8), ()),

            # First block
            ((128, 8, 8), ()),
            ((128, 8, 8), ()),

            # Second conv
            ((128, 16, 16), ()),
            ((128, 16, 16), ()),
            ((128, 16, 16), ()),

            # Third conv
            ((128, 32, 32), ()),
            ((128, 32, 32), ()),
            ((128, 32, 32), ()),
             
            # Fourth conv
            ((128, 64, 64), ()),
            ((128, 64, 64), ()),
            ((128, 64, 64), ()),
             
            # Fifth conv
            ((128, 128, 128), ()),
            ((128, 128, 128), ()),
            ((128, 128, 128), ()),

            # Skip the whole net
            ((3, 128, 128), ()),
        ]
#         self.l0 = nn.Linear(self.h, 8*8*self.num_channel)
#         self.l1 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.l2 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
#         self.l3 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.l4 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
#         self.l5 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.l6 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
#         self.l7 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.l8 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         if self.scale_size == 128:
#             self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
#             self.l10 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#             self.l11 = nn.Conv2d(self.num_channel, self.num_channel, 3, 1, 1)
#         self.l9 = nn.Conv2d(self.num_channel, 3, 3, 1, 1)
        
        self._check_input_shapes()
        
    def _check_input_shapes(self):
        for n_cuts, (x1_shape, x2_shape) in enumerate(self.input_shapes):
            print(n_cuts)
            x1 = torch.randn(1, *x1_shape)
            if n_cuts <= 1:
                x2 = None
            else:
                x2 = torch.randn(1, *x2_shape)
            res = self.forward(x1, x2, n_cuts)
            print(x1.shape, () if n_cuts <= 1 else x2.shape, res.shape[1:])
            

    def forward(self, z, z2=None, n_cuts=0, end=None):
        if end is None:
            end = len(self.layers)
        for i, layer in enumerate(self.layers[n_cuts:end]):
            z = layer(z)
        return z
    
    def __str__(self):
        return f'Began.Gen128.latent_dim={self.latent_dim}'

In [5]:
gen = Decoder()
gen.cuda()

0
torch.Size([1, 64]) () torch.Size([3, 128, 128])
1
torch.Size([1, 128, 8, 8]) () torch.Size([3, 128, 128])
2
torch.Size([1, 128, 8, 8]) torch.Size([1]) torch.Size([3, 128, 128])
3
torch.Size([1, 128, 8, 8]) torch.Size([1]) torch.Size([3, 128, 128])
4
torch.Size([1, 128, 16, 16]) torch.Size([1]) torch.Size([3, 128, 128])
5
torch.Size([1, 128, 16, 16]) torch.Size([1]) torch.Size([3, 128, 128])
6
torch.Size([1, 128, 16, 16]) torch.Size([1]) torch.Size([3, 128, 128])
7
torch.Size([1, 128, 32, 32]) torch.Size([1]) torch.Size([3, 128, 128])
8
torch.Size([1, 128, 32, 32]) torch.Size([1]) torch.Size([3, 128, 128])
9
torch.Size([1, 128, 32, 32]) torch.Size([1]) torch.Size([3, 128, 128])
10
torch.Size([1, 128, 64, 64]) torch.Size([1]) torch.Size([3, 128, 128])
11
torch.Size([1, 128, 64, 64]) torch.Size([1]) torch.Size([3, 128, 128])
12
torch.Size([1, 128, 64, 64]) torch.Size([1]) torch.Size([3, 128, 128])
13
torch.Size([1, 128, 128, 128]) torch.Size([1]) torch.Size([3, 128, 128])
14
torch.Size

Decoder(
  (layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=64, out_features=8192, bias=True)
      (1): View()
    )
    (1): ConvBlock(
      (net): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (2): ConvBlock(
      (net): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (3): UpsamplingNearest2d(scale_factor=2.0, mode=nearest)
    (4): ConvBlock(
      (net): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (5): ConvBlock(
      (net): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ELU(alpha=1.0)
      )
    )
    (6): UpsamplingNearest2d(scale_factor=2.0, mode=nearest)
    (7): ConvBlock(
      (net): Sequential(
        (0): Conv2d

In [6]:
old_to_new = {
    'l0': 'layers.0.0', # linear
    'l1': 'layers.1.net.0',
    'l2': 'layers.2.net.0',
    'l3': 'layers.4.net.0',
    'l4': 'layers.5.net.0',
    'l5': 'layers.7.net.0',
    'l6': 'layers.8.net.0',
    'l7': 'layers.10.net.0',
    'l8': 'layers.11.net.0',
    'l9': 'layers.15.0',
    'l10': 'layers.13.net.0',
    'l11': 'layers.14.net.0',
}
    
def _rename_state_dict(name_mapping, state_dict):
    # rename layers
    for key in list(state_dict.keys()):
        layer = key.split('.')[0]
        rest = key.split('.')[1]
        new_key = name_mapping[layer] + '.' + rest
        state_dict[new_key] = state_dict[key]
        del state_dict[key]
      
    return state_dict

state_dict = _rename_state_dict(old_to_new, state_dict)

In [7]:
gen.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
def generative_experiments(obj):
    z = []
    for inter in range(10):
        z0 = np.random.uniform(-1,1,64)
        z10 = np.random.uniform(-1,1,64)
        def slerp(val, low, high):
            omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
            so = np.sin(omega)
            if so == 0:
                return (1.0-val) * low + val * high # L'Hopital's rule/LERP
            return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high 

        z.append(z0)
        for i in range(1, 9):
            z.append(slerp(i*0.1, z0, z10))
        z.append(z10.reshape(1, 64)) 
    z = [_.reshape(1, 64) for _ in z]
    z_var = Variable(torch.from_numpy(np.concatenate(z, 0)).float())
    z_var = z_var.cuda()
    gen_z = obj(z_var)
    vutils.save_image(gen_z.data, '%s_%s_gen.png'%('began', 12), nrow=10, normalize=True)


In [11]:
generative_experiments(gen)