In [None]:
# https://medium.com/ai-techsystems/neural-style-transfer-742dca137976

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
from PIL import Image
transform = transforms.Compose([
        transforms.Resize(512),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
content_image  = cv2.imread("/home/walke/college/cv/ass2/CV Assignment 2/Q3/content/bear.jpg")
content_image  = cv2.cvtColor(content_image , cv2.COLOR_BGR2RGB)
content_image = Image.fromarray(content_image)
content_image = transform(content_image).unsqueeze(0)


style_image  = cv2.imread("/home/walke/college/cv/ass2/CV Assignment 2/Q3/styles/bet-you.jpg")
style_image  = cv2.cvtColor(style_image , cv2.COLOR_BGR2RGB)
style_image = Image.fromarray(style_image)
style_image = transform(style_image).unsqueeze(0)

generated_image = content_image.clone().detach().requires_grad_(True)




vgg = models.vgg19(pretrained=True).features

for param in vgg.parameters():
    param.requires_grad_(False)
        
device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using {device} device")

vgg = vgg.to(device)

generated_image = generated_image.to(device)
content_image = content_image.to(device)
style_image = style_image.to(device)
    

layers = {}
conv_count = 1
relu_count = 1
pool_count = 1
    
for i, layer in enumerate(vgg.children()):
    if isinstance(layer, nn.Conv2d):
        name = f'conv_{conv_count}'
        conv_count += 1
    elif isinstance(layer, nn.ReLU):
        name = f'relu_{relu_count}'
        # Replace in-place ReLU
        vgg[i] = nn.ReLU(inplace=False)
        relu_count += 1
    elif isinstance(layer, nn.MaxPool2d):
        name = f'pool_{pool_count}'
        pool_count += 1
    else:
        raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')
        
    layers[name] = i
    
def get_features(image, layers, model=vgg):
    features = {}
    x = image
    
    max_layer = max(layers.values())
    
    for idx, layer in enumerate(model.children()):
        x = layer(x)

        for name, layer_idx in layers.items():
            if idx == layer_idx:
                features[name] = x
                
        if idx >= max_layer:
            break
                
    return features
def content_loss(input_features, target_features):
    return F.mse_loss(input_features, target_features)


def style_loss(input_features, target_features):
    input_batch_size, input_channels, input_height, input_width = input_features.size()
    input_features = input_features.view(input_batch_size * input_channels, input_height * input_width)
    input_gram = torch.mm(input_features, input_features.t())
    input_gram = input_gram.div(input_channels * input_height * input_width)

    target_batch_size, target_channels, target_height, target_width = target_features.size()
    target_features = target_features.view(target_batch_size * target_channels, target_height * target_width)
    target_gram = torch.mm(target_features, target_features.t())
    target_gram = target_gram.div(target_channels * target_height * target_width)
 
    return F.mse_loss(input_gram, target_gram)

content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    
layers = {name: layers[name] for name in set(content_layers + style_layers)}

generated_features = get_features(generated_image, layers=layers)
content_features = get_features(content_image, layers=layers)
style_features = get_features(style_image, layers=layers)

# Tcontent_loss = 0
# for layer in content_layers:
#     Tcontent_loss += content_loss(generated_features[layer], content_features[layer])
    
# Tstyle_loss = 0
# for layer in style_layers:
#     Tstyle_loss += style_loss(generated_features[layer], style_features[layer])

content_weight = 1e4
style_weight = 1e2


mse_loss = torch.nn.MSELoss()


generated_image = generated_image.clone().detach().requires_grad_(True)  # Ensure it's a leaf tensor

optimizer = torch.optim.LBFGS([generated_image])

run = [0]
while run[0] <= 5:
    print(run[0])
    def closure():
        optimizer.zero_grad()

        Tcontent_loss = 0
        for layer in content_layers:
            Tcontent_loss += content_loss(generated_features[layer], content_features[layer])
            
        Tstyle_loss = 0
        for layer in style_layers:
            Tstyle_loss += style_loss(generated_features[layer], style_features[layer])

        total_loss = content_weight * Tcontent_loss + style_weight * Tstyle_loss
        total_loss.backward(retain_graph=True)
        run[0] += 1
        return total_loss


    optimizer.step(closure)

final_img = generated_image.detach().squeeze(0).permute(1, 2, 0).cpu().numpy()
final_img = (final_img * 255).clip(0, 255).astype('uint8')
Image.fromarray(final_img).save("/home/walke/college/cv/ass2/CV Assignment 2/LBFGS.jpg")

