In [None]:
%load_ext autoreload
%autoreload 2
import torch
from models.vgg19style import VGG19Style
from lib.fashionpedia_processed import FashionPediaProcessed
from tqdm.autonotebook import tqdm

In [None]:
# load model
model = VGG19Style()

In [None]:
# load dataset
data = FashionPediaProcessed()
data_loader = torch.utils.data.DataLoader(
    data, batch_size=16, shuffle=True)

In [None]:
def total_loss(target, style, generated, alpha, beta):
    return alpha * content_loss(target, generated) + beta * style_loss(style, generated)

In [None]:
import matplotlib.pyplot as plt
import json

style_img = data[7]
content_img = data[2]

with open('fashionpedia/selected_attributes.json') as f:
    s_att = list(json.load(f).values())

print('attributes:', [s_att[i] for i, v in enumerate(content_img['att_oh']) if v])

plt.imshow(data.invImg(content_img))
plt.title('Content image')
plt.show()

print('attributes:', [s_att[i] for i, v in enumerate(style_img['att_oh']) if v])

plt.imshow(data.invImg(style_img))
plt.title('Style image')
plt.show()

In [None]:
# target_img = target_item['img'].clone().requires_grad_(True)
target_img = torch.normal(0, 1, content_img['img'].shape).requires_grad_(True)

plt.imshow(target_img.detach().permute(1, 2, 0))

In [None]:
def gram_matrix(tensor: torch.Tensor):
    depth = tensor.shape[0]
    tensor = tensor.view(depth, -1)
    return torch.mm(tensor, tensor.t()) 

In [None]:
activations = {}
def set_activation(name):
    return lambda _, __, output: activations.update({name: output})

# content layer
model.layers[21].register_forward_hook(set_activation('conv_4_2'))

# style layers
style_layers = {0: 'conv_1_1', 5: 'conv_2_1', 10: 'conv_3_1', 19: 'conv_4_1', 28: 'conv_5_1'}
for i, name in style_layers.items():
    model.layers[i].register_forward_hook(set_activation(name))

model(content_img['img'])

content_img_feature = activations['conv_4_2']

model(style_img['img'])

style_img_grams = {i: gram_matrix(activations[name]) for i, name in style_layers.items()}

In [None]:
import torch.nn as nn

def get_content_loss(content_img_feature, target_img_feature):
	return torch.sum((content_img_feature - target_img_feature) ** 2) / 2

In [None]:
from tqdm import trange

optimizer = torch.optim.Adam([target_img], lr=0.003)
iterations = 3000

for i in trange(1, iterations):
    optimizer.zero_grad()

    model(target_img)

    target_content_feature = activations['conv_4_2']

    content_loss = get_content_loss(content_img_feature, target_content_feature)

    content_loss.backward()
    optimizer.step()

    if i % 100 == 0:
        print('loss:', content_loss.item())
        plt.imshow(target_img.detach().permute(1,2,0))
        plt.show()

In [None]:
def get_style_loss(style_img_grams, target_img_features):
	loss = 0

	for i in style_layers.keys():
		target_gram = gram_matrix(target_img_features[i])
		squared_err = torch.sum((target_gram - style_img_grams[i]) ** 2)

		_, height, width = target_img_features[i].shape
		loss += squared_err / (4 * height ** 2 * width ** 2)

	return loss / len(style_layers)

In [None]:
from tqdm import trange

optimizer = torch.optim.Adam([target_img], lr=0.003)
iterations = 12000

for i in trange(1, iterations):
    optimizer.zero_grad()

    model(target_img)

    target_style_features = {i: activations[name] for i, name in style_layers.items()}

    style_loss = get_style_loss(style_img_grams, target_style_features)

    style_loss.backward()
    optimizer.step()

    if i % 500 == 0:
        print('loss:', style_loss.item())
        plt.imshow(target_img.detach().permute(1,2,0))
        plt.show()