In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torchvision.transforms as tf
import torchvision.models as models

In [2]:
vgg = models.vgg19(pretrained = True).features

for param in vgg.parameters():
  param.requires_grad_(False)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(IntProgress(value=0, max=574673361), HTML(value='')))




In [55]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [0]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

In [0]:
def transformation(img):
  tasks = tf.Compose([tf.Resize(256),
                      tf.ToTensor(),
                      tf.Normalize(mean, std)])
  img = tasks(img)
  img = img.unsqueeze(0)

  return img

In [0]:
content_image = Image.open('AndrewNG.png').convert('RGB')
style_image = Image.open('conmug.png').convert('RGB')

In [0]:
content_image = transformation(content_image).to(device)
style_image = transformation(style_image).to(device)

In [60]:
content_image

tensor([[[[-0.3369, -0.3027,  0.0912,  ...,  0.6392,  0.6049,  0.6049],
          [-0.0801, -0.1314, -0.1143,  ...,  0.7591,  0.7248,  0.7077],
          [ 0.2282,  0.1083,  0.0398,  ...,  0.8447,  0.8104,  0.7933],
          ...,
          [-1.3987, -1.3815, -1.3473,  ..., -2.0323, -2.0323, -2.0323],
          [-1.3644, -1.3815, -1.3302,  ..., -2.0323, -2.0323, -2.0323],
          [-1.3473, -1.3644, -1.3130,  ..., -2.0323, -2.0323, -2.0494]],

         [[-0.1625, -0.0924,  0.3978,  ...,  0.9930,  0.9755,  0.9580],
          [ 0.1176,  0.0651,  0.1352,  ...,  1.0980,  1.0630,  1.0455],
          [ 0.4328,  0.3277,  0.2577,  ...,  1.1681,  1.1506,  1.1331],
          ...,
          [-1.2304, -1.1604, -1.1078,  ..., -1.8782, -1.8782, -1.8782],
          [-1.2129, -1.1429, -1.0903,  ..., -1.8782, -1.8782, -1.8782],
          [-1.1954, -1.1253, -1.0728,  ..., -1.8782, -1.8782, -1.8782]],

         [[ 0.0431,  0.0605,  0.5136,  ...,  1.0539,  1.0365,  1.0017],
          [ 0.3219,  0.2522,  

In [61]:
style_image

tensor([[[[ 1.9749,  1.9578,  1.9407,  ...,  1.8208,  1.9749,  1.7523],
          [ 1.9749,  1.9578,  2.0092,  ...,  1.8379,  1.9235,  1.7865],
          [ 2.0263,  2.0605,  1.9749,  ...,  1.7180,  2.1119,  1.8037],
          ...,
          [ 2.2318,  2.2147,  2.2318,  ...,  1.6838,  1.9407,  1.5639],
          [ 2.2147,  2.2318,  2.2318,  ...,  1.6153,  1.8893,  1.5982],
          [ 2.2489,  2.2489,  2.2489,  ...,  1.5468,  1.6153,  1.7009]],

         [[ 0.2227,  0.2752,  0.1176,  ...,  1.0455,  1.9909,  1.9559],
          [ 0.1352,  0.1877,  0.1001,  ...,  0.9580,  1.8859,  2.0259],
          [ 0.1176,  0.2052,  0.0126,  ...,  1.1506,  2.2010,  2.0784],
          ...,
          [ 2.3936,  2.3936,  2.4111,  ...,  1.8683,  2.0959,  1.7633],
          [ 2.4286,  2.4286,  2.4111,  ...,  1.7808,  2.0609,  1.7633],
          [ 2.4286,  2.4286,  2.4286,  ...,  1.7108,  1.7808,  1.8683]],

         [[-0.4798, -0.4450, -0.5495,  ...,  0.4962,  1.8731,  2.1520],
          [-0.4275, -0.3927, -

In [62]:
style_image.shape

torch.Size([1, 3, 376, 256])

In [0]:
def tensor_to_image(tensor):
  image = tensor.clone().detach()
  image = image.cpu().numpy().squeeze()
  image = image.transpose(1, 2, 0)
  image *= np.array(std) + np.array(mean)
  image = image.clip(0, 1)

In [0]:
# img = tensor_to_image(content_image)
# fig = plt.figure()
# plt.imshow(img)

# img = tensor_to_image(style_image)
# fig = plt.figure()
# plt.imshow(img)

In [0]:
layersofinterest = {'0' : 'conv1_1',
                    '5' : 'conv2_1',
                    '10' : 'conv3_1',
                    '19' : 'conv4_1',
                    '21' : 'conv4_2',
                    '28' : 'conv5_1',}

In [0]:
def apply_mode_extract_features (image, model):
  x = image
  features = {}
  for name, layer in model._modules.items():
    x = layer(x)
    if name in layersofinterest:
      features[layersofinterest[name]] = x
  return features

In [0]:
content_image_features = apply_mode_extract_features(content_image, vgg)
style_image_features = apply_mode_extract_features(style_image, vgg)

In [68]:
content_image_features

{'conv1_1': tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.4361, 0.4103, 0.1228],
           [0.0000, 0.0000, 0.0000,  ..., 0.4680, 0.4446, 0.1953],
           [0.0000, 0.0000, 0.0000,  ..., 0.5333, 0.5173, 0.2591],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
 
          [[0.7034, 0.0000, 0.0000,  ..., 1.5021, 1.4700, 2.5907],
           [0.9650, 0.6831, 0.0000,  ..., 0.6751, 0.6886, 2.7227],
           [0.3399, 1.4487, 0.7800,  ..., 0.5562, 0.5914, 2.8459],
           ...,
           [3.7197, 0.1366, 0.1136,  ..., 0.3061, 0.2909, 0.0000],
           [3.6423, 0.1395, 0.0837,  ..., 0.3026, 0.2906, 0.0000],
           [4.5844, 2.5878, 2.4487,  ..., 4.1123, 4.1271, 0.0770]],
 
          [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0

In [69]:
style_image_features

{'conv1_1': tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.5078e-01,
            1.3376e+00, 1.1989e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.3678e-01,
            1.1633e+00, 1.1482e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 5.4463e-01,
            1.3731e+00, 1.2250e+00],
           ...,
           [1.2838e+00, 2.0255e+00, 2.0098e+00,  ..., 1.4379e+00,
            1.4578e+00, 9.9847e-01],
           [1.3080e+00, 2.0160e+00, 2.0166e+00,  ..., 1.2875e+00,
            1.3666e+00, 9.6418e-01],
           [4.4421e-01, 7.8487e-01, 7.8269e-01,  ..., 2.6968e-01,
            3.5411e-01, 1.5045e-01]],
 
          [[0.0000e+00, 1.5688e+00, 1.7595e+00,  ..., 2.6368e-01,
            1.2795e+00, 5.2859e+00],
           [0.0000e+00, 5.0048e-01, 6.3798e-01,  ..., 0.0000e+00,
            0.0000e+00, 5.3492e+00],
           [0.0000e+00, 3.7080e-01, 5.1937e-01,  ..., 0.0000e+00,
            0.0000e+00, 5.1827e+00],
           ...,
           [0.0000e+00, 2.1

In [0]:
def calculate_gram_matrix(tensor):
  _, channels, height, width = tensor.size()
  tensor = tensor.view(channels, height * width)
  gram_matrix = torch.mm(tensor, tensor.t())
  gram_matrix = gram_matrix.div(channels * height * width)
  return gram_matrix

In [71]:
style_features_gram_matrix = {layer: calculate_gram_matrix(style_image_features[layer]) for layer in style_image_features}
style_features_gram_matrix

{'conv1_1': tensor([[0.0081, 0.0045, 0.0075,  ..., 0.0005, 0.0017, 0.0043],
         [0.0045, 0.0255, 0.0055,  ..., 0.0002, 0.0036, 0.0098],
         [0.0075, 0.0055, 0.0151,  ..., 0.0003, 0.0036, 0.0079],
         ...,
         [0.0005, 0.0002, 0.0003,  ..., 0.0003, 0.0002, 0.0003],
         [0.0017, 0.0036, 0.0036,  ..., 0.0002, 0.0064, 0.0035],
         [0.0043, 0.0098, 0.0079,  ..., 0.0003, 0.0035, 0.0167]],
        device='cuda:0'),
 'conv2_1': tensor([[0.0431, 0.0004, 0.0239,  ..., 0.0072, 0.0039, 0.0166],
         [0.0004, 0.0029, 0.0016,  ..., 0.0017, 0.0006, 0.0005],
         [0.0239, 0.0016, 0.0879,  ..., 0.0120, 0.0026, 0.0210],
         ...,
         [0.0072, 0.0017, 0.0120,  ..., 0.0236, 0.0028, 0.0046],
         [0.0039, 0.0006, 0.0026,  ..., 0.0028, 0.0129, 0.0081],
         [0.0166, 0.0005, 0.0210,  ..., 0.0046, 0.0081, 0.0456]],
        device='cuda:0'),
 'conv3_1': tensor([[0.1293, 0.0270, 0.0148,  ..., 0.0180, 0.0546, 0.0082],
         [0.0270, 0.0926, 0.0301,  ..., 

In [0]:
weights = {'conv1_1': 1.0, 'conv2_1': 0.75, 'conv3_1': 0.35,
           'conv4_1': 0.25, 'conv5_1': 0.15}

In [75]:
target = content_image.clone().to(device)
optimizer = optim.Adam([target], lr = 0.003)
target.requires_grad_(True)

tensor([[[[-0.3369, -0.3027,  0.0912,  ...,  0.6392,  0.6049,  0.6049],
          [-0.0801, -0.1314, -0.1143,  ...,  0.7591,  0.7248,  0.7077],
          [ 0.2282,  0.1083,  0.0398,  ...,  0.8447,  0.8104,  0.7933],
          ...,
          [-1.3987, -1.3815, -1.3473,  ..., -2.0323, -2.0323, -2.0323],
          [-1.3644, -1.3815, -1.3302,  ..., -2.0323, -2.0323, -2.0323],
          [-1.3473, -1.3644, -1.3130,  ..., -2.0323, -2.0323, -2.0494]],

         [[-0.1625, -0.0924,  0.3978,  ...,  0.9930,  0.9755,  0.9580],
          [ 0.1176,  0.0651,  0.1352,  ...,  1.0980,  1.0630,  1.0455],
          [ 0.4328,  0.3277,  0.2577,  ...,  1.1681,  1.1506,  1.1331],
          ...,
          [-1.2304, -1.1604, -1.1078,  ..., -1.8782, -1.8782, -1.8782],
          [-1.2129, -1.1429, -1.0903,  ..., -1.8782, -1.8782, -1.8782],
          [-1.1954, -1.1253, -1.0728,  ..., -1.8782, -1.8782, -1.8782]],

         [[ 0.0431,  0.0605,  0.5136,  ...,  1.0539,  1.0365,  1.0017],
          [ 0.3219,  0.2522,  

In [0]:
# plt.figure()
# plt.imshow(tensor_to_image(target))

In [0]:
for i in range(1, 2000):
  target_features = apply_mode_extract_features(target, vgg)
  content_loss = F.mse_loss(target_features['conv4_2'], content_image_features['conv4_2'])
  style_loss = 0

  for layer in weights:
    target_feature = target_features[layer]

    target_gram_matrix = calculate_gram_matrix(target_feature)
    style_gram_matrix = style_features_gram_matrix[layer]

    layer_loss = F.mse_loss(target_gram_matrix, style_gram_matrix)
    layer_loss *= weights[layer]

    _, channels, height, width = target_feature.shape

    style_loss += layer_loss
  
  total_loss = 1000000 * style_loss + content_loss

  if i % 50 == 0:
    f"Epoch {i}:  Style Loss: {style_loss:.4f}  Content Loss: {content_loss:.4f}"

  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

In [0]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
ax1.imshow(tensor_to_image(content_image))
ax2.imshow(tensor_to_image(target))