In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import requests
from io import BytesIO
import json

import torch
import torchvision

import copy

In [None]:
#Getting the Image.
image_url = 'https://cdn.pixabay.com/photo/2017/09/13/15/38/airplane-2745898_960_720.jpg'
#image_url = 'https://media.healthdirect.org.au/images/general/primary/baby-sleeping-AMCBWB.jpg'
image = Image.open(BytesIO(requests.get(image_url).content))
plt.axis('off')
plt.imshow(image)

In [None]:
#VGG19 and setting it to evaluation mode.
vgg19 = torchvision.models.vgg19(pretrained=True).eval()

In [None]:
#Getting the labels
!git clone https://github.com/anishathalye/imagenet-simple-labels.git

labels_json = open('imagenet-simple-labels/imagenet-simple-labels.json')
labels = json.load(labels_json)
labels

In [None]:
#Preparing the image
input_size = 128 #....

transform_to_input = torchvision.transforms.Compose([torchvision.transforms.Resize((input_size,input_size)),
                                            torchvision.transforms.ToTensor()])


input_image = transform_to_input(image).unsqueeze(0)

In [None]:
#The predictions.
prediction = vgg19(input_image)
prediction

In [None]:
#Chcking the top 5 predictions
prediction_copy = prediction.data.numpy().copy().squeeze(0)
for _ in range(5):
    index = prediction_copy.argmax()
    print('class: {} with score: {}'.format(labels[index], prediction_copy[index]))
    prediction_copy[index] = prediction_copy.min() - 1

# We can reverse it !

Suppose we know the outputs of the convolutional neural network for some unkown image x.

We can approximate x from the its neural network output by finding an image y that minimizes the error (for some error function) between the output of y and the output of x.

In [None]:
layers_list = [4] #The layers we are measuring the errors at. List because we can use more than one layer.

In [None]:
#Let's prepare a new network, by inserting additional layers that measure the error after some convolutional layers.

#The additional Layer
class layer_error(torch.nn.Module):
    def __init__(self, target):
        super(layer_error, self).__init__()
        self.target = target.detach()
        self.error = torch.Tensor([0]).type(torch.FloatTensor)
        
    def forward(self, input):
        self.error = torch.nn.functional.mse_loss(input,self.target)
        return input #To not mess the network
    
vgg19_features = copy.deepcopy(vgg19.features.eval())
vgg19_features

In [None]:
class VGG_Normalization(torch.nn.Module):
    def __init__(self):
        super(VGG_Normalization, self).__init__()
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

In [None]:
normalization = VGG_Normalization()
vgg_with_layer_errors = torch.nn.Sequential(normalization) #empty

conv_layer_number=0
layer_number = 0
layer_errors = []

for layer in vgg19_features.children():
    
    layer_number += 1
    if not isinstance(layer, torch.nn.ReLU):
        vgg_with_layer_errors.add_module('{}'.format(layer_number), layer)
    else :
        vgg_with_layer_errors.add_module('{}'.format(layer_number), torch.nn.ReLU(inplace=False))
    
    if isinstance(layer, torch.nn.Conv2d):
        conv_layer_number += 1
        if conv_layer_number in layers_list:
            layer_number += 1
            new_layer_error = layer_error(vgg_with_layer_errors(input_image))
            vgg_with_layer_errors.add_module('{}'.format(layer_number), new_layer_error)
            layer_errors.append(new_layer_error)
    
    if len(layer_errors) == len(layers_list):
        break
    
vgg_with_layer_errors    

In [None]:
approximate_image = torch.randn(input_image.data.size()) #Initial Value

In [None]:
#optimizer = torch.optim.Adam([approximate_image.requires_grad_()], lr=0.01)
optimizer = torch.optim.LBFGS([approximate_image.requires_grad_()])

In [None]:
#Helping function to show images from tensors.

image_extractor = torchvision.transforms.ToPILImage()

