In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
import torchvision.models as models
import numpy as np
import pandas as pd
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import pickle
import os
import time
from copy import deepcopy

In [0]:
vgg19 = models.vgg19(pretrained=True)
vgg19

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

# Neural Style Transfer

In [0]:
class NST(nn.Module):
  def __init__(self, max_pool=False):
    super(NST, self).__init__()
    vgg19 = models.vgg19(pretrained=True)
    # self.vgg19features = nn.ModuleList(list(vgg19.features))
    self.vgg19features = vgg19.features
    for i in range(len(self.vgg19features)):
      if isinstance(self.vgg19features[i], nn.ReLU):
        self.vgg19features[i] = nn.ReLU(inplace=False)
      if not max_pool:
        if isinstance(self.vgg19features[i], nn.MaxPool2d):
          self.vgg19features[i] = nn.AvgPool2d(kernel_size=2, stride=2)
    # self.vgg19features = nn.Sequential(self.vgg19features)
    print(self.vgg19features)
  
  def forward(self, image, content_layer=21, style_layers=[1,6,11,20,29], content_reconstruction=False, style_reconstruction=False):
    if content_reconstruction:
      for i, layer in enumerate(self.vgg19features):
        if i == 0:
          content_out = layer(image)
          # x = layer(content_image)
          # y = layer(white_noise_image)
        else:
          content_out = layer(content_out)
          # x = layer(x)
          # y = layer(y)
          if i == content_layer:
            content_image_representation = content_out
            # content_image_representation = x
            # white_noise_image_representation = y
    
    if style_reconstruction:
      style_outs = []
      # style_layer_outs = []
      # white_layer_outs = []
      for i, layer in enumerate(self.vgg19features):
        if i == 0:
          style_out = layer(image)
          # x = layer(style_image)
          # y = layer(white_noise_image)
        else:
          style_out = layer(style_out)
          # x = layer(x)
          # y = layer(y)
        if i in style_layers:
          style_outs.append(style_out[0])
          # style_layer_outs.append(x[0])
          # white_layer_outs.append(y[0])
    
    # return content_image_representation[0], white_noise_image_representation[0], style_layer_outs, white_layer_outs
    if content_reconstruction:
      return content_image_representation[0]
    
    if style_reconstruction:
      return style_outs
  
  def inference(self, content_image, style_image, white_noise_image, device, alpha=1, beta=1e2, content_layer=21, style_layers=[1,6,11,20,29], epochs=10, required_loss=0.01, content_reconstruction=True, style_reconstruction=True):
    print("before normalize: ")
    # d1, d2, d3 = content_image.shape
    plt.imshow(content_image); plt.show()
    plt.imshow(style_image); plt.show()
    plt.imshow(white_noise_image); plt.show()

    # content_image_norm = self.normalize(content_image)
    # style_image_norm = self.normalize(style_image)
    # mean, var = 0, 0.1
    # white_noise_image = content_image + np.random.normal(mean, var**0.5, content_image.shape)
    # white_noise_image = np.random.normal(mean, var**0.5, content_image.shape)
    # white_noise_image = deepcopy(content_image)
    # plt.imshow(white_noise_image.reshape((d2,d3,d1))); plt.show()
    # white_noise_image_norm = self.normalize(white_noise_image)

    content_image_tensor = self.preprocess(content_image).unsqueeze(0).to(device)
    style_image_tensor = self.preprocess(style_image).unsqueeze(0).to(device)
    white_noise_image_tensor = self.preprocess(white_noise_image).unsqueeze(0).to(device)
    
    # content_image_tensor = torch.from_numpy(content_image_norm).float().to(device).unsqueeze_(0)
    # style_image_tensor = torch.from_numpy(style_image_norm).float().to(device).unsqueeze_(0)
    # white_noise_image_tensor = torch.tensor(white_noise_image_norm).float().to(device).unsqueeze_(0)


    # print("after normalize, denormalize: ")
    # plt.imshow(self.denormalize(content_image_tensor[0]).numpy().reshape((224,224,3))); plt.show()
    # plt.imshow(self.denormalize(style_image_tensor[0]).numpy().reshape((224,224,3))); plt.show()
    # plt.imshow(self.denormalize(white_noise_image_tensor[0]).numpy().reshape((224,224,3))); plt.show()


    white_noise_image_tensor = white_noise_image_tensor.requires_grad_()
    # print(content_image_tensor.unsqueeze(0).shape, style_image_tensor.unsqueeze(0).shape, white_noise_image_tensor.unsqueeze(0).shape)
    
    # optimizer = optim.SGD([white_noise_image_tensor], lr=0.001, momentum=0.9)
    # optimizer = optim.Adam([white_noise_image_tensor], lr=0.001)
    optimizer = optim.LBFGS([white_noise_image_tensor])

    # print("before loop:")
    # white_noise_image_np = torch.tensor(white_noise_image_tensor).numpy()[0]
    # plt.imshow(white_noise_image_np.reshape(224,224,3)); plt.show();

    # content_image_tensor.unsqueeze_(0)
    # style_image_tensor.unsqueeze_(0)
    # white_noise_image_tensor.unsqueeze_(0)

    # loss = float('inf')
    count = 0
    # while loss > required_loss and count<epochs:
    content_image_representation = self(content_image_tensor, content_layer=content_layer, content_reconstruction=True)
    style_layer_outs = self(style_image_tensor, style_layers=style_layers, style_reconstruction=True)
    while count < epochs:
      print("EPOCH NO: ", count+1)
      def closure():
        # white_noise_image_tensor.data.clamp_(0, 1)
        optimizer.zero_grad()
        white_noise_image_representation = self(white_noise_image_tensor, content_layer=content_layer, content_reconstruction=True)
        white_layer_outs = self(white_noise_image_tensor, style_layers=style_layers, style_reconstruction=True)
        # print(content_image_representation.shape, white_noise_image_representation.shape, style_layer_outs[1].shape, white_layer_outs[3].shape)
        # loss = self.totalLoss(content_image_representation, white_noise_image_representation, style_layer_outs, white_layer_outs, alpha=alpha, beta=beta)
        weight = 1/len(style_layer_outs)
        content_loss = self.contentLoss(content_image_representation, white_noise_image_representation)
        style_loss = sum([weight*self.styleLoss(style_layer_outs[i], white_layer_outs[i]) for i in range(len(style_layer_outs))])
        total_loss = alpha*content_loss + beta*style_loss
        print("total loss: ", total_loss)
        total_loss.backward(retain_graph=True)
        return total_loss

      optimizer.step(closure)

      # print("after step: ")
      # print(white_noise_image_tensor)
      # print("LOSS: ",loss)
      # white_noise_image_np = self.denormalize(torch.tensor(white_noise_image_tensor[0])).cpu().numpy()
      white_noise_image_pil = self.postprocess(torch.tensor(white_noise_image_tensor[0]).cpu())
      # content_image_np = self.denormalize(torch.tensor(content_image_tensor[0])).cpu().numpy()
      # style_image_np = self.denormalize(torch.tensor(style_image_tensor[0])).cpu().numpy()
      plt.imshow(white_noise_image_pil); plt.show();
      # plt.imshow(content_image_np.reshape((d2,d3,d1))); plt.show();
      # plt.imshow(style_image_np.reshape((d2,d3,d1))); plt.show();
      count += 1
    
    # white_noise_image_tensor.squeeze_(0)
    # white_noise_image_tensor_denorm = self.denormalize(white_noise_image_tensor.detach().cpu()[0])
    
    # return white_noise_image_tensor_denorm.numpy()
    return self.postprocess(white_noise_image_tensor.squeeze(0).detach().cpu())
  
  def preprocess(self, image):
    img_size=256
    pre = transforms.Compose([transforms.Scale(img_size),
                           transforms.ToTensor(),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
                           transforms.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], #subtract imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x.mul_(255)),
                          ])
    return pre(image)
  
  def postprocess(self, image):
    post1 = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
                           transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], #add imagenet mean
                                                std=[1,1,1]),
                           transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
                           ])
    post2 = transforms.Compose([transforms.ToPILImage()])
    tnsr = post1(image)
    tnsr[tnsr>1] = 1
    tnsr[tnsr<0] = 0
    return post2(tnsr)

  def contentLoss(self, content_image_representation, white_noise_image_representation):
    return nn.MSELoss()(content_image_representation, white_noise_image_representation)
  
  def styleLoss(self, style_layer_outs, white_layer_outs):
    def gramMatrix(x):
      d1, d2, d3 = x.shape
      x_reshape = x.view(d1, d2*d3)
      return torch.mm(x_reshape, x_reshape.t()).div(d1*d2*d3)
    
    style_gram = gramMatrix(style_layer_outs)
    white_gram = gramMatrix(white_layer_outs)

    return nn.MSELoss()(style_gram, white_gram)
  
  def normalize(self, image):
    # img = torch.tensor(img).float()
    img = image/255.
    # mean = torch.tensor([0.485, 0.456, 0.406])
    # std = torch.tensor([0.229, 0.224, 0.225])
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    img[0, :, :] -= mean[0]
    img[1, :, :] -= mean[1]
    img[2, :, :] -= mean[2]

    #divide by std
    img[0, :, :] /= std[0]
    img[1, :, :] /= std[1]
    img[2, :, :] /= std[2]
  
    return img
  
  def denormalize(self, image):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])

    img = deepcopy(image)

    img[0, :, :] *= std[0]
    img[1, :, :] *= std[1]
    img[2, :, :] *= std[2]

    img[0, :, :] += mean[0]
    img[1, :, :] += mean[1]
    img[2, :, :] += mean[2]
    
    img*=255.
    return img



