In [3]:
import torch
from torch import nn
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

In [18]:
# Device Agnostic Code
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
model = models.vgg19(pretrained = True).features

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\Sriram Kidambi/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth
100.0%


In [15]:
important_layers = ["0", "5", "10", "19", "28"] # Look at the first layer and the layers after a maxpool

In [35]:
# Creating the model
class VGG(nn.Module):
    def __init__(self):
        super().__init__()
        self.chosen_features = ["0", "5", "10", "19", "28"]
        self.model = models.vgg19(weights = True).features[:29]

    def forward(self, x):
        features = []

        for layer_num, layer in enumerate(self.model):
            x = layer(x)

            if str(layer_num) in self.chosen_features:
                features.append(x)
        
        return features

In [46]:
model = VGG().to(device).eval()

In [47]:
IMAGE_SIZE = 356
loader = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

In [48]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)

    return image.to(device)

In [49]:
# Images
original_image = load_image("content_image.jpg")
style_image = load_image("style_image.jpg")
generated_image = original_image.clone().requires_grad_(True)

In [50]:
# Hyperparameters
epochs = 6000
learning_rate = 0.001
alpha = 1 # Content Loss
beta = 0.01 # Style Loss
optimizer = torch.optim.Adam([generated_image], lr=learning_rate)

In [51]:
for epoch in range(epochs):
    generated_features = model(generated_image)
    original_img_features = model(original_image)
    style_features = model(style_image)

    style_loss = 0
    content_loss = 0

    for gen_feature, orig_feature, style_feature in zip(generated_features, original_img_features, style_features):
        batch_size, channel, height, width = gen_feature.shape
        content_loss += torch.mean((gen_feature - orig_feature) ** 2)

        # Compute Gram Matrix
        G = gen_feature.view(channel, height*width).mm(gen_feature.view(channel, height*width).t())
        
        A = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())

        style_loss += torch.mean((G - A)**2)
    
    total_loss = alpha*content_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if epoch % 200 == 0:
        print(f"Total Loss: {total_loss:.3f}")
        save_image(generated_image, "generated_image.png")


Total Loss: 954707.688
Total Loss: 56942.934
Total Loss: 28854.098
Total Loss: 13032.459
Total Loss: 5351.579
Total Loss: 2984.576
Total Loss: 2295.335
Total Loss: 1971.850
Total Loss: 1755.304
Total Loss: 1592.016
Total Loss: 1462.240
Total Loss: 1355.671
Total Loss: 1265.560
Total Loss: 1187.739
Total Loss: 1120.350
Total Loss: 1061.337
Total Loss: 1008.776
Total Loss: 961.801
Total Loss: 919.513
Total Loss: 880.994
Total Loss: 845.641
Total Loss: 812.725
Total Loss: 782.499
Total Loss: 754.563
Total Loss: 728.747
Total Loss: 704.686
Total Loss: 681.678
Total Loss: 661.146
Total Loss: 642.119
Total Loss: 624.693
