0. Package Installation

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [2]:
from utils.imageloader import load_images
from utils.dataloader import load_data
from utils.normalize import batch_normalize
from utils.gram_matrix import gram_matrix
from model.VGG16 import VGG16
from model.TransformerNet import TransformerNet

1. Hyperparameter Setting

In [3]:
batch_size = 10
num_epoch = 10
learning_rate = 1e-4

2. Style Images and Train Data Loading

In [4]:
style_data = load_images('./data/', 'summer', batch_size)
print(style_data.shape)

train_dataset, train_dataloader = load_data('./data/', batch_size)
print(train_dataset[0][0].shape)

torch.Size([100, 3, 256, 256])
torch.Size([3, 256, 256])


3. Style Transform with gram

In [5]:
transformer = TransformerNet()
vgg = VGG16(requires_grad=False)

features_style = vgg(batch_normalize(style_data))
gram_style = [gram_matrix(y) for y in features_style]

4. TransformerNet training with train data

In [6]:
optimizer = optim.Adam(transformer.parameters(), lr=learning_rate)
loss_function = nn.MSELoss()

In [7]:
def train(model, train_data):
    for epoch in range(num_epoch):
        model.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0

        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = normalize_batch(y)
            x = normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            content_loss = content_weight * loss_function(features_y.relu2_2, features_x.relu2_2)

            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = gram_matrix(ft_y)
                style_loss += loss_function(gm_y, gm_s[:n_batch, :, :])
            style_loss *= style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % log_interval == 0:
                msg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), epoch + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1)
                )
                print(msg)

5. Test