In [None]:
import os
import glob
import random
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch import optim
from torch.autograd import Variable
from torch import autograd
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import itertools

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import logging
from tqdm import tqdm
import math
from statistics import mean

In [None]:
def same_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(0)

In [None]:
tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])
class ImgDataset(Dataset):
    def __init__(self, tfm=tfm):
        super(ImgDataset, self).__init__()
        self.tfm = tfm
        self.content_list = [os.path.join(os.getcwd(), './dataset/photo_jpg/', file) for file in os.listdir('./dataset/photo_jpg/')]
        self.style_list = [os.path.join(os.getcwd(), './dataset/monet_jpg/', file) for file in os.listdir('./dataset/monet_jpg/')]
        self.len = len(self.content_list)

    def __getitem__(self, idx):
        content = Image.open(self.content_list[idx])
        content = self.tfm(content)
        style = Image.open(random.choice(self.style_list))
        style = self.tfm(style)
        return style, content

    def __len__(self):
        return self.len

class SampleDataset(Dataset):
    def __init__(self, path, tfm=tfm, mx_len=64):
        super(SampleDataset, self).__init__()
        self.tfm = tfm
        self.file_list = [os.path.join(os.getcwd(), path, file) for file in os.listdir(path)][:mx_len]
        self.len = mx_len

    def __getitem__(self, idx):
        x = Image.open(self.file_list[idx])
        x = self.tfm(x)
        return x

    def __len__(self):
        return self.len

In [None]:
def conv_istn_lrelu(in_dim, out_dim, ist_norm=True):
    if ist_norm:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_dim,
                      out_channels=out_dim,
                      kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(out_dim, affine = True),
            nn.LeakyReLU(0.2),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_dim,
                      out_channels=out_dim,
                      kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )

def convt_istn_relu(in_dim, out_dim):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_dim,
                          out_channels=out_dim,
                          kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(out_dim, affine=True),
        nn.ReLU(),
    )

