# Download Dataset

In [0]:
!mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip

--2020-06-08 07:20:55--  https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
Resolving s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)... 52.219.120.176
Connecting to s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)|52.219.120.176|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1443490838 (1.3G) [application/zip]
Saving to: ‘celeba.zip’


2020-06-08 07:21:28 (41.8 MB/s) - ‘celeba.zip’ saved [1443490838/1443490838]



In [0]:
import zipfile

with zipfile.ZipFile("celeba.zip","r") as zip_ref:
    zip_ref.extractall("data_faces/")

In [0]:
import os
root = 'data_faces/img_align_celeba'
img_list = os.listdir(root)
print(len(img_list))

202599


# Download Code

In [0]:
!git clone https://github.com/eriklindernoren/Fast-Neural-Style-Transfer.git

Cloning into 'Fast-Neural-Style-Transfer'...
remote: Enumerating objects: 94, done.[K
remote: Total 94 (delta 0), reused 0 (delta 0), pack-reused 94[K
Unpacking objects: 100% (94/94), done.


In [0]:
!pip install av

Collecting av
[?25l  Downloading https://files.pythonhosted.org/packages/9e/62/9a992be76f8e13ce0e3a24a838191b546805545116f9fc869bd11bd21b5f/av-8.0.2-cp36-cp36m-manylinux2010_x86_64.whl (36.9MB)
[K     |████████████████████████████████| 36.9MB 81kB/s 
[?25hInstalling collected packages: av
Successfully installed av-8.0.2


# Training

In [0]:

import os
import sys
import random
from PIL import Image
import numpy as np
import torch
import glob
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image
from src.models import TransformerNet, VGG16
from src.utils import *

In [0]:
dataset_path = "data_faces"
style_image="src/images/styles/starry_night.jpg"
epochs=1
batch_size=4
image_size=256
style_size=256
lambda_content=1e5
lambda_style=1e10
lr=1e-3
checkpoint_model=""
checkpoint_interval=2000
sample_interval=1000

In [0]:
style_name = style_image.split("/")[-1].split(".")[0]
os.makedirs(f"images/outputs/{style_name}-training", exist_ok=True)
os.makedirs(f"checkpoints", exist_ok=True)

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [0]:
# Create dataloader for the training data
train_dataset = datasets.ImageFolder(dataset_path, train_transform(image_size))
dataloader = DataLoader(train_dataset, batch_size=batch_size)

In [0]:
# Defines networks
transformer = TransformerNet().to(device)
vgg = VGG16(requires_grad=False).to(device)

In [0]:
# Load checkpoint model if specified
if checkpoint_model:
    print(True)
    transformer.load_state_dict(torch.load(checkpoint_model))

In [0]:
# Define optimizer and loss
optimizer = Adam(transformer.parameters(), lr)
l2_loss = torch.nn.MSELoss().to(device)

In [0]:
# Load style image
style = style_transform(style_size)(Image.open(style_image))
style = style.repeat(batch_size, 1, 1, 1).to(device)

In [0]:
# Extract style features
features_style = vgg(style)
gram_style = [gram_matrix(y) for y in features_style]

In [0]:
# Sample 8 images for visual evaluation of the model
image_samples = []
for path in random.sample(glob.glob(f"{dataset_path}/*/*.jpg"), 8):
    image_samples += [style_transform(image_size)(Image.open(path))]
image_samples = torch.stack(image_samples)

In [0]:
def save_sample(batches_done):
    """ Evaluates the model and saves image samples """
    transformer.eval()
    with torch.no_grad():
        output = transformer(image_samples.to(device))
    image_grid = denormalize(torch.cat((image_samples.cpu(), output.cpu()), 2))
    save_image(image_grid, f"images/outputs/{style_name}-training/{batches_done}.jpg", nrow=4)
    transformer.train()

In [0]:
for epoch in range(epochs):
    epoch_metrics = {"content": [], "style": [], "total": []}
    for batch_i, (images, _) in enumerate(dataloader):
        optimizer.zero_grad()

        images_original = images.to(device)
        images_transformed = transformer(images_original)

        # Extract features
        features_original = vgg(images_original)
        features_transformed = vgg(images_transformed)

        # Compute content loss as MSE between features
        content_loss = lambda_content * l2_loss(features_transformed.relu2_2, features_original.relu2_2)

        # Compute style loss as MSE between gram matrices
        style_loss = 0
        for ft_y, gm_s in zip(features_transformed, gram_style):
            gm_y = gram_matrix(ft_y)
            style_loss += l2_loss(gm_y, gm_s[: images.size(0), :, :])
        style_loss *= lambda_style

        total_loss = content_loss + style_loss
        total_loss.backward()
        optimizer.step()

        epoch_metrics["content"] += [content_loss.item()]
        epoch_metrics["style"] += [style_loss.item()]
        epoch_metrics["total"] += [total_loss.item()]

        sys.stdout.write(
          "\r[Epoch %d/%d] [Batch %d/%d] [Content: %.2f (%.2f) Style: %.2f (%.2f) Total: %.2f (%.2f)]"
          % (
              epoch + 1,
              epochs,
              batch_i,
              len(train_dataset),
              content_loss.item(),
              np.mean(epoch_metrics["content"]),
              style_loss.item(),
              np.mean(epoch_metrics["style"]),
              total_loss.item(),
              np.mean(epoch_metrics["total"]),
          )
        )

        batches_done = epoch * len(dataloader) + batch_i + 1
        if batches_done % sample_interval == 0:
            save_sample(batches_done)

        if checkpoint_interval > 0 and batches_done % checkpoint_interval == 0:
            style_name = os.path.basename(style_image).split(".")[0]
            torch.save(transformer.state_dict(), f"checkpoints/{style_name}_{batches_done}.pth")


[Epoch 1/1] [Batch 4036/202599] [Content: 838704.31 (879576.79) Style: 166427.11 (491089.53) Total: 1005131.44 (1370666.32)]

KeyboardInterrupt: ignored

In [0]:
!wget https://i2.wp.com/www.agendavisual.es/wp-content/uploads/2013/11/Shipwreck_turner.jpg

--2020-06-08 08:42:03--  https://i2.wp.com/www.agendavisual.es/wp-content/uploads/2013/11/Shipwreck_turner.jpg
Resolving i2.wp.com (i2.wp.com)... 192.0.77.2
Connecting to i2.wp.com (i2.wp.com)|192.0.77.2|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 379606 (371K) [image/jpeg]
Saving to: ‘Shipwreck_turner.jpg’


2020-06-08 08:42:04 (7.29 MB/s) - ‘Shipwreck_turner.jpg’ saved [379606/379606]



In [0]:
import shutil
output_filename = "checkpoints"
dir_name = "checkpoints"
shutil.make_archive(output_filename, 'zip', dir_name)

'/content/checkpoints.zip'