In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class customvgg19(nn.Module):
  def __init__(self):
    super().__init__()
    self.chosen_features=[0,5,10,19,28]
    self.model = models.vgg19(weights=True).features[:29]

  def forward(self,x):
      features = []

      for layer_num,layer in enumerate(self.model):
        x = layer(x)

        if layer_num in self.chosen_features:
          features.append(x)

      return features

In [None]:
custom = customvgg19()
custom.to(device)



customvgg19(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1),

In [None]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")
device

device(type='cuda')

In [None]:
img_size = 256
stats = ((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
loader = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(*stats)
]
)

In [None]:
def load_image(image_name):
  image = Image.open(image_name)
  image = loader(image).unsqueeze(0)
  return image.to(device)

In [None]:
content_img = load_image("content.jpg")
style_img = load_image("style_img.jpg")
generated_img = content_img.clone().requires_grad_(True)
content_img.shape

torch.Size([1, 3, 256, 256])

In [None]:
# show_image(content_img)
# plt.imshow(content_img)

In [None]:
epochs = 6000
lr = 0.01
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated_img],lr=lr)

In [None]:
def gram_matrix(matrix,channel,height,width):
  g = matrix.view(channel,height*width)
  return g.mm(g.t())

In [None]:
def denorm(img):
  return img*stats[1][0]+stats[0][0]
def show_image(img):
  plt.imshow(denorm(img))

In [None]:
for epoch in range(epochs+1):
  generated_features = custom(generated_img)
  content_features = custom(content_img)
  style_features = custom(style_img)

  style_loss=content_loss=0

  for g_feature,c_feature,s_feature in zip(generated_features,content_features,style_features):
    batch_size,channel,height,width = g_feature.shape
    content_loss += torch.mean((g_feature-c_feature)**2)

    gen_gram = gram_matrix(g_feature,channel,height,width)
    style_gram = gram_matrix(s_feature,channel,height,width)

    style_loss += torch.mean((gen_gram-style_gram)**2)

  total_loss = alpha*content_loss + beta*style_loss
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()


  if epoch%100==0:
    print(f"total loss for epoch {epoch}: {total_loss}")

  if epoch==epochs:
    save_image(generated_img,"generated.png")



total loss for epoch 0: 2481159.25
total loss for epoch 100: 314996.5625
total loss for epoch 200: 151625.28125
total loss for epoch 300: 88155.828125
total loss for epoch 400: 61226.46875
total loss for epoch 500: 48395.82421875
total loss for epoch 600: 41146.5390625
total loss for epoch 700: 36768.80859375
total loss for epoch 800: 33611.3359375
total loss for epoch 900: 31242.53125
total loss for epoch 1000: 29381.576171875
total loss for epoch 1100: 27959.306640625
total loss for epoch 1200: 26846.6171875
total loss for epoch 1300: 25740.287109375
total loss for epoch 1400: 24909.056640625
total loss for epoch 1500: 24062.900390625
total loss for epoch 1600: 23531.14453125
total loss for epoch 1700: 23050.958984375
total loss for epoch 1800: 22383.439453125
total loss for epoch 1900: 22072.666015625
total loss for epoch 2000: 21471.36328125
total loss for epoch 2100: 21075.458984375
total loss for epoch 2200: 21603.998046875
total loss for epoch 2300: 20438.123046875
total loss fo