In [None]:
# %matplotlib inline
# %reload_ext autoreload
# %autoreload 2

# Image style transfer with Deep learning


In this notebook, we will play with a pretrain VGG convolutional neural network in order to understand style transfer <br>
At the end, we should be able to take the style of an image A and to apply in to an image B
<img src="https://dmtyylqvwgyxw.cloudfront.net/instances/132/uploads/images/custom_image/image/1581/normal_Slide11.JPG?v=1508001718" alt="Drawing" style="width: 600px;"/>
<br>
This work is based on <a href="https://arxiv.org/pdf/1508.06576.pdf"> this paper</a>
<br>

In [None]:
#We import our libraries
from fastai.conv_learner import *
import torch
from pathlib import Path
from scipy import ndimage
torch.cuda.set_device(0)
from torchvision import models
torch.backends.cudnn.benchmark=True

We will download a pretrained m_vgg network. The following cell may take some time to run as you will load the network.

In [None]:
m_vgg = (vgg16(True)).cuda().eval()
set_trainable(m_vgg, False)

# Part one: The proof that we can reconstruct an image from his convolution

VGG is a convolution neural network created in 2014. It is made of blocks composed of 3 times 3 kernels CNN, with a BatchNord, a Relu and a maxpool. To have a look to its achitecture, run the cell bellow:

In [None]:
m_vgg

What we will prove first is that it is possible to recontruct an image from its convolution output, using backpropagation. Let's take an image, a tree im my case:

In [None]:
img_fn = "/home/remi/Desktop/tree.jpg"
img = open_image(img_fn)
plt.imshow(img);

We preprocess our image so that we can fit it to the network: <br>
(I am running this notebook on my laptop, so I use a small image)

In [None]:
sz=288
trn_tfms,val_tfms = tfms_from_model(vgg16, sz)
img_tfm = val_tfms(img)
img_tfm.shape

Then we will generate a noise image in wich we will apply our gradient from our convolution activation

In [None]:
opt_img = np.random.uniform(0, 1, size=img.shape).astype(np.float32)
plt.imshow(opt_img);

In [None]:
#we smooth our image
opt_img = scipy.ndimage.filters.median_filter(opt_img, [8,8,1])
plt.imshow(opt_img);

In [None]:
opt_img = val_tfms(opt_img)/2
opt_img_v = V(opt_img[None], requires_grad=True)
opt_img_v.shape

In [None]:
# We take the first layers of our network, until the first convolution
m_vgg = nn.Sequential(*children(m_vgg)[:8])

In [None]:
targ_t = m_vgg2(VV(img_tfm[None]))
targ_v = V(targ_t)
targ_t.shape

So what we will do is: Take the convolution output of vgg with our tree image, the convolution output of vgg with our noise image, and then retroprapagate the gradient to train the noise image. Here, our loss function will be a basic mean square distance betweet the two outputs:

In [None]:
max_iter = 250
show_iter = 5
optimizer = optim.LBFGS([opt_img_v], lr=0.5)

In [None]:
def actn_loss(x): return F.mse_loss(m_vgg2(x), targ_v)*1000

In [None]:
def step(loss_fn):
    global n_iter
    optimizer.zero_grad()
    loss = loss_fn(opt_img_v)
    loss.backward()
    
    n_iter+=1
    if n_iter%show_iter==0: print(f'Iteration: {n_iter}, loss: {loss.data[0]}')
    return loss

In [None]:
n_iter=0
while n_iter <= max_iter: optimizer.step(partial(step,actn_loss))

In [None]:
x = val_tfms.denorm(np.rollaxis(to_np(opt_img_v.data),1,4))[0]
plt.figure(figsize=(7,7))
plt.imshow(x);

In [None]:
x.shape

In [None]:
# import scipy.misc
# scipy.misc.imsave('/home/remi/Desktop/content.jpg', x)

# extract the style of an image

Now, we will do the same but intread of extracting the style, we will extract the content of the image. Do do that, we will compute the Gram matrix from our convolution activation and then we will compare in to the gram matrice of our noise, trying to reduce the distance using backpropagation on the noise.

