In [1]:
import tqdm
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.utils import save_image
import torchvision.transforms as transforms

In [2]:
torch.cuda.get_device_name()

'Tesla T4'

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [0, 5, 10, 19, 28]
        self.vgg = models.vgg19(pretrained=True).features[:29]

        for param in self.vgg.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        features = []
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i in self.layers:
                features.append(x)

        return features

In [4]:
def load_image(img, transform, device):
    img = Image.open(img)
    img = transform(img).unsqueeze(0).to(device)
    return img

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 356
device

device(type='cuda')

In [6]:
transform = transforms.Compose([transforms.Resize((img_size, img_size)),
                                transforms.ToTensor()])

In [7]:
original_img = load_image("elon_musk.jpg", transform, device)
style_img = load_image("picasso.jpg", transform, device)
generated_img = original_img.clone().requires_grad_(True)

In [8]:
net = Net().to(device)
net.eval()
net

Net(
  (vgg): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(

In [9]:
steps = 10000
lr = 1e-3
alpha = 1
beta = 1e-2
opt = optim.Adam([generated_img], lr)

In [10]:
loop = tqdm.tqdm(range(steps), total=steps, leave=False)
for step in loop:
    og_features = net(original_img)
    gen_features = net(generated_img)
    sty_features = net(style_img)

    content_loss = 0
    style_loss = 0

    for o, g, s in zip(og_features, gen_features, sty_features):
        b, c, h, w = g.shape
        content_loss += torch.mean((g - o) ** 2)

        G = g.view(c, h * w).mm(g.view(c, h * w).t())
        S = s.view(c, h * w).mm(s.view(c, h * w).t())

        style_loss += torch.mean((G - S) ** 2)
    
    loss = alpha * content_loss + beta * style_loss
    opt.zero_grad()
    loss.backward()
    opt.step()

    if step % 1000 == 0:
        print(f" Loss: {np.round(loss.item(), 4)}")
        save_image(generated_img, f"output_{step}.png")
        print("Image Saved!!")
        print("")

  0%|          | 5/10000 [00:00<24:25,  6.82it/s]

 Loss: 1196540.0
Image Saved!!



 10%|█         | 1005/10000 [01:27<14:55, 10.05it/s]

 Loss: 12331.498
Image Saved!!



 20%|██        | 2005/10000 [02:58<13:34,  9.81it/s]

 Loss: 4787.9014
Image Saved!!



 30%|███       | 3005/10000 [04:30<12:03,  9.67it/s]

 Loss: 3210.8015
Image Saved!!



 40%|████      | 4005/10000 [06:03<10:23,  9.62it/s]

 Loss: 2344.7749
Image Saved!!



 50%|█████     | 5005/10000 [07:37<08:39,  9.62it/s]

 Loss: 1789.3899
Image Saved!!



 60%|██████    | 6005/10000 [09:11<06:59,  9.52it/s]

 Loss: 1429.5714
Image Saved!!



 70%|███████   | 7005/10000 [10:45<05:11,  9.62it/s]

 Loss: 1207.6511
Image Saved!!



 80%|████████  | 8005/10000 [12:20<03:26,  9.64it/s]

 Loss: 1073.0446
Image Saved!!



 90%|█████████ | 9005/10000 [13:54<01:44,  9.55it/s]

 Loss: 991.6924
Image Saved!!





In [11]:
save_image(generated_img, f"output_{step+1}.png")