In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt

In [2]:
model = models.vgg19(pretrained = True).features
print(model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [3]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG,self).__init__()
        self.chosen_features = ['0','5', '10', '19','28']
        self.model = models.vgg19(pretrained = True).features[:29]
    def forward(self, x):
        features = []
        for layer_num,layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)
        return features

In [4]:
def load_image(image):
    image = Image.open(image)
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
im_size = 356

In [7]:
loader = transforms.Compose([transforms.Resize((im_size,im_size)),
                            transforms.ToTensor()])

In [8]:
original_img = load_image('goldenbridge.jpg')
style_img = load_image('starryimg.jpg')

In [9]:
type(style_img)

torch.Tensor

In [10]:
style_img.is_cuda

True

In [11]:
original_img.is_cuda

True

In [12]:
# generated = torch.randn(original_img.data.shape, device=device, requires_grad=True)
generated = original_img.clone().requires_grad_(True)
model = VGG().to(device).eval()

In [13]:
#hyperparameters
total_steps = 6000
learning_rate = 0.01
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated],lr = learning_rate)

In [14]:
#training
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)
    style_loss = 0
    original_loss = 0
    for gen_feature, orig_feature, style_feature in zip(generated_features,original_img_features,style_features):
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature)**2)
        G = gen_feature.view(channel,height*width).mm(gen_feature.view(channel,height*width).t())
        A = style_feature.view(channel,height*width).mm(style_feature.view(channel,height*width).t())
        style_loss = torch.mean((G-A)**2)
    total_loss = alpha*original_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    if step%200==0:
        print(total_loss)
        save_image(generated,'generated.png')

tensor(105584.1484, device='cuda:0', grad_fn=<AddBackward0>)
tensor(328.5591, device='cuda:0', grad_fn=<AddBackward0>)
tensor(170.1773, device='cuda:0', grad_fn=<AddBackward0>)
tensor(180.9818, device='cuda:0', grad_fn=<AddBackward0>)
tensor(134.1369, device='cuda:0', grad_fn=<AddBackward0>)
tensor(165.2439, device='cuda:0', grad_fn=<AddBackward0>)
tensor(579.0601, device='cuda:0', grad_fn=<AddBackward0>)
tensor(190.7857, device='cuda:0', grad_fn=<AddBackward0>)
tensor(152.8796, device='cuda:0', grad_fn=<AddBackward0>)
tensor(135.3183, device='cuda:0', grad_fn=<AddBackward0>)
tensor(122.0916, device='cuda:0', grad_fn=<AddBackward0>)
tensor(113.5036, device='cuda:0', grad_fn=<AddBackward0>)
tensor(108.0214, device='cuda:0', grad_fn=<AddBackward0>)
tensor(104.7045, device='cuda:0', grad_fn=<AddBackward0>)
tensor(99.2918, device='cuda:0', grad_fn=<AddBackward0>)
tensor(94.9996, device='cuda:0', grad_fn=<AddBackward0>)
tensor(104.6537, device='cuda:0', grad_fn=<AddBackward0>)
tensor(89.998