In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
from tqdm import tqdm, trange

In [None]:
device = torch.device("mps")
vgg = torchvision.models.vgg19(pretrained=True).features.to(device)

In [None]:
test =Image.open("./test_images/lion.jpeg").convert("RGB")
style = Image.open("./test_images/style.jpeg").convert("RGB")

shape = 224
transf = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
    torchvision.transforms.Resize((shape,shape))
])

inp = transf(test).to(device)
plt.imshow(inp.cpu().permute(1,2,0).numpy())
inp = inp.unsqueeze(0)
style = transf(style).to(device)
print(style.shape)
plt.imshow(style.permute(1,2,0).cpu().numpy())
style = style.unsqueeze(0)


In [None]:
print(torch.max(inp))

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self,model, key_features=['0', '5', '10', '19', '28']):
        super(FeatureExtractor, self).__init__()
        self.key_features = key_features
        self.model = model
        for layer in self.model:
            layer.requires_grad_(False)
    def forward(self, x):
        out =[]
        for ix, layer in enumerate(self.model):
            x = layer(x)
            if str(ix) in self.key_features:
                out.append(x)

        return out


In [None]:
def grams(x):
    # (b,ch,h,w)= x.size()
    x = x.view(x.shape[0], x.shape[1], -1)
    # print(x.shape)
    x_t = x.transpose(1,2)
    # print(x_t.shape)
    o = x.bmm(x_t)
    # print(o.shape)
    return o
    # return o


In [None]:
# gen = torch.randn((1,3,shape,shape), requires_grad=True, device=device)
gen = inp.clone().requires_grad_(True).to(device)
optim = torch.optim.Adam([gen], lr=0.01)
features = FeatureExtractor(vgg).to(device)
features.eval()
losses = []
features_extracted_source = features(inp)
features_extracted_style = features(style)
style_weight, content_weight = 0.1, 1.0





for _ in trange(1000):
    content_loss = style_loss = 0.0
    optim.zero_grad()
    features_extracted_generated = features(gen)
    for source_features, gen_features, style_features in zip(features_extracted_source,features_extracted_generated, features_extracted_style):
        # print(source_features.shape, gen_features.shape)
        source_features = source_features.view(source_features.shape[0], source_features.shape[1], -1)
        gen_features= gen_features.view(gen_features.shape[0], gen_features.shape[1], -1)
        # print(source_features.shape, gen_features.shape)

        style_grams = grams(style_features)
        gen_grams = grams(gen_features)
        # print(style_grams[0].shape, gen_grams[0].shape)

        # print(source_features.shape, gen_features.shape)
        content_loss += torch.mean((source_features[0]- gen_features[0])**2)
        style_loss += torch.mean((style_grams[0] - gen_grams[0])**2)
    total_loss = content_weight * content_loss + style_weight * style_loss
    losses.append(total_loss.item())
    total_loss.backward()

    optim.step()
    

In [None]:
plt.plot(losses)

In [None]:
# plt.imshow(gen.permute(1,2,0).cpu().detach().numpy())
print(torch.max(gen))
img = torchvision.transforms.functional.to_pil_image(gen.squeeze(0))
plt.imshow(img)
# img.save("./test_images/out_3000_0_1.jpg")
# plt.imshow((gen.squeeze(0).permute(1,2,0)).cpu().detach().numpy())
print(losses[-1])