def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean=0.0, std=1.0)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean=0.0, std=1.0)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.ConvTranspose2d):
        m.weight.data.normal_(mean=0.0, std=1.0)
        if m.bias is not None:
            m.bias.data.zero_()

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.ae = nn.Sequential(
            conv_istn_lrelu(3, 64, ist_norm = False),
            conv_istn_lrelu(64, 128),
            conv_istn_lrelu(128, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
            convt_istn_relu(256, 256),
            convt_istn_relu(256, 256),
            convt_istn_relu(256, 256),
            convt_istn_relu(256, 256),
            convt_istn_relu(256, 128),
            convt_istn_relu(128, 64),
            nn.ConvTranspose2d(in_channels=64,
                              out_channels=3,
                              kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
        self.apply(init_weights)

    def forward(self, x):
        return self.ae(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.cnn = nn.Sequential(
            conv_istn_lrelu(3, 64),
            conv_istn_lrelu(64, 128),
            conv_istn_lrelu(128, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
            conv_istn_lrelu(256, 256),
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
        )
        self.apply(init_weights)
    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x.view(x.size()[0], -1))
        return x

In [None]:
class CycleGAN():
    def __init__(self):
        self.config = {
            "name":"cycleGAN",
            "device": "cuda:1" if torch.cuda.is_available() else "cpu",
            "epoch": 50,
            "batch_size": 16,
            "g_lr": 1e-4,
            "m_d_lr": 1e-4,
            "p_d_lr": 1e-4,
        }
        
        # dataloader
        dataset = ImgDataset()
        self.dataloader = DataLoader(dataset, batch_size=self.config["batch_size"], shuffle=True, num_workers=8)
        
        photo_dataset = SampleDataset('./dataset/photo_jpg')
        self.photoloader = DataLoader(photo_dataset, batch_size=self.config["batch_size"], shuffle=False, num_workers=8)
        
        monet_dataset = SampleDataset('./dataset/monet_jpg')
        self.monetloader = DataLoader(monet_dataset, batch_size=self.config["batch_size"], shuffle=False, num_workers=8)

        # models
        self.M2P = Generator().to(self.config["device"])
        self.M_D = Discriminator().to(self.config["device"])
        self.P2M = Generator().to(self.config["device"])
        self.P_D = Discriminator().to(self.config["device"])
        
        # optimizer
        self.opt_G = torch.optim.Adam(itertools.chain(self.M2P.parameters(), self.P2M.parameters()), lr=self.config["g_lr"], betas=(0.5, 0.999))
        self.opt_M_D = torch.optim.Adam(self.M_D.parameters(), lr=self.config["m_d_lr"], betas=(0.5, 0.999))
        self.opt_P_D = torch.optim.Adam(self.P_D.parameters(), lr=self.config["p_d_lr"], betas=(0.5, 0.999))
        
        # loss
        self.L1 = nn.L1Loss()
        self.L2 = nn.MSELoss()
        
        os.makedirs(f'cycle_checkpoints/', exist_ok=True)
        
    def train(self):
        for e in range(self.config["epoch"]):
            M_D_loss = []
            P_D_loss = []
            id_loss_list = []
            GAN_loss_list = []
            cycle_loss_list = []

            for m, p in tqdm(self.dataloader):
                m = m.to(self.config["device"])
                p = p.to(self.config["device"])

                # Discriminator Monet
                r_logits = self.M_D(m)
                f_logits = self.M_D(self.P2M(p).detach())
                loss_D = self.L2(r_logits, torch.ones_like(r_logits)) +\
                            self.L2(f_logits, torch.zeros_like(f_logits))
                self.opt_M_D.zero_grad()
                loss_D.backward()
                self.opt_M_D.step()
                M_D_loss += [loss_D.item()]
                
                # Discrminator Photo
                r_logits = self.P_D(p)
                f_logits = self.P_D(self.M2P(m).detach())
                loss_D = self.L2(r_logits, torch.ones_like(r_logits).to(self.config["device"])) +\
                            self.L2(f_logits, torch.zeros_like(f_logits).to(self.config["device"]))
                self.opt_P_D.zero_grad()
                loss_D.backward()
                self.opt_P_D.step()
                P_D_loss += [loss_D.item()]
                
                # Generator
                m_prime = self.P2M(m)
                p_prime = self.M2P(p)
                m_f = self.P2M(p)
                p_f = self.M2P(m)
                m_rec = self.P2M(p_f)
                p_rec = self.M2P(m_f)
                
                m_logits = self.M_D(m_f).detach()
                p_logits = self.P_D(p_f).detach()
                
                id_loss = self.L1(m_prime, m) + self.L1(p_prime, p)
                gan_loss = self.L2(m_logits, torch.ones_like(m_logits).to(self.config["device"])) +\
                            self.L2(p_logits, torch.ones_like(p_logits).to(self.config["device"]))
                cycle_loss = self.L2(m_rec, m) + self.L2(p_rec, p)
                
                id_loss_list += [id_loss.item()]
                GAN_loss_list += [gan_loss.item()]
                cycle_loss_list += [cycle_loss.item()]
                
                loss_G = id_loss * 1e5 + gan_loss + cycle_loss * 1e5
                self.opt_G.zero_grad()
                loss_G.backward()
                self.opt_G.step()
            print(f"M_D: {mean(M_D_loss):.5e}")
            print(f"P_D: {mean(P_D_loss): .5e}")
            print(f"identity loss:{mean(id_loss_list):.5e}")
            print(f"GAN loss:{mean(GAN_loss_list):.5e}")
            print(f"cycle loss:{mean(cycle_loss_list):.5e}")
            if e % 10 == 0:
                torch.save(self.M2P, f"cycle_checkpoints/M2P_{e}.pt")
                torch.save(self.M2P, f"cycle_checkpoints/P2M_{e}.pt")
                os.makedirs(f'cycle_logs/epoch_{e}', exist_ok=True)
                
                x = None
                y = None
                for p in self.photoloader:
                    p = p.to(self.config["device"])
                    p2p = (self.M2P(self.P2M(p)).data + 1) / 2
                    p2m = (self.P2M(p).data + 1) / 2
                    for _x, _y in zip(p2p, p2m):
                        if x == None: x = _x.unsqueeze(0).to('cpu')
                        else: x = torch.cat((x, _x.unsqueeze(0).to('cpu')), 0)
                        if y == None: y = _y.unsqueeze(0).to('cpu')
                        else: y = torch.cat((y, _y.unsqueeze(0).to('cpu')), 0)
                torchvision.utils.save_image(x, f"cycle_logs/epoch_{e}/p2p.jpg", nrow=8)
                torchvision.utils.save_image(x, f"cycle_logs/epoch_{e}/p2m.jpg", nrow=8)
                
                x = None
                y = None
                for m in self.monetloader:
                    m = m.to(self.config["device"])
                    m2m = (self.P2M(self.M2P(m)).data + 1) / 2
                    m2p = (self.P2M(m).data + 1) / 2
                    for _x, _y in zip(m2m, m2p):
                        if x == None: x = _x.unsqueeze(0).to('cpu')
                        else: x = torch.cat((x, _x.unsqueeze(0).to('cpu')), 0)
                        if y == None: y = _y.unsqueeze(0).to('cpu')
                        else: y = torch.cat((y, _y.unsqueeze(0).to('cpu')), 0)
                torchvision.utils.save_image(x, f"cycle_logs/epoch_{e}/m2m.jpg", nrow=8)
                torchvision.utils.save_image(x, f"cycle_logs/epoch_{e}/m2p.jpg", nrow=8)

In [None]:
trainer = CycleGAN()
trainer.train()