In [1]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

import torchvision.models as models

In [15]:
def get_mean_var(t):
    # t for tensor
    
    n_batch, n_ch, h, w = t.size()

    t_view = t.view(n_batch, n_ch, h*w)
    t_mean = t_view.mean(dim=2)
    t_var = t_view.var(dim=2)

    # broadcast
    t_mean = t_mean.view(n_batch, n_ch, 1, 1).expand_as(t)
    t_var = t_var.view(n_batch, n_ch, 1, 1).expand_as(t)
    
    return t_mean, t_var

In [17]:
class AdaInstanceNormalization(nn.Module):
    
    '''
    
    style transfer in 'feature space'. 
    Combining the content image's feature maps and style image's feature maps together, channel-wise. 
        - weighting the style feature maps much more strongly. 
        - i wonder if this could be done by just adding feature maps together, or linearly combining them in some way
    
    '''
    
    def __init__(self, eps=1e-6):
        super(AdaInstanceNormalization, self).__init__()
        self.eps = eps
        
    def forward(self, c, s):
        '''
        c - content image's feature map
        s - style image's feature map
        '''
        c_mean, c_var = get_mean_var(c)
        s_mean, s_var = get_mean_var(s)
        
        return s_var * ((c - c_mean) / (c_var + self.eps)) + s_mean
        

In [2]:
class EncoderLayer(nn.Module):
    '''
    - construct partial VGG-19 model (through relu_4_1)
    - will fill in weights in separate method
    '''
    
    def __init__(self, batch_norm):
        super(EncoderLayer, self).__init__()
        
        # constructing partial architecture of VGG-19 (through relu_4_1)
        conf = models.vgg.cfg['E'][:12] 
        self.features = models.vgg.make_layers(conf, batch_norm=batch_norm)
        
    def forward(self, x):
        return self.features(x)
        

In [29]:
# When you initialize an nn.Module, it creates a state_dict automatically. 

In [3]:
def build_encoder(model_file, batch_norm=True):
    '''
    
    initialize partial VGG-19 model and fill with pre-trained weights
    
    '''
    VGG_TYPE = 'vgg19_bn' if batch_norm else 'vgg19'
    enc = EncoderLayer(batch_norm)
    
    if model_file and os.path.isfile(model_file):
        enc.load_state_dict(torch.load(model_file))
    else:
        vgg_weights = model_zoo.load_url(models.vgg.model_urls[VGG_TYPE])
        w = {}
        for key in enc.state_dict().keys():
            w[key] = vgg_weights[key]
        enc.load_state_dict(w)
        if not model_file:
            model_file = "encoder.model"
        torch.save(enc.state_dict(), model_file)
        
    return enc

In [5]:
class DecoderLayer(nn.Module):
    '''
    
    A decoder that converts a set of feature maps to an image.  
    AdaIN layer --> decoder --> stylized image
    
    The reverse of VGG-19, with max-pooling replaced by upsampling. 
    
    '''
    def __init__(self):
        super(DecoderLayer, self).__init__()
        
        conf = [
            (1, 256),
            'U',
            (3, 256),
            'U',
            (1, 128),
            (1, 64),
            'U',
            (1, 64),
            (1, 3)
        ]
        
        self.features = self._make_layers(conf)
        
    def _make_layers(self, conf):
        layers = []
        in_channels = 512
        for block in conf:
            if block == 'U':
                layers += nn.Upsample(scale_factor=2, mode='bilinear')
                continue
            
            n_layer, n_feat = block
            for i in range(0, n_layer):
                layers += [
                    nn.ReflectionPad2d(1),
                    nn.Conv2d(in_channels, n_feat, kernel_size=3, stride=1),
                    nn.Relu()
                ]
                in_channels = n_feat
        layers.pop()
        
        return nn.Sequential(*layers)
        
        
    def forward(self, x):
        return self.features(x)
        

In [7]:
def make_decoder(model_file):
    '''
    
    make a pre-trained partial VGG-19 network
    
    '''
    
    dec = DecoderLayer()
    if model_file and os.path.isfile(model_file):
        dec.load_state_dict(torch.load(model_file))
    else:
        raise ValueError('Decoder model not found.')
        
    return dec