In [None]:
import time
import math
from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision

In [None]:
class AdaIN(nn.Module):
    def forward(self, contents, styles):
        """
        Args:
            contents (torch.Tensor) : 4-dimensional tensor (N, C, H, W)
            styles (torch.Tensor)   : 4-dimensional tensor (N, C, H, W)
        """
        c_mean = contents.mean(dim=(2, 3), keepdim=True)
        c_var = contents.var(dim=(2, 3), keepdim=True)
        s_mean = styles.mean(dim=(2, 3), keepdim=True)
        s_var = styles.var(dim=(2, 3), keepdim=True)
        return s_var * (contents - c_mean) / (c_var + 1.0e-5) + s_mean

    
class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True)
        self.features = model.features[:21]
        
    def forward(self, inputs):
        outputs = []
        x = self.features[1](self.features[0](inputs))
        outputs.append(x)
        x = self.features[3](self.features[2](x))
        x = self.features[4](x)
        x = self.features[6](self.features[5](x))
        outputs.append(x)
        x = self.features[8](self.features[7](x))
        x = self.features[9](x)
        x = self.features[11](self.features[10](x))
        outputs.append(x)
        x = self.features[13](self.features[12](x))
        x = self.features[15](self.features[14](x))
        x = self.features[17](self.features[16](x))
        x = self.features[18](x)
        x = self.features[20](self.features[19](x))
        outputs.append(x)
        return outputs
    
class VGGDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2, 2), mode="nearest"),
            nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2, 2), mode="nearest"),
            nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1,1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2, 2), mode="nearest"),
            nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return self.decoder(x)
    
class AdaINNetwork(nn.Module):
    def __init__(self, encoder=VGGEncoder(), decoder=VGGDecoder()):
        super().__init__()
        self.encoder = encoder
        self.adain = AdaIN()
        self.decoder = decoder
        
    def forward(self, contents, styles):
        c_features = self.encoder(contents)
        s_features = self.encoder(styles)
        adain_feat = self.adain(c_features[-1], s_features[-1])
        output = self.decoder(adain_feat)
        return output, adain_feat, c_features, s_features

In [None]:
class StyleLoss(nn.Module):
    def forward(self, target_features, adain_feature, style_features):
        total_loss = torch.dist(adain_feature, target_features[-1], 2)
        for t_feat, s_feat in zip(target_features, style_features):
            mean_loss = torch.dist(t_feat.mean(dim=(2, 3)), s_feat.mean(dim=(2, 3)), 2)
            var_loss = torch.dist(t_feat.var(dim=(2, 3)), s_feat.var(dim=(2, 3)), 2)
            total_loss = total_loss + mean_loss + var_loss
        return total_loss

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_root_dir, transforms=None):
        self.image_paths = list(image_root_dir.glob("*.jpg"))
        self.transforms = transforms
        
    def expand(self, scale):
        self.image_paths *= scale
        
    def __len__(self):
        return len(self.image_paths)
        
    def __getitem__(self, index):
        x = Image.open(str(self.image_paths[index]))

        if self.transforms:
            x = self.transforms(x)
        return x

In [None]:
encoder = VGGEncoder()
net = AdaINNetwork()
optimizer = torch.optim.Adam(net.decoder.parameters(), lr=0.001)
criterion = StyleLoss()

transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.RandomCrop((256, 256)),
    torchvision.transforms.ToTensor(),
])

batch_size = 8
content_dataset = ImageDataset(Path.home() / "dataset/COCO/train2014/Resized512Color", transform)
style_dataset = ImageDataset(Path.home() / "dataset/ArtWiki/Resized512", transform)
style_dataset.expand(math.ceil(len(content_dataset) / len(style_dataset)))
content_loader = DataLoader(
    dataset=content_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
style_loader = DataLoader(
    dataset=style_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

n_epochs = 10
es_patience = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = encoder.to(device)
net = net.to(device)

In [None]:
patience = es_patience
for epoch in range(n_epochs):
    start_time = time.time()
    epoch_loss = 0.0
    iteration = 0
    for contents, styles in zip(content_loader, style_loader):
        # assert contents.shape == (8, 3, 256, 256)
        # assert styles.shape == (8, 3, 256, 256)
        contents = contents.to(device=device, dtype=torch.float32)
        styles = styles.to(device=device, dtype=torch.float32)
        optimizer.zero_grad()
        output, adain_feat, c_features, s_features = net(contents, styles)
        # assert output.shape == (8, 3, 256, 256)
        target_features = encoder(output)
        # assert len(target_features) == len(s_features)
        # assert target_features[-1].shape == s_features[-1].shape
        loss = criterion(target_features, adain_feat, s_features)
        with torch.autograd.set_detect_anomaly(True):
            loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        iteration += 1
        print("\rloss = {:.3f}".format(epoch_loss / iteration), end="")

    model_output_path = "model_{}epoch.pth".format(epoch + 1)
    torch.save(net, model_output_path)
    epoch_loss /= iteration
    print("\nFinish Epoch {} / {}, Loss = {}".format(epoch + 1, n_epochs, epoch_loss))
    
    if epoch == 0:
        best_loss = epoch_loss
    if epoch_loss - 1.0e-5 < best_loss:
        torch.save(net, "model_bestloss.pth")
        patience = es_patience
        best_loss = epoch_loss
    else:
        patience -= 1
        if patience == 0:
            print("Eary Stopping at Epoch {}".format(epoch + 1))
            break