In [2]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image


In [3]:
model = models.vgg19(pretrained = True).features



In [None]:
class VGG (nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        self.chosen_features = ['0','5','10','19','28']
        self.model = models.vgg19(pretrained = True).features[:29]

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.model):
            x = layer(x)

            if(str(i) in self.chosen_features):
                features.append(x)

        return features

In [None]:
device = "cuda" # to run on the gpu 
img_size = 256
loader = transforms.Compose([transforms.Resize((img_size,img_size)),transforms.ToTensor()])

In [None]:
def load_image(path):
    image = Image.open(path)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [None]:
original_img = load_image("building.jpg")
style_img = load_image("van_gogh_style.jpeg")

In [None]:
generated = original_img.clone().requires_grad_(True)

In [None]:
# Hyperparams

total_steps = 6000
learning_rate = 0.001
# used for the cost function 
alpha = 1
beta = 0.01

In [None]:
optimizer = optim.Adam([generated ], lr = learning_rate)
model = VGG().to(device).eval()

In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = 0
    original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
        batch_size, channel, height, width = gen_feature.shape

        original_loss  += torch.mean((gen_feature - orig_feature) **2)
        
        # Gramm matrix defined as M*M.T
        G = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())

        A = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())

        style_loss += torch.mean((G - A)**2)

    total_loss =  alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if(step % 200 ==  0 ):
        print(total_loss)
        save_image(generated, "generated.png")