def image_show(image_tensor):
    image = image_tensor.clone().detach().squeeze(0)
    image = image_extractor(image)
    
    plt.axis('off')
    plt.imshow(image)
    plt.show()

In [None]:
image_show(approximate_image)

In [None]:
image_show(input_image)

In [None]:
import time

approximate_images = [approximate_image.clone().detach()]
number_of_iterations = 100

begin = time.time()

for step in range(number_of_iterations):   
    
    error = 0
    def closure():
        global error
        error = 0
        approximate_image.data.clamp_(0,1)
        vgg_with_layer_errors(approximate_image)
        error = torch.Tensor([0]).type(torch.FloatTensor)
        for layer in layer_errors:
            error += layer.error
        optimizer.zero_grad()
        error.backward()

        return error
    
    optimizer.step(closure)
    
    if step% 10 == 0:
        print('step: {} ,error = {:4f}'.format(step, error.item()))
        image_show(approximate_image)
        approximate_images.append(approximate_image.clone().detach())
        
end = time.time()
print('Time elapsed: {}'.format(begin-end))


# Style Transfer.

Let's define the vgg <b>style</b> of an image $X$ at layer $n$ to be the <b>Gram matrix</b> of the output of the layer $n$ once the image $X$ is feed into vgg.

The <b>Gram matrix</b> $G$ of a sequence of vectors $v_1,...,v_k\in\mathbb{R}^d$ is defined as


$$G=\begin{pmatrix}
 \left \langle v_1,v_1 \right \rangle & \left \langle v_1,v_2 \right \rangle &  \cdots & \left \langle v_1,v_k \right \rangle\\ 
 \left \langle v_2,v_1 \right \rangle & \left \langle v_2,v_2 \right \rangle &  \cdots & \left \langle v_2,v_k \right \rangle\\
 \vdots & \ddots&  &\vdots\\
 \vdots & & \ddots &\vdots\\
 \left \langle v_k,v_1 \right \rangle & \left \langle v_k,v_2 \right \rangle &  \cdots & \left \langle v_k,v_k \right \rangle\\
\end{pmatrix}$$

where $\left \langle v,u \right \rangle$ is "the inner product" between vectos $v$ and $u$ (Vectors here will be in fact matrices).

Notice that if $A$ is the matrix whose rows are $v_1,...,v_k$ then $G=AA^{T}$

In [None]:
#A function that returns Gram matrix.
def Gram_matrix(tensor_4d):
    size = tensor_4d.size()
    vectors_matrix = tensor_4d.view(size[1], size[2]*size[3]) #size[0]=1
    G = torch.mm(vectors_matrix, vectors_matrix.t())
    return G/(size[1]*size[2]*size[3]) #normalization

As we saw in the previous section, we can reconstruct an image by knowing the vgg outputs at any convolutional layer.
To transfer the <b>style</b> we want the reconstructed image to have a Gram matrix close to the Gram matrix of the image we want to the style from.

Hence we will just adjust the loss layer we defined previously

In [None]:
#The additional Layer adjusted for style transfer
class layer_error_style(torch.nn.Module):
    def __init__(self, content, style, style_weight = 10000): # as suggested in the paper, either 1000 or 10000.
        super(layer_error_style, self).__init__()
        self.content = content.detach()
        self.style = style.detach()
        self.error = torch.Tensor([0]).type(torch.FloatTensor)
        self.content_error = torch.Tensor([0]).type(torch.FloatTensor)
        self.style_error = torch.Tensor([0]).type(torch.FloatTensor)
        self.style_weight = style_weight
        
    def forward(self, input):
        self.content_error = torch.nn.functional.mse_loss(input,self.content)
        self.style_error = self.style_weight*torch.nn.functional.mse_loss(Gram_matrix(input), Gram_matrix(self.style))
        self.error = self.content_error + self.style_error
        return input #To not mess the network

We write the same things we wrote previously