### forward hook
We will capture our activation for the input image:

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def close(self): self.hook.remove()

In [None]:
m_vgg = to_gpu(vgg16(True)).eval()
set_trainable(m_vgg, False)

In [None]:
#List the block we are interested in (=conv output)
block_ends = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
block_ends

In [None]:
sf = SaveFeatures(children(m_vgg)[block_ends[3]])

In [None]:
def get_opt():
    opt_img = np.random.uniform(0, 1, size=img.shape).astype(np.float32)
    opt_img = scipy.ndimage.filters.median_filter(opt_img, [8,8,1])
    opt_img_v = V(val_tfms(opt_img/2)[None], requires_grad=True)
    return opt_img_v, optim.LBFGS([opt_img_v])

In [None]:
opt_img_v, optimizer = get_opt()

In [None]:
m_vgg(VV(img_tfm[None]))
targ_v = V(sf.features.clone())
targ_v.shape

We will extract the style of this painting:

In [None]:
# img_fn = "/home/remi/Desktop/lion.jpg"
# lion = open_image(img_fn)
# plt.imshow(lion);

In [None]:
sz=288
trn_tfms,val_tfms = tfms_from_model(vgg16, sz)
img_tfm = val_tfms(img)
img_tfm.shape

In [None]:
def actn_loss2(x):
    m_vgg(x)
    out = V(sf.features)
    return F.mse_loss(out, targ_v)*1000

In [None]:
# n_iter=0
# while n_iter <= max_iter: optimizer.step(partial(step,actn_loss2))

In [None]:
x = val_tfms.denorm(np.rollaxis(to_np(opt_img_v.data),1,4))[0]
plt.figure(figsize=(7,7))
plt.imshow(x);

In [None]:
# sf.close()

## Style match

In [None]:
# wget https://raw.githubusercontent.com/jeffxtang/fast-style-transfer/master/images/starry_night.jpg
# style_fn = img_fn = "/home/remi/Desktop/lion.jpg"

In [None]:
# style_img = open_image(style_fn)
# style_img.shape, img.shape

In [None]:
# plt.imshow(style_img);

In [None]:
img = "/home/remi/Desktop/lion.jpg"
img = open_image(img)

In [None]:
def scale_match(src, targ):
    h,w,_ = src.shape
    sh,sw,_ = targ.shape
    rat = max(h/sh,w/sw); rat
    res = cv2.resize(targ, (int(sw*rat), int(sh*rat)))
    return res[:h,:w]

In [None]:
style = scale_match(img, style_img)
style.shape

In [None]:
plt.imshow(style)
style.shape, img.shape

In [None]:
opt_img_v, optimizer = get_opt()

In [None]:
sfs = [SaveFeatures(children(m_vgg)[idx]) for idx in block_ends]
m_vgg(VV(img_tfm[None]))
targ_vs = [V(o.features.clone()) for o in sfs]
[o.shape for o in targ_vs]
style_tfm = val_tfms(style_img)
m_vgg(VV(style_tfm[None]))
targ_styles = [V(o.features.clone()) for o in sfs]

In [None]:
[o.shape for o in targ_styles]

In [None]:
def gram(input):
        b,c,h,w = input.size()
        x = input.view(b*c, -1)
        return torch.mm(x, x.t())/input.numel()*1e6

def gram_mse_loss(input, target): return F.mse_loss(gram(input), gram(target))

In [None]:
def style_loss(x):
    m_vgg(opt_img_v)
    outs = [V(o.features) for o in sfs]
    losses = [gram_mse_loss(o, s) for o,s in zip(outs, targ_styles)]
    return sum(losses)

In [None]:
n_iter=0
while n_iter <= max_iter: optimizer.step(partial(step,style_loss))

In [None]:
x = val_tfms.denorm(np.rollaxis(to_np(opt_img_v.data),1,4))[0]
plt.figure(figsize=(7,7))
plt.imshow(x);

In [None]:
scipy.misc.imsave('/home/remi/Desktop/style1.jpg', x)

