In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm

class VGGStyleModel:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.to(device).eval()
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        self.transform = transforms.Compose([
            transforms.Resize(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ])

    def load_image(self, image_path):
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image).unsqueeze(0).to(self.device)
        return image

    def denormalize(self, tensor):
        image = tensor.clone().detach().cpu().squeeze(0).numpy().transpose(1, 2, 0)
        image = image * self.std + self.mean
        return (np.clip(image, 0, 1) * 255).astype(np.uint8)

    def get_features(self, image, layers):
        features = {}
        x = image

        for idx, layer in enumerate(self.vgg.children()):
            x = layer(x)
            if idx in layers:
                features[layers[idx]] = x
        return features

    def gram_matrix(self, tensor):
        b, c, h, w = tensor.size()
        features = tensor.view(c, h * w)
        gram = torch.mm(features, features.t())
        return gram.div(c * h * w)

    def train(self, content_path, style_path, weights, num_steps=300):

        content_img = self.load_image(content_path)
        style_img = self.load_image(style_path)
        targets = []


        for content_weight, style_weight in weights:

            target_img = content_img.clone().requires_grad_(True)

            layers = {
                0: 'conv1_1',
                5: 'conv2_1',
                10: 'conv3_1',
                19: 'conv4_1',
                21: 'conv4_2', 
                28: 'conv5_1'
            }
            content_layer = 'conv4_2'
            style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']


            content_features = self.get_features(content_img, layers)
            style_features = self.get_features(style_img, layers)


            optimizer = torch.optim.Adam([target_img], lr=0.001)


            step = [0]
            for _ in tqdm(range(num_steps), desc="Training with LBFGS"):
                optimizer.zero_grad()
                generated_features = self.get_features(target_img, layers)

 
                content_loss = F.mse_loss(generated_features[content_layer], content_features[content_layer])
                
            
                style_loss = 0
                for layer in style_layers:
                    gram_gen = self.gram_matrix(generated_features[layer])
                    gram_style = self.gram_matrix(style_features[layer])
                    style_loss += F.mse_loss(gram_gen, gram_style)

                total_loss = content_weight * content_loss + style_weight * style_loss
                total_loss.backward(retain_graph=True)
                step[0] += 1
                optimizer.step()

            targets.append(target_img.clone())


        for idx, target in enumerate(targets):
            result = self.denormalize(target)
            Image.fromarray(result).save(f"generated_image_{idx}.jpg")


model = VGGStyleModel()


weights = [
    (1e4, 1e2),
    (1e3, 1e3),
    (1e5, 1e1),
    (1e4, 1e4),
    (1e2, 1e5)
]


model.train("/home/walke/college/cv/ass2/CV Assignment 2/Q3/content/bear.jpg", 
            "/home/walke/college/cv/ass2/CV Assignment 2/Q3/styles/bet-you.jpg", 
            weights, num_steps=300)
