In [1]:
from ipynb.fs.full.Helper import upperTriangle

import torch.nn as nn
import torch

import math


def gramMatrix(input, weight):
    weight = math.sqrt(weight)
    a, b, c, d = input.size()
    features = input.view(a, b, c * d)
    G = torch.bmm(features, features.transpose(1, 2))
    G = upperTriangle(G)
    G = G.view(a, -1)
    return weight * G.div(b * c * d)


class GramMatrixLayer(nn.Module):
    def __init__(self, weight):
        super(GramMatrixLayer, self).__init__()
        self.weight = weight
        
    def __repr__(self):
        return "GramMatrixLayer(λ=" + str(self.weight) + ")"
        
    def forward(self, input):
        self.gramMatrix = gramMatrix(input, self.weight)
        self.gramMatrix = self.gramMatrix.detach()
        return input

In [2]:
def convertModel(cnn, gramMatrixLayers, gramMatrixWeights, testing=False):
    layers = next(iter(cnn.children()))
    
    if testing is True:
        print(cnn)
    
    model = nn.Sequential()
    gram_matrices = []
    idx = 0
    
    i = 1
    j = 1
    for layer in layers:
        if isinstance(layer, nn.Conv2d):
            name = 'conv{}_{}'.format(i, j)
        elif isinstance(layer, nn.ReLU):
            name = 'relu{}_{}'.format(i, j)
            j += 1
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool{}'.format(i)
            i += 1
            j = 1
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in gramMatrixLayers:
            gram_matrix = GramMatrixLayer(gramMatrixWeights[idx])
            idx += 1
            model.add_module("gram_matrix{}_{}".format(i, j-1), gram_matrix)
            gram_matrices.append(gram_matrix)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], GramMatrixLayer):
            break

    model = model[:(i + 1)]

    if testing is True:
        print("\n", model)
    
    return cnn, model, gram_matrices

In [3]:
if __name__ == '__main__':
    import torchvision.models as models
    vgg19 = models.vgg19(pretrained=True)
    gramMatrixLayers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
    gramMatrixWeights = [1, 1, 1, 1, 1]
    vgg19, model, gram_matrices = convertModel(vgg19, gramMatrixLayers, gramMatrixWeights, testing=True)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)