In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
from torchvision.utils import save_image

In [None]:
model = models.vgg19(pretrained=True).features

In [None]:
model

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        self.selected_layer = ['5', '10', '19', '28']
        
        self.model = models.vgg19(pretrained=True).features[:29]
    
    def forward(self, x):
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.selected_layer:
                features.append(x)
            
        return features
                
        

In [None]:
def load_image(image_name):
    image = Image.open(image_name)
    image = loader(image).unsqueeze(0)
    return image.to(device=device)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
image_size = 224

loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

In [None]:
!wget "https://upload.wikimedia.org/wikipedia/commons/c/cd/Anne_Hathaway_at_MIFF_%28cropped%29.jpg" -O anna.jpg
!wget "https://pbs.twimg.com/media/DU5DVJ_WAAICIMg.jpg" -O style.jpg

In [None]:
!wget "https://web.whatsapp.com/25d0ecc7-7b05-48f2-bb2b-ac42ccebe38d" -O pran.jpg

In [None]:
original_image = load_image("./anna.jpg")
style_image = load_image('./style.jpg')

#initial i
generated = original_image.clone().requires_grad_(True)

In [None]:
plt.imshow(Image.open("./anna.jpg"))
plt.show()
plt.imshow(Image.open("./style.jpg"))

In [None]:
model = VGG().to(device=device).eval()

Hyperparameters

In [None]:
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr =learning_rate)


In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_features = model(original_image)
    style_features = model(style_image)
    
    style_loss = original_loss = 0
    
    for gen_feat, orig_feat, style_feat in zip(generated_features, original_features, style_features):
        batch_size, channel, height, width = gen_feat.shape
        original_loss += torch.mean((gen_feat - orig_feat) ** 2)
        
        G = gen_feat.view(channel, height*width).mm(
            gen_feat.view(channel, height*width).t()
        )
        
        A = style_feat.view(channel, height*width).mm(
            style_feat.view(channel, height*width).t()
        )
        
        style_loss += torch.mean((G-A)**2)
    
    total_loss= alpha * original_loss + beta * style_loss
    optimizer.zero_grad()
    
    total_loss.backward()
    
    optimizer.step()
    
    if (step % 10 == 0):
        print(f"At step : {step} Total loss: {total_loss}") 
    
    if (step % 200 == 0):
        print(f"Total loss: {total_loss}")
        save_image(generated, f"{step}_generated.png")

![](././400_generated.png)

In [None]:
plt.imshow(Image.open("./5800_generate.png"))