In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_image(path, img_transform , size=(720,1280)):
  image = Image.open(path)
  image = image.resize(size, Image.LANCZOS)
  image = img_transform(image).unsqueeze(0)  # h,w -> 1,h,w
  return image.to(device)

def get_gram(m): # m => batch size , C , H , W
  b_size , c, h, w = m.size()
  m = m.view(c, h*w)
  m = torch.mm(m, m.t())
  return m

def denorm_image(inp):
  inp = inp.numpy().transpose((1,2,0)) # C, H, W -> H, W, C
  mean = np.array([0.485 , 0.456 , 0.406])
  std = np.array([0.229 , 0.224 , 0.225])
  inp = inp*std + mean   # doing the reverse of (x-mean)/sigma
  inp = inp.clip(inp,0,1)
  return inp

class FeatureExtractor(nn.Module):
  def __init__(self):
    super(FeatureExtractor,self).__init__()
    self.selected_layers = [3,8,15,22]  # all the relu activation layers are chosen
    self.vgg = models.vgg16(pretrained=True).features
  
  def forward(self, x):
    layer_feats = []
    for layer_num,layer in self.vgg._modules.items():
      x = layer(x)
      if int(layer_num) in self.selected_layers:
        layer_feats.append(x)
    return layer_feats



img_transform = transforms.Compose(
                                    [transforms.ToTensor(),
                                    transforms.Normalize(mean=(0.485 , 0.456 , 0.406) , std=(0.229 , 0.224 , 0.225))
                                    ]
                                   )
content_image = get_image('/content/drive/MyDrive/NeuralStyleTransfer/content.jpg', img_transform)
style_image = get_image('/content/drive/MyDrive/NeuralStyleTransfer/style.jpg', img_transform)
generated_image = content_image.clone()  
generated_image.requires_grad = True
opt = torch.optim.Adam([generated_image] , lr = 0.003 , betas = [0.5, 0.999])
encoder = FeatureExtractor().to(device)
for p in encoder.parameters():
  p.requires_grad = False
style_weight = 100
cont_weight = 1
for epoch in range(5000):
  content_features = encoder(content_image)
  style_features = encoder(style_image)
  generated_features = encoder(generated_image)
  cont_loss = torch.mean((content_features[-1] - generated_features[-1])**2)
  style_loss = 0
  for gf, sf in zip(generated_features,style_features):
    _,c,h,w = gf.size()
    gram_gf = get_gram(gf)
    gram_sf = get_gram(sf)
    style_loss += torch.mean((gram_gf - gram_sf)**2)/(c*h*w)
  loss = cont_weight*cont_loss + style_weight*style_loss
  opt.zero_grad()
  loss.backward()
  opt.step()
  if epoch%100 == 0:
    print(f"Epoch is {epoch} , Content Loss is {cont_loss} , Style Loss is {style_loss} , Total Loss is {loss}")
    print("==========================================================================================================")



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch is 0 , Content Loss is 0.0 , Style Loss is 5281.4091796875 , Total Loss is 528140.9375
Epoch is 100 , Content Loss is 0.994583249092102 , Style Loss is 971.04736328125 , Total Loss is 97105.7265625
Epoch is 200 , Content Loss is 1.210949182510376 , Style Loss is 450.52630615234375 , Total Loss is 45053.83984375
Epoch is 300 , Content Loss is 1.334060788154602 , Style Loss is 290.0861511230469 , Total Loss is 29009.94921875
Epoch is 400 , Content Loss is 1.4217565059661865 , Style Loss is 202.24827575683594 , Total Loss is 20226.25
Epoch is 500 , Content Loss is 1.492169737815857 , Style Loss is 152.50277709960938 , Total Loss is 15251.76953125
Epoch is 600 , Content Loss is 1.5491524934768677 , Style Loss is 122.091064453125 , Total Loss is 12210.6552734375
Epoch is 700 , Content Loss is 1.5977041721343994 , Style Loss is 101.45652770996094 , Total Loss is 10147.25
Epoch is 800 , Content Loss is 1.6428710222244263 , Style Loss is 86.38429260253906 , Total Loss is 8640.072265625
E

In [None]:
inp = generated_image.detach().cpu().squeeze()
inp = denorm_image(inp)
plt.imshow(inp)

TypeError: ignored