In [0]:
def imageLoad(path, size=256):
  # path=os.getcwd()+"/img1.jpg"
  # print(path)
  # img=cv2.imread(path,cv2.IMREAD_COLOR)
  # img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  img = Image.open(path)
  img = img.resize((size, size))
  # img = np.asarray(img)/1.
  # print(img.shape)
  # img = cv2.resize(img, (500, 500)).reshape((3, 500, 500))
  # img = np.float64(img)/255.
  return img

def imageSave(img, path, fmt='PNG'):
  # img1=cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  # cv2.imwrite(path, img)
  img.save(path, format=fmt)



In [0]:
path_to_sml = "/content/drive/My Drive/SML Project"
path_to_content_images = os.path.join(path_to_sml, "Content images")
path_to_style_images = os.path.join(path_to_sml, "Style images")
path_to_nst_images = os.path.join(path_to_sml, "NST images")

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nst = NST().to(device)
print(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU()
  (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU()
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU()
  (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU()
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU()
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU()
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU()
  (18): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (20): ReLU()
  (21): Conv2d(512, 512

In [0]:
content_image1 = imageLoad(os.path.join(path_to_content_images, "content3.jpeg"))
# style_image1 = imageLoad(os.path.join(path_to_style_images, "picasso.jpg"))
style_image1 = imageLoad(os.path.join(path_to_style_images, "picasso.jpg"))

In [0]:
# print(np.float64(content_image1).dtype)
# plt.imshow(content_image1.reshape((500,500,3)))
# plt.show()
# plt.imshow(style_image1.reshape((500,500,3)))
# plt.show()

plt.imshow(content_image1); plt.show()
plt.imshow(style_image1); plt.show()

In [0]:
input_image = deepcopy(content_image1)

In [0]:
alpha = 1e1
beta = 1e7
content_layer = 21
# style_layers = [1,6,11,20,29]
style_layers = [1,6,11,20,29]
epochs = 100
start_time = time.time()
nst_image = nst.inference(content_image1, style_image1, input_image, device, epochs=epochs, alpha=alpha, beta=beta)
end_time = time.time()
print("time taken: ", (end_time - start_time), (end_time - start_time)/60)

In [0]:
plt.imshow(nst_image);plt.show()
plt.imshow(content_image1);plt.show()
plt.imshow(style_image1); plt.show()

In [0]:
imageSave(nst_image, os.path.join(path_to_nst_images, "nst_image5"+str((alpha, beta, content_layer, style_layers,epochs))+".jpg"))