In [1]:
import torch.nn as nn
import torch


def gramMatrix(input):
    a, b, c, d = input.size()
    features = input.view(a, b, c * d)
    G = torch.bmm(features, features.transpose(1, 2))
    return G.div(b * c * d)


class GramMatrixLayer(nn.Module):
    def __init__(self):
        super(GramMatrixLayer, self).__init__()
        
    def forward(self, input):
        self.gramMatrix = gramMatrix(input)
        return input

In [8]:
def convertModel(cnn, gramMatrixLayers, testing=False):
    layers = next(iter(cnn.children()))
    
    if testing is True:
        print(cnn)
    
    model = nn.Sequential()
    gram_matrices = []
    
    i = 1
    j = 1
    for layer in layers:
        if isinstance(layer, nn.Conv2d):
            name = 'conv{}_{}'.format(i, j)
        elif isinstance(layer, nn.ReLU):
            name = 'relu{}_{}'.format(i, j)
            j += 1
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool{}'.format(i)
            i += 1
            j = 1
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in gramMatrixLayers:
            gram_matrix = GramMatrixLayer()
            model.add_module("gram_matrix{}_{}".format(i, j-1), gram_matrix)
            gram_matrices.append(gram_matrix)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], GramMatrixLayer):
            break

    model = model[:(i + 1)]

    if testing is True:
        print("\n", model)
    
    return cnn, model, gram_matrices

In [4]:
if __name__ == '__main__':
    import torchvision.models as models
    vgg19 = models.vgg19(pretrained=True)
    gramMatrixLayers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1']
    vgg19, model, gram_matrices = convertModel(vgg19, gramMatrixLayers, testing=True)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)

In [5]:
# import torch

# M1 = torch.rand(32, 2, 3)
# print(M1)
# a, b, c = M1.size()
# M1 = M1.view(a*b, c)
# print(M1)
# # print(M1.t())

# print(torch.mm(M1.t(), M1))

# print("\n")

# M2 = torch.rand(64, 3)
# # print(M2)
# # print(M2.t())

# print(torch.mm(M2.t(), M2))

# print("\n")

# M3 = torch.rand(64, 3)
# # print(M3)
# # print(M3.t())

# print(torch.mm(M3.t(), M3))

In [6]:
# # M3 = torch.cat((M1, M2), 0)
# # M4 = torch.cat((torch.rand(0), M3), 0)

# arr = []
# arr.append(M1)
# arr.append(M2)
# arr.append(M3)

# b = torch.Tensor(3, 64, 3)
# torch.stack(arr, out=b)

# # print(b)
# # print(b.transpose(1, 2))

# print(torch.bmm(b.transpose(1, 2), b))
# # print(M1)
# # print(M2)

In [7]:
# print("conv{}_{}".format(1, 2))