### Import Libraries




In [26]:
import torchvision.models as models
from torchvision import transforms
import torch
import torch.nn as nn
from PIL import Image
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

In [27]:
IMG_HT = 256
IMG_WT = 256

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


### Import Model vgg16 feature extractor

In [28]:
vgg16 =  models.vgg16(pretrained=True).to(device)
vgg_features = vgg16.features

for params in vgg_features.parameters():
    params.requires_grad = False



### Load Style and content images

In [29]:
transform = transforms.Compose([
        transforms.Resize((IMG_HT, IMG_WT)), 
        transforms.ToTensor()
])

In [30]:
style_image = transform(Image.open('wave.jpg'))
content_image = transform(Image.open('chicago.jpg'))

In [None]:
fig, ax = plt.subplots (1,2)
ax[0].imshow(style_image.permute(1,2,0))
ax[1].imshow(content_image.permute(1,2,0))
ax[0].axis('off')
ax[1].axis('off')

In [32]:
style_image = torch.stack([style_image]).to(device)
content_image = torch.stack([content_image]).to(device)

### Apply hook for storing activations

In [33]:
activation_store = []

def store_activations(store_list):
    def hook(model, input, output):
        store_list.append(output)
    return hook

for layer in vgg_features:
    # global activation_store 
    if isinstance(layer, nn.Conv2d):
        layer.register_forward_hook(store_activations(activation_store))


content_out = vgg_features(content_image)
style_out = vgg_features(style_image)

In [34]:
print(len(activation_store))
content_store = activation_store[:13]
style_store = activation_store[13:]

26


#### Content Loss

In [35]:
def get_content_loss(generation_img, num_layer, activation_store):

    content_act = content_store[num_layer]

    activation_store.clear()
    generated_out = vgg_features(generation_img)
    generated_act = activation_store[num_layer]
    
    return nn.MSELoss()(generated_act, content_act)


In [None]:
for num_layer in range(1,10):
    generation = torch.abs(torch.randn(1,3,IMG_HT,IMG_WT).to(device))
    generation.requires_grad = True
    optimizer = optim.Adam([generation], lr=0.1)
    for epoch in range(500):
        generation.data.clamp_(0,1)
        content_loss = get_content_loss(generation, num_layer, activation_store)

        optimizer.zero_grad()
        content_loss.backward()
        optimizer.step()
    plt.imshow(generation[0].clone().detach().to('cpu').permute(1,2,0)) 
    plt.show()


#### Style Loss

In [38]:
def get_gram_matrix(feature_matrix):
    b,c,w,h = feature_matrix.shape

    reshaped_feature = feature_matrix.view(b*c,w*h)
    return (torch.mm(reshaped_feature, reshaped_feature.T))/(b*c*w*h)


def get_style_loss(generation_image, num_layers, activation_store):
    loss = 0
    activation_store.clear()
    generated_out = vgg_features(generation_image)
    
    for layer in range(1,num_layers):
        style_act = style_store[layer]
        generated_act = activation_store[layer]
        
        gen_gram = get_gram_matrix(generated_act)
        style_gram = get_gram_matrix(style_act)

        loss += nn.MSELoss()(gen_gram, style_gram)   
    return loss/(num_layers-1)

In [None]:
for num_layer in range(2,10):
    generation = torch.abs(torch.randn(1,3,IMG_HT,IMG_WT).to(device))
    generation.requires_grad = True
    optimizer = optim.Adam([generation], lr=0.1)
    for epoch in range(500):
        generation.data.clamp_(0,1)
        style_loss = get_style_loss(generation, num_layer, activation_store)

        optimizer.zero_grad()
        style_loss.backward()
        optimizer.step()
    
       
    plt.imshow(generation[0].clone().detach().to('cpu').permute(1,2,0)) 
    plt.show()  

#### Style Transfer

In [None]:
generation = torch.abs(torch.randn(1,3,IMG_HT,IMG_WT).to(device))
generation.requires_grad = True
optimizer = optim.Adam([generation], lr=0.1)

for epoch in range(10000):
    generation.data.clamp_(0,1)
    style_loss = get_style_loss(generation, 7, activation_store)
    content_loss = get_content_loss(generation, 1, activation_store)

    loss = 100000*style_loss + content_loss


    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch%200 == 0):
        plt.imshow(generation[0].clone().detach().to('cpu').permute(1,2,0)) 
        plt.show() 
