In [1]:
# necessary imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms 
from torchvision.utils import save_image
from PIL import Image

In [2]:
# going to use pretrained VGG19 network
model = models.vgg19(pretrained=True).features # it will load conv layers
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.features = ['0','5','10','19','28'] # we want outputs from these layers to calculate the loss
        self.vgg_model = model
    def forward(self,x):
        required_layer = []
        for layer_num, layer in enumerate(self.vgg_model):
            x = layer(x)
            if str(layer_num) in self.features:
                required_layer.append(x)
        return required_layer
    

In [4]:
device = 'cuda'
print(device)

cuda


In [5]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0) # we are unsqueezing it to add a dimension to it to pass through the model as it requires the size (batch_size, num_channels=3, height,width)          
    return image.to(device) #loader is basically transformed image

In [6]:
image_size = 256
loader = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
    # we can normalize image
])

In [7]:
content_img = load_image('ayush.jpg')
style_img = load_image('style7.jpeg')
#gen_image = torch.randn(content_image.shape, device = device, requires_grad = True)
# actually generated image is the parameter
gen_img = content_img.clone().requires_grad_(True)

In [8]:
# hyperparameters
num_steps = 1001
learning_rate = 1e-3
alpha = 1
beta = 0.01
optimizer = optim.Adam([gen_img], lr=learning_rate) # our only parameter is generated image, so passed it in           

In [9]:
# defining the model 
nst_model = VGG().to(device).eval() # doing so to freeze the weights

In [10]:
# implementation
for step in range(1,num_steps):
    content_features = nst_model(content_img)
    style_features = nst_model(style_img)
    gen_features = nst_model(gen_img)
    content_loss = 0
    style_loss = 0
    
    for content_feats, style_feats, gen_feats in zip(content_features, style_features, gen_features):  # zip just merges the lists    
        batch_size, num_channels, height, width = gen_feats.shape # output related to the 1,2,3,4,5 conv layers  
        
        content_loss+= torch.mean((content_feats-gen_feats)**2)
        # calculate the gram matrices
        G = gen_feats.view(num_channels, height*width).mm(gen_feats.view(num_channels, height*width).t())  # similar to reshaping but first a copy is created 
        S = style_feats.view(num_channels, height*width).mm(style_feats.view(num_channels, height*width).t())
        style_loss+= torch.mean((G-S)**2)
        
    total_loss = alpha*content_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step%200==0:
        print(total_loss)
        save_image(gen_img, "gen.png") 
    

tensor(54411.0391, device='cuda:0', grad_fn=<AddBackward0>)
tensor(28733.1582, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19752.3594, device='cuda:0', grad_fn=<AddBackward0>)
tensor(14869.4121, device='cuda:0', grad_fn=<AddBackward0>)
tensor(11759.9873, device='cuda:0', grad_fn=<AddBackward0>)
