In [None]:
%load_ext autoreload
%autoreload 2
import torch
from models.vgg19style import VGG19Style, VGG19Style2
from lib.fashionpedia_processed import FashionPediaProcessed
from tqdm.autonotebook import tqdm

In [None]:
# load model
model = VGG19Style2(32)

In [None]:
# load dataset
data = FashionPediaProcessed()
data_loader = torch.utils.data.DataLoader(
    data, batch_size=16, shuffle=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001)
loss_fn = torch.nn.BCELoss()

def train_one_epoch():
    running_loss = 0.
    num_batches = 0

    for batch in tqdm(data_loader):
        optimizer.zero_grad()

        outputs = model(batch['img'])

        # Compute the loss and its gradients
        loss = loss_fn(outputs, batch['att_oh'])
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        num_batches += 1

    return running_loss / num_batches

In [None]:
EPOCHS = 200

best_test_loss = 1_000_000.

for epoch in range(EPOCHS):
    print('\nEPOCH {}:'.format(epoch + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch()

    # Log the running loss averaged per batch
    # for both training and validation
    print(f'Training Loss: {avg_loss}')

In [None]:
def content_loss(target, generated):
	layers = [] # Todo layers
	total = 0
	for l in layers:
		loss = nn.MSELoss()
		total += loss(target[l], generated[l])
	
	return total*0.5

In [None]:
def style_loss(style, generated):
	layers = []  # Todo layers
	total = 0
	weight = 1 / len(layers) # Todo weight
	for l in layers:
		loss = nn.MSELoss()
		output = loss(style[l], generated[l])
		N = l.N # Todo number of distinct filter/feature maps in layer
		M = l.M # Todo height times width of feature maps
		total += weight * output / (4 * N**2 * M**2)

	return total

In [None]:
def total_loss(target, style, generated, alpha, beta):
    return alpha * content_loss(target, generated) + beta * style_loss(style, generated)

In [None]:
import matplotlib.pyplot as plt
import json

target_item = data[2]

target = target_item['img'].clone().requires_grad_(True)

with open('fashionpedia/selected_attributes.json') as f:
    s_att = list(json.load(f).values())

print('attributes:', [s_att[i] for i, v in enumerate(target_item['att_oh']) if v])

for i in range(2000):
    if i % 500 == 0:
        plt.imshow(target.detach().permute(1,2,0))
        plt.show()
