In [26]:
# !wget -qnc https://raw.githubusercontent.com/greentfrapp/lucent/master/images/transfer_big_ben.png
# !wget -qnc https://raw.githubusercontent.com/greentfrapp/lucent/master/images/transfer_vangogh.png
# !wget -qnc https://raw.githubusercontent.com/greentfrapp/lucent/master/images/transfer_picasso.png
# !wget -qnc https://pbs.twimg.com/media/EEHUTeHUcAAq_R5.jpg

In [2]:
import torch
from PIL import Image
import numpy as np

from lucent.optvis import render, param, transform, objectives
from lucent.modelzoo import inceptionv1
from lucent.misc.io import show
from lucent.optvis.objectives import wrap_objective

In [3]:
model = inceptionv1(pretrained=True)
model.cuda().eval();

In [17]:
def load(path):
    return np.array(Image.open(path)) / 255

def load_resize(path):
    return np.array(Image.open(path).convert('RGB').resize((512, 512))) / 255

In [34]:
content_image = load_resize('./chest_xray/val/NORMAL/NORMAL2-IM-1427-0001.jpeg')
style_image = load("transfer_vangogh.png")

print(content_image.shape, style_image.shape)

show(content_image)
show(style_image)

(512, 512, 3) (512, 645, 3)


In [19]:
style_layers = [
  'conv2d2',
  'mixed3a',
  'mixed4a',
  'mixed4b',
  'mixed4c',
]

content_layers = [
  'mixed3b',
]

In [20]:
def style_transfer_param(content_image, style_image, decorrelate=True, fft=True):
    shape = content_image.shape[:2] # assume we use content_image.shape
    params, image = param.image(*shape, decorrelate=decorrelate, fft=fft)
    def inner():
        style_transfer_input = image()[0]
        content_input = torch.tensor(np.transpose(content_image, [2, 0, 1])).float().cuda()
        style_input = torch.tensor(np.transpose(style_image[:shape[0], :shape[1], :], [2, 0, 1])).float().cuda()
        return torch.stack([style_transfer_input, content_input, style_input])
    return params, inner

# following the original Lucid notebook,
# these constants help remember which image is at which batch dimension
TRANSFER_INDEX = 0
CONTENT_INDEX = 1
STYLE_INDEX = 2

In [21]:
def mean_L1(a, b):
    return torch.abs(a-b).mean()

In [22]:
@wrap_objective()
def activation_difference(layer_names, activation_loss_f=mean_L1, transform_f=None, difference_to=CONTENT_INDEX):
    def inner(T):
        # first we collect the (constant) activations of image we're computing the difference to
        image_activations = [T(layer_name)[difference_to] for layer_name in layer_names]
        if transform_f is not None:
            image_activations = [transform_f(act) for act in image_activations]

        # we also set get the activations of the optimized image which will change during optimization
        optimization_activations = [T(layer)[TRANSFER_INDEX] for layer in layer_names]
        if transform_f is not None:
            optimization_activations = [transform_f(act) for act in optimization_activations]

        # we use the supplied loss function to compute the actual losses
        losses = [activation_loss_f(a, b) for a, b in zip(image_activations, optimization_activations)]
        return sum(losses)

    return inner

In [23]:
def gram_matrix(features, normalize=True):
    C, H, W = features.shape
    features = features.view(C, -1)
    gram = torch.matmul(features, torch.transpose(features, 0, 1))
    if normalize:
        gram = gram / (H * W)
    return gram

In [24]:
param_f = lambda: style_transfer_param(content_image, style_image)

content_weight = 200
style_weight = 1

content_obj = activation_difference(content_layers, difference_to=CONTENT_INDEX)
content_obj.description = "Content Loss"

style_obj = activation_difference(style_layers, transform_f=gram_matrix, difference_to=STYLE_INDEX)
style_obj.description = "Style Loss"

objective = content_weight * content_obj + style_weight * style_obj

vis = render.render_vis(model, objective, param_f, show_inline=True)

100%|██████████| 512/512 [01:09<00:00,  7.36it/s]


In [37]:
content_image = load_resize("transfer_vangogh.png")
style_image = load_resize('./cxr_color/val/NORMAL/NORMAL2-IM-1427-0001.jpeg')

param_f = lambda: style_transfer_param(content_image, style_image)

vis = render.render_vis(model, objective, param_f, show_inline=True)

100%|██████████| 512/512 [01:09<00:00,  7.41it/s]


In [32]:
style_image = load_resize("./chest_xray/val/PNEUMONIA/person1947_bacteria_4876.jpeg")

param_f = lambda: style_transfer_param(content_image, style_image)

content_weight = 2000
style_weight = 1

objective = content_weight * content_obj + style_weight * style_obj

vis = render.render_vis(model, objective, param_f, show_inline=True)

100%|██████████| 512/512 [01:09<00:00,  7.32it/s]