In [None]:
#Just some random images from google search :)
content_url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Leonhard_Euler.jpg'
content = Image.open(BytesIO(requests.get(content_url).content))
style_url = 'https://upload.wikimedia.org/wikipedia/en/8/8f/Pablo_Picasso%2C_1909-10%2C_Figure_dans_un_Fauteuil_%28Seated_Nude%2C_Femme_nue_assise%29%2C_oil_on_canvas%2C_92.1_x_73_cm%2C_Tate_Modern%2C_London.jpg'
style = Image.open(BytesIO(requests.get(style_url).content))

In [None]:
content_image = transform_to_input(content).unsqueeze(0)
style_image = transform_to_input(style).unsqueeze(0)
#style_transferred_image = content_image.clone().detach() #Initial Value
style_transferred_image = torch.randn(content_image.data.size())

image_show(content_image)
image_show(style_image)
image_show(style_transferred_image)

In [None]:
layers_list = [1,2,3,4,5] #The layers we are measuring the errors at. List because we can use more than one layer.

In [None]:
normalization = VGG_Normalization()
vgg_style_transfer = torch.nn.Sequential(normalization) #the new network with the first layer being normalization.

layer_number = 0
layer_errors = []
conv_layer_number = 0

for layer in vgg19_features.children():
    
    layer_number += 1    
    if not isinstance(layer, torch.nn.ReLU):
        vgg_style_transfer.add_module('{}'.format(layer_number), layer)
    else :
        vgg_style_transfer.add_module('{}'.format(layer_number), torch.nn.ReLU(inplace=False)) #error.Backward() complains wthen inplace=True.

    if isinstance(layer, torch.nn.Conv2d):
        conv_layer_number +=1
        if conv_layer_number in layers_list:
            layer_number+=1
            new_layer_error = layer_error_style(vgg_style_transfer(content_image), vgg_style_transfer(style_image)
                                                ,style_weight=100000) #Try different weights.
            vgg_style_transfer.add_module('{}'.format(layer_number), new_layer_error)
            layer_errors.append(new_layer_error)
            
    if len(layer_errors) == len(layers_list):
        break
    
vgg_style_transfer    

In [None]:
#optimizer = torch.optim.Adam([style_transferred_image.requires_grad_()], lr=0.1)
optimizer = torch.optim.LBFGS([style_transferred_image.requires_grad_()])

In [None]:
style_transferred_images = [style_transferred_image.clone().detach()]
number_of_iterations = 50

begin = time.time()
for step in range(number_of_iterations):   
    
    error = torch.Tensor([0.0])
    content_error = torch.Tensor([0.0])
    style_error = torch.Tensor([0.0])
    def closure():
        global error
        global content_error
        global style_error
        error = torch.Tensor([0.0])
        content_error = torch.Tensor([0.0])
        style_error = torch.Tensor([0.0])
        style_transferred_image.data.clamp_(0,1)
        vgg_style_transfer(style_transferred_image)

        for layer in layer_errors:
            error += layer.error
            content_error += layer.content_error
            style_error += layer.style_error
        optimizer.zero_grad()
        error.backward()

        return error
    
    optimizer.step(closure)
    
    if step% 1 == 0:
        print('step: {} , error = {:4f} , style error = {:4f} , content error = {:4f}'.format(step, error.item(), 
                                                                         content_error.item(), style_error.item()))
        image_show(style_transferred_image)
        style_transferred_images.append(style_transferred_image.clone().detach())

end = time.time()
print('Time elapsed: {}'.format(end-begin))

In [None]:
len(style_transferred_images)

In [None]:
final_image = image_extractor(style_transferred_images[100].squeeze(0))

plt.imshow(final_image)

In [None]:
#Let's try to smoothen the image.
from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral,
                                 denoise_wavelet, estimate_sigma)
from skimage import data, img_as_float

In [None]:
final_image_float = img_as_float(final_image)

In [None]:
for i in range(10):
    plt.imshow(denoise_tv_chambolle(final_image_float, weight= (i+1)/100, multichannel=True))
    plt.axis('off')
    plt.show()