In [None]:
plt.savefig('/home/remi/Documents/code/Jupiter/fastai_old/courses/dl2/data/imagenet/data_style/style2.jpg')

In [None]:
for sf in sfs: sf.close()

##  Reproducing the content on one image with the style of a second image

On this part, we will assemble our two last section. We will take the style of the lion image, the content of the tree image in order to produce a new image.

In [None]:
# from fastai.conv_learner import *
# from pathlib import Path
# from scipy import ndimage
# torch.cuda.set_device(0)
# from torchvision import models
# torch.backends.cudnn.benchmark=True

In [None]:
m_vgg = to_gpu(vgg16(True)).eval()
set_trainable(m_vgg, False)

In [None]:
img_fn = "/home/remi/Desktop/tree.jpg"
img = open_image(img_fn)
plt.imshow(img);

In [None]:
style_fn = "/home/remi/Desktop/lion.jpg"
style_img = open_image(style_fn)
plt.imshow(style_img);

In [None]:
sz=288
trn_tfms,val_tfms = tfms_from_model(vgg16, sz)
img_tfm = val_tfms(img)

In [None]:
class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def close(self): self.hook.remove()

In [None]:
def get_opt():
    opt_img = np.random.uniform(0, 1, size=img.shape).astype(np.float32)
    opt_img = scipy.ndimage.filters.median_filter(opt_img, [8,8,1])
    opt_img_v = V(val_tfms(opt_img/2)[None], requires_grad=True)
    return opt_img_v, optim.LBFGS([opt_img_v])

In [None]:
def step(loss_fn):
    global n_iter
    optimizer.zero_grad()
    loss = loss_fn(opt_img_v)
    loss.backward()
    n_iter+=1
    if n_iter%show_iter==0: print(f'Iteration: {n_iter}, loss: {loss.data[0]}')
    return loss

In [None]:
def gram(input):
        b,c,h,w = input.size()
        x = input.view(b*c, -1)
        return torch.mm(x, x.t())/input.numel()*1e6

def gram_mse_loss(input, target): return F.mse_loss(gram(input), gram(target))

In [None]:
sz=288
trn_tfms,val_tfms = tfms_from_model(vgg16, sz)

opt_img_v, optimizer = get_opt()

Set new random image

In [None]:
block_ends = [i-1 for i,o in enumerate(children(m_vgg))
              if isinstance(o,nn.MaxPool2d)]
block_ends

In [None]:
#style_tfm = val_tfms(style_img)
#m_vgg(VV(style_tfm[None]))
#targ_styles = [V(o.features.clone()) for o in sfs]


sfs = [SaveFeatures(children(m_vgg)[idx]) for idx in block_ends]
m_vgg(VV(img_tfm[None]))
targ_vs = [V(o.features.clone()) for o in sfs]
[o.shape for o in targ_vs]
style_tfm = val_tfms(style_img)
m_vgg(VV(style_tfm[None]))
targ_styles = [V(o.features.clone()) for o in sfs]

In [None]:
targ_styles = [V(o.features.clone()) for o in sfs]

def comb_loss(x):
    
    coef = Variable(torch.from_numpy(np.array([0.2])), requires_grad=False).float()
    m_vgg(opt_img_v)
    outs = [V(o.features) for o in sfs]
    #* coef.expand_as(gram_mse_loss(o, s))
    losses = [gram_mse_loss(o, s) for o,s in zip(outs, targ_styles)]
    cnt_loss   = F.mse_loss(outs[0], targ_vs[0])*100000*2
    style_loss = sum(losses)
    return cnt_loss + style_loss

In [None]:
n_iter=0
max_iter=250
show_iter=10
while n_iter <= max_iter: optimizer.step(partial(step,comb_loss))

In [None]:
x = val_tfms.denorm(np.rollaxis(to_np(opt_img_v.data),1,4))[0]
plt.figure(figsize=(9,9))
plt.imshow(x, interpolation='lanczos')
plt.axis('off');

In [None]:
for sf in sfs: sf.close()

In [None]:
scipy.misc.imsave('/home/remi/Desktop/transfered.jpg', x)