In [21]:
import torch
from torchvision import transforms, datasets, models
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision.utils import save_image

In [2]:
device = "gpu" if torch.cuda.is_available() else "cpu"

In [3]:
device

'cpu'

In [10]:
def load_image(path):
    image = Image.open(path)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [12]:
imsize = 356
loader = transforms.Compose([transforms.Resize((imsize, imsize)), transforms.ToTensor(),
                             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [17]:
original = load_image('target.png')
style = load_image('style.jpg')
generated = original.clone().requires_grad_(True)

In [16]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        # chosen layers : conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
        self.chosen_features = ["0", "5", "10", "19", "28"]
        self.model = models.vgg19(pretrained=True).features[:29]
        
    def forward(self, x):
        features = []
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)
        return features

In [18]:
model = VGG().to(device).eval()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\prabh/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [20]:
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)

In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original)
    style_features = model(style)
    style_loss, target_loss = 0, 0
    for gen_feature, orig_feature, sty_feature in zip(generated_features, original_img_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        target_loss += torch.mean((gen_feature - orig_feature) ** 2)
        G = gen_feature.view(channel, height * width).mm(gen_feature.view(channel, height * width).t())
        A = sty_feature.view(channel, height * width).mm(sty_feature.view(channel, height * width).t())
        style_loss += torch.mean((G - A) ** 2)
    
    total_loss = alpha * target_loss + beta * style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if step % 200 == 0:
        print(total_loss)
        save_image(generated, 'generated.png' )

tensor(60728180., grad_fn=<AddBackward0>)


In [27]:
total_loss

tensor(60936712., grad_fn=<AddBackward0>)