<a href="https://colab.research.google.com/github/sharon200102/ImageStyleTransfer/blob/main/Image_Style_Transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# basic initialization
import torchvision.models as models
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
# enable cuda device
GPU = 0
epoch = 0
# The paths to the content and style images
content_img_path = Path('/content/Neckarfront_Tübingen_Mai_2017.jpg')
style_img_path = Path('/content/starry-night-1093721_960_720.jpg')
writer = SummaryWriter('/content/imag_style_transfer_exp')
pre_trained_model_name = 'vgg19'
start_from_noise = False
# parameters defined by the article
patience = 100
alpha =1
beta=1000
content_representation_layer = 21
style_representation_layres = [0,5,10,19,28]
weights = [0.2]*5
lr = 0.005

def get_device(gpu=GPU):
    return torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() else 'cpu')

# The basic loss fuctions

def square_differncess(a,b,scalar:float=1):
  """
  f,p are assumed to be in the shape of,(m,n)
  """
  sub_matrix = a-b
  square_differencess_matrix = sub_matrix*sub_matrix
  return square_differencess_matrix.sum()*scalar

def content_loss(p,x):
  scalar = 1/2
  return square_differncess(p,x,scalar)

def element_style_loss(g,a):
  scalar=1/(4*((g.shape[0])**2)*((a.shape[0])**2))
  return square_differncess(g,a,scalar)

def style_loss(g_seq,a_seq,weights):
    style_loss_elements = list(map(element_style_loss,g_seq,a_seq))
    return sum(map(lambda x,y: x*y,style_loss_elements,weights))

class unsqueezeTransform():
  def __init__(self,dim = 0):
    self.dim = dim
  
  def __call__(self, sample):
    return torch.unsqueeze(sample,dim=self.dim)

def imshow(img, title=None):
  if len(img.shape) > 3:
    img = torch.squeeze(img)

  plt.imshow(img)
  if title:
    plt.title(title)

def early_stopping(loss_buffer:list,patience:int,minimize = True):
  if len(loss_buffer)<=patience:
    return False
  
  if minimize:
    min_item = min(loss_buffer[-patience-1:])
    index_of_min = loss_buffer[-patience-1:].index(min_item)
    if index_of_min == 0:
      return True
    return False

  if not minimize:
    max_item = max(loss_buffer[-patience-1:])
    index_of_max = loss_buffer[-patience-1:].index(max_item)
    if index_of_max == 0:
      return True
    return False



# define the transformer for image preprocess.
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    unsqueezeTransform(0)
    

])

In [5]:
# The following class will provide the representation of the content and the style of a given image.

class ImageStyleRepresentor:
  supported_pretrained_modles = {'vgg16':models.vgg16,'vgg19':models.vgg19}
  def __init__(self,model_name:str = 'vgg16', device=None) -> None:
    device= device if device is not None else torch.device('cpu') 
    self.device = device
    pre_trained_model = self.supported_pretrained_modles[model_name](pretrained=True).to(self.device)
    features = list(pre_trained_model.features)
    self.features = nn.ModuleList(features).eval()

  def get_content_representaion(self, img, content_layer_idx:int=0):
    img = img.to(self.device)
    for layer_idx,model in enumerate(self.features):
            img = model(img)
            if layer_idx == content_layer_idx:
                return img

  def get_style_representation(self, img, style_layers_idx:list=None):
    representation_list = []
    img = img.to(self.device)
    for layer_idx,model in enumerate(self.features):
          img = model(img)
          if layer_idx in style_layers_idx :
            representation_list.append(img)
    return list(map(self._gram_multiplication_wrapper,representation_list))
  # Should be a static function
  def _gram_multiplication(self,t:torch.tensor):
    # currently t is assumed to in the follwoing shape (n_filters,flatted_rep)
        return torch.matmul(t,t.transpose(0,1))
  # Should be a static function
  def _gram_multiplication_wrapper(self,t:torch.tensor):
      # t is assumed to be of the follwoing shape (1,n_filters,m_rep,m_rep)
      if len(t.shape)>2:
        t = t.squeeze(0)
        t = t.flatten(start_dim=1)
      return self._gram_multiplication(t)
  

  
    


In [None]:
# open and preprocess the content and style images 
style_img = Image.open(style_img_path)
content_img = Image.open(content_img_path)

preprocessed_style_img = preprocess(style_img)
preprocessed_content_img = preprocess(content_img)

if start_from_noise:
  # Sample a random noise and preprocess it.
  imarray = np.random.rand(224,224,3) * 255
  transfered_img = Image.fromarray(imarray.astype('uint8')).convert('RGB')
  transfered_img = preprocess(transfered_img)
  transfered_img = transfered_img.clone().detach().requires_grad_(True)
else:
  transfered_img = preprocessed_content_img.clone().detach().requires_grad_(True)


# Load the pretrained model
isr =ImageStyleRepresentor(pre_trained_model_name,device=get_device())

# 
real_content_representation = isr.get_content_representaion(preprocessed_content_img,content_representation_layer)
real_style_representation = isr.get_style_representation(preprocessed_style_img,style_representation_layres)


loss_buffer = []
optimizer = torch.optim.Adam([transfered_img],lr =lr)


while not early_stopping(loss_buffer,patience):
  optimizer.zero_grad()
  transfered_img_content_representation = isr.get_content_representaion(transfered_img,content_representation_layer)
  transfered_img_style_representation = isr.get_style_representation(transfered_img,style_representation_layres)

  c_loss = content_loss(real_content_representation,transfered_img_content_representation)
  s_loss = style_loss(real_style_representation,transfered_img_style_representation,weights)
  total_loss = alpha*c_loss + beta*s_loss
  print(f'this is epoch number {epoch}, the loss is {total_loss}')

  
  writer.add_scalar('content_loss',c_loss.detach(),epoch)
  writer.add_scalar('style_loss',s_loss.detach(),epoch)
  writer.add_scalar('total_loss',total_loss.detach(),epoch)

  loss_buffer.append(total_loss.detach())

  total_loss.backward(retain_graph=True)
  optimizer.step()
  if epoch%1000==0:
    save_image(transfered_img.detach(), f'transfered_img{epoch}.png')
  epoch+=1
  torch.cuda.empty_cache()
  
save_image(transfered_img.detach(), f'transfered_img{epoch}.png')




In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/imag_style_transfer_exp
