In [1]:
import torch 
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
from skimage import io
import torchvision.models as models

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_epochs = 20
batch_size = 8
learning_rate = 1e-4
lamda = 0.01
content_path = '../input/style-transfer/CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/'
style_path = '../input/style-transfer/style/style/'

In [16]:
class ContentDataset(data.Dataset):
    def __init__(self, transform=transforms.ToTensor()):
        super().__init__()
        self.content_images = [f for f in os.listdir(content_path)]
        self.transform = transform

    def __getitem__(self, index):
        content_img = io.imread(content_path + self.content_images[index])
        content_img = self.transform(content_img)
        return content_img

    def __len__(self):
        return len(self.content_images) 

class StyleDataset(data.Dataset):
    def __init__(self, transform=transforms.ToTensor()):
        super().__init__()
        self.style_images = [f for f in os.listdir(style_path)]
        self.transform = transform

    def __getitem__(self, index):
        style_img = io.imread(style_path + self.style_images[index])
        style_img = self.transform(style_img)
        return style_img

    def __len__(self):
        return len(self.style_images) 

transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Resize(512),
                                        transforms.RandomCrop(256)])
content_dataset = ContentDataset(transform=transform)
style_dataset = StyleDataset(transform=transform)
content_loader = data.DataLoader(dataset=content_dataset, batch_size=batch_size, shuffle=True)
style_loader = data.DataLoader(dataset=style_dataset, batch_size=batch_size, shuffle=True)

In [5]:
def mu(x):
    size = x.size()
    assert (len(size) == 4)
    N, C = size[:2]
    return x.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)

def sigma(x, eps=1e-5):
    size = x.size()
    assert (len(size) == 4)
    N, C = size[:2]
    var = x.view(N, C, -1).var(dim=2) + eps
    std = var.sqrt().view(N, C, 1, 1)
    return std

In [6]:
class AdaIN(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, y):
        assert (x.size()[:2] == y.size()[:2])
        size = x.size()
        style_mean, style_std = mu(y), sigma(y)
        content_mean, content_std = mu(x), sigma(x)
        normalized_feat = (x - content_mean.expand(size)) / content_std.expand(size)
        return normalized_feat * style_std.expand(size) + style_mean.expand(size)


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        layers = [nn.Conv2d(3, 3, 1)]
        vgg_features = models.vgg19(pretrained=True).features.children()
        for layer in vgg_features:
            layers.append(layer)
            if isinstance(layer, torch.nn.Conv2d):
                layer.padding_mode = 'reflect'

        self.net = nn.Sequential(*(layers[: 22]))
    
    def forward(self, x):
        phi_1, phi_2 ,phi_3 ,phi_4 = None, None, None, None
        for i, layer in enumerate(self.net):          
            x = layer(x)
            if i == 2:      
                phi_1 = x
            elif i == 7:
                phi_2 = x
            elif i == 12:
                phi_3 = x
            elif i == 21: 
                phi_4 = x

        return phi_1, phi_2, phi_3, phi_4

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),
            nn.ReLU())

    def forward(self, x):
        return self.net(x)

adain = AdaIN().to(device)
decoder = Decoder().to(device)
encoder = Encoder().to(device)
for param in encoder.parameters():
    param.requires_grad = False

In [None]:
decoder.train()
mse = nn.MSELoss()
optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    epoch_loss = 0
    for c, s in zip(content_loader, style_loader):
        c = c.to(device) 
        s = s.to(device)
        _, _, _, f_c = encoder(c)
        phi_1_s, phi_2_s, phi_3_s, phi_4_s = encoder(s)
        t = adain(f_c, phi_4_s)
        g = decoder(t)
        phi_1_g, phi_2_g, phi_3_g, phi_4_g = encoder(g)
        content_loss = mse(phi_4_g, t)
        mu_loss = mse(mu(phi_1_g), mu(phi_1_s)) + mse(mu(phi_2_g), mu(phi_2_s)) + mse(mu(phi_3_g), mu(phi_3_s)) + mse(mu(phi_4_g), mu(phi_4_s))
        std_loss = mse(sigma(phi_1_g), sigma(phi_1_s)) + mse(sigma(phi_2_g), sigma(phi_2_s)) + mse(sigma(phi_3_g), sigma(phi_3_s)) + mse(sigma(phi_4_g), sigma(phi_4_s))
        style_loss = mu_loss + std_loss
        loss = content_loss + lamda * style_loss
        epoch_loss = epoch_loss + loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch:', epoch, 'loss:', epoch_loss)
    c = c.to(device) 
    s = s.to(device)
    _, _, _, f_c = encoder(c)
    _, _, _, phi_4_s = encoder(s)
    t = adain(f_c, phi_4_s)
    g = decoder(t)
    content = c[0].permute(1, 2, 0).cpu().detach().numpy()
    style = s[0].permute(1, 2, 0).cpu().detach().numpy()
    out = g[0].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(content)
    plt.show()
    plt.imshow(style)
    plt.show()
    plt.imshow(out)
    plt.show()
    torch.save(decoder.state_dict(), './params.pt')
    torch.save(optimizer.state_dict(), './opt.pt')