<a href="https://colab.research.google.com/github/soumyamalviya92-pixel/prodigytask5/blob/main/task5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
image_size = 512 if torch.cuda.is_available() else 256

loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device)

content_img = load_image("content.jpg")
style_img = load_image("style.jpg")

assert content_img.size() == style_img.size()


In [None]:
def imshow(tensor, title=None):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    plt.imshow(image)
    if title:
        plt.title(title)
    plt.axis('off')

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
imshow(content_img, "Content Image")
plt.subplot(1,2,2)
imshow(style_img, "Style Image")


In [None]:
cnn = models.vgg19(pretrained=True).features.to(device).eval()


In [None]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = target.detach()

    def forward(self, x):
        self.loss = nn.functional.mse_loss(x, self.target)
        return x


def gram_matrix(x):
    b, c, h, w = x.size()
    features = x.view(c, h * w)
    gram = torch.mm(features, features.t())
    return gram.div(c * h * w)


class StyleLoss(nn.Module):
    def __init__(self, target):
        super().__init__()
        self.target = gram_matrix(target).detach()

    def forward(self, x):
        G = gram_matrix(x)
        self.loss = nn.functional.mse_loss(G, self.target)
        return x


In [None]:
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

content_losses = []
style_losses = []

model = nn.Sequential()
i = 0

for layer in cnn.children():
    if isinstance(layer, nn.Conv2d):
        i += 1
        name = f"conv_{i}"
    elif isinstance(layer, nn.ReLU):
        name = f"relu_{i}"
        layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
        name = f"pool_{i}"
    else:
        continue

    model.add_module(name, layer)

    if name in content_layers:
        target = model(content_img).detach()
        content_loss = ContentLoss(target)
        model.add_module("content_loss_" + name, content_loss)
        content_losses.append(content_loss)

    if name in style_layers:
        target = model(style_img).detach()
        style_loss = StyleLoss(target)
        model.add_module("style_loss_" + name, style_loss)
        style_losses.append(style_loss)


In [None]:
input_img = content_img.clone()
optimizer = optim.LBFGS([input_img.requires_grad_()])

style_weight = 1e6
content_weight = 1

epochs = 200
print("Stylizing...")

run = [0]
while run[0] <= epochs:

    def closure():
        optimizer.zero_grad()
        model(input_img)
        style_score = sum(sl.loss for sl in style_losses)
        content_score = sum(cl.loss for cl in content_losses)

        loss = style_weight * style_score + content_weight * content_score
        loss.backward()

        if run[0] % 50 == 0:
            print(f"Epoch {run[0]} | Style Loss: {style_score.item():.2f} | Content Loss: {content_score.item():.2f}")

        run[0] += 1
        return loss

    optimizer.step(closure)


In [None]:
plt.figure(figsize=(6,6))
imshow(input_img, "Stylized Image")
