In [240]:
import torch 
import torch.nn as nn 
import torchvision
from torchvision import models,transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt 

In [241]:
device=torch.device('mps')

In [242]:
def get_image(path,img_transform,size=(300,300)):
    image = Image.open(path)
    image = image.resize(size,Image.LANCZOS)
    image=img_transform(image).unsqueeze(0)
    return image.to(device)

In [243]:
def get_gram(m):
    # shape of m is (1,c,h,w)
    _,c,h,w=m.size()
    m=m.view(c , h*w)
    m=torch.mm(m,m.t())
    return m


In [244]:
def denormalize_img(inp):
    inp=inp.numpy().transpose((1,2,0)) # to convert chw to hwc
    mean=np.array([0.485,0.456,0.406])
    std=np.array([1,1,1])
    inp=inp * std + mean
    inp = np.clip(inp,0,1)
    return inp 

In [245]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor,self).__init__()
        self.selected_layers =[3,8,15,22]
        self.vgg=models.vgg16(weights='VGG16_Weights.DEFAULT').features

    def forward(self,x):
        layer_feats=[]
        for layer_num , layer in self.vgg._modules.items():
            x = layer(x)
            if int(layer_num) in self.selected_layers:
                layer_feats.append(x)
        return layer_feats        

In [246]:
# for i,j in models.vgg16(weights='VGG16_Weights.DEFAULT').features._modules.items():
#     print(j)

In [247]:
# vgg.features

In [248]:
img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.485,0.456,0.406),std=(1,1,1))])
# 0.229,0.224,0.225
content_img = get_image('content_image5.jpg',img_transform).to(device)
style_img = get_image('style_image4.jpg',img_transform).to(device)

generated_img=content_img.clone()
generated_img.requires_grad = True
optimizer = torch.optim.Adam([generated_img], lr = 0.0007,betas = (0.5,0.999))
encoder=FeatureExtractor().to(device)

for p in encoder.parameters():
    p.requires_grad=False

In [249]:
# generated_img=generated_img.to(device)
content_weight=1
style_weight=100
# f=1
for epoch in range(1000):

    content_features = encoder(content_img)
    style_features = encoder(style_img)
    generated_features = encoder(generated_img)

    content_loss = torch.mean((content_features[-1] - generated_features[-1])**2)

    style_loss = 0
    for gf , sf in zip(generated_features , style_features):
        _,c,h,w=gf.size()
        gram_gf=get_gram(gf)
        gram_sf=get_gram(sf)
        style_loss += torch.mean((gram_gf-gram_sf)**2)/(c * h * w)
    
    loss = content_weight * content_loss + style_weight * style_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch%100 ==0:
        print("Epoch [{}]\tContent Loss: {:.4f}\tStyle Loss:{:.4f}".format(epoch,content_loss.item(),style_loss))

Epoch [0]	Content Loss: 0.0000	Style Loss:78.9002
Epoch [100]	Content Loss: 0.9436	Style Loss:25.6521
Epoch [200]	Content Loss: 1.2161	Style Loss:14.9854
Epoch [300]	Content Loss: 1.3519	Style Loss:10.2414
Epoch [400]	Content Loss: 1.4237	Style Loss:7.7343
Epoch [500]	Content Loss: 1.4693	Style Loss:6.1720
Epoch [600]	Content Loss: 1.5041	Style Loss:5.0809
Epoch [700]	Content Loss: 1.5302	Style Loss:4.2682
Epoch [800]	Content Loss: 1.5508	Style Loss:3.6306
Epoch [900]	Content Loss: 1.5674	Style Loss:3.1170


In [None]:
inp = generated_img.detach().cpu().squeeze()
inp = denormalize_img(inp)

In [None]:
# plt.show(inp.all())

In [None]:
import cv2

In [None]:
# cv2.imshow('aa',inp)

In [None]:
# inp.shape

In [None]:
import PIL

In [None]:
from PIL import Image

In [None]:
import torchvision.transforms as T
def showimg(tensor):
    transform = T.ToPILImage()
    img = transform(tensor)
    img.show()

In [None]:
tensor = inp*255.0
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor)>3:
  assert tensor.shape[0] == 1
  tensor = tensor[0]
showimg(tensor)
# tensor =  PIL.Image.fromarray(tensor)
# plt.imshow(cv2.cvtColor(np.array(tensor), cv2.COLOR_BGR2RGB))
# plt.show()