In [None]:
from google.colab import drive
drive.mount('./mount')

In [None]:
import os

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.utils import make_grid
import torchvision.transforms as transforms
from torchvision.transforms import Resize, RandomCrop, Normalize, ToTensor
from torch.utils.data import Dataset, DataLoader

import itertools
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# check if CUDA is available
# if yes, set default tensor type to cuda

if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  print("using cuda:", torch.cuda.get_device_name(0))
  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
class ImageDataset(Dataset):
    def __init__(self, image_dir, size=(256, 256)):
        super().__init__()
        self.image_dir = image_dir
        self.img_idx = dict()

        # 전처리 과정
        self.transform = transforms.Compose([
            Resize((286, 286)),
            RandomCrop((256,256)),
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
        ])

        
        for i, fl in enumerate(os.listdir(self.image_dir)):
            self.img_idx[i] = fl

    def __getitem__(self, idx):
        image_dir = os.path.join(self.image_dir, self.img_idx[idx])
        img = Image.open(image_dir)
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.img_idx)

In [None]:
def unnormalize(image, mean_=0.5, std_=0.5):
    if torch.is_tensor(image):
        image = image.detach().cpu().numpy()
    un_normalized_img = image * std_ + mean_
    un_normalized_img = un_normalized_img * 255
    return np.uint8(un_normalized_img)
    

In [None]:
# Load Dataset
photo_ds = ImageDataset('/content/mount/MyDrive/dataset/Art_GAN/photo/')
monet_ds = ImageDataset('/content/mount/MyDrive/dataset/Art_GAN/vangogh/')

In [None]:
photo_dl = DataLoader(photo_ds, batch_size = 1, shuffle = True, pin_memory=True)
monet_dl = DataLoader(monet_ds, batch_size = 1, shuffle = True, pin_memory=True)
test = DataLoader(photo_ds, batch_size = 2, shuffle = False, pin_memory=True)

In [None]:
def show_test(fixed_X, G_XtoY, mean_=0.5, std_=0.5):
    #Create fake pictures for both cycles
    fake_Y = G_XtoY(fixed_X.to(device))
    
    #Generate grids
    grid_x =  make_grid(fixed_X).permute(1, 2, 0).detach().cpu().numpy()
    grid_fake_y =  make_grid(fake_Y).permute(1, 2, 0).detach().cpu().numpy()
    
    #Normalize pictures to pixel range rom 0 to 255
    X, fake_Y = unnormalize(grid_x, mean_, std_), unnormalize(grid_fake_y, mean_, std_)
    
    #Transformation from X -> Y
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(20, 10))
    ax1.imshow(X)
    ax1.axis('off')
    ax1.set_title('Original')
    ax2.imshow(fake_Y)
    ax2.axis('off')
    ax2.set_title('Converted')
    plt.show()

In [None]:
class Resblock(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, kernel_size = 3, bias = False),
        nn.InstanceNorm2d(256)
    )

    self.conv2 = nn.Sequential(
        nn.ReflectionPad2d(1),
        nn.Conv2d(256, 256, kernel_size = 3, bias = False),
        nn.InstanceNorm2d(256)
    )

  def forward(self, inputs):
    output = torch.nn.functional.relu(self.conv1(inputs))
    return torch.nn.functional.relu(inputs + self.conv2(output))

In [None]:
# Create Generator

class Generator(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(3, 64, kernel_size = 7, bias = False),
        nn.InstanceNorm2d(64),
        nn.GELU()
    )

    self.downsampling = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size = 3, stride = 2, padding = 1, bias = False),
        nn.InstanceNorm2d(128),
        nn.GELU(),

        nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1, bias = False),
        nn.InstanceNorm2d(256),
        nn.GELU()
    )

    resblock_layer = []
    for i in range(9):
      resblock_layer += [Resblock()]

    self.resblock = nn.Sequential(*resblock_layer)

    self.upsampling = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
        nn.InstanceNorm2d(128),
        nn.GELU(),

        nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias = False),
        nn.InstanceNorm2d(64),
        nn.GELU()
    )

    self.conv2 = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(64, 3, kernel_size = 7, bias = False),
        nn.Tanh()
    )
    

  def forward(self, inputs):
    output = self.conv1(inputs)

    output = self.downsampling(output)
    output = self.resblock(output)
    output = self.upsampling(output)
    output = self.conv2(output)
    
    return output

In [None]:
# CycleGAN - Discriminator with PatchGAN

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    # Set simple model
    self.model = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size = 4, padding = 1, stride = 2),
        nn.GELU(),

        nn.Conv2d(64, 128, kernel_size = 4, padding = 1, stride = 2, bias = False),
        nn.InstanceNorm2d(128),
        nn.GELU(),

        nn.Conv2d(128, 256, kernel_size = 4, padding = 1, stride = 2, bias = False),
        nn.InstanceNorm2d(256),
        nn.GELU(),

        nn.Conv2d(256, 512, kernel_size = 4, padding = 1, stride = 2, bias = False),
        nn.InstanceNorm2d(512),
        nn.GELU(),

        nn.Conv2d(512, 512, kernel_size = 4, padding = 1, stride = 2, bias = False),
        nn.BatchNorm2d(512),
        nn.GELU(),

        nn.Conv2d(512, 1, kernel_size = 4, padding = 1, stride = 1, bias = False),
        nn.Sigmoid(),
    )
    
  def forward(self, inputs):
    return self.model(inputs)

In [None]:
class replay_buffer():
    
    def __init__(self, pool_size):
      self.pool_size = pool_size
      if self.pool_size > 0:  # create an empty pool
        self.num_imgs = 0
        self.images = []

    def query(self, images):
        to_return = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                to_return.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    to_return.append(tmp)
                    self.images[random_id] = image
                    
                else:       # by another 50% chance, the buffer will return the current image
                    to_return.append(image)
        to_return = torch.cat(to_return, 0)   # collect all the images and return
        return to_return

In [None]:
def weights_init_normal(m):
    
    #classname will be something like: `Conv`, `BatchNorm2d`, `Linear`, etc.
    classname = m.__class__.__name__
    
    #normal distribution with given paramters
    std_dev = 0.02
    mean = 0.0
    
    # Initialize conv layer
    if hasattr(m, 'weight') and (classname.find('Conv') != -1):
        torch.nn.init.normal_(m.weight.data, mean, std_dev)

In [None]:
# Set Network
G_XtoY = Generator().to(device)
G_YtoX = Generator().to(device)
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)

#Weight initialization
G_XtoY.apply(weights_init_normal)
G_YtoX.apply(weights_init_normal)
D_X.apply(weights_init_normal)
D_Y.apply(weights_init_normal)

print("                     G_XtoY                    ")
print("-----------------------------------------------")
print(G_XtoY)
print()

print("                     G_YtoX                    ")
print("-----------------------------------------------")
print(G_YtoX)
print()

print("                      D_X                      ")
print("-----------------------------------------------")
print(D_X)
print()

print("                      D_Y                      ")
print("-----------------------------------------------")
print(D_Y)
print()

In [None]:
# Define loss function
lambda_weight = 10

loss_GAN = nn.BCEWithLogitsLoss().to(device)
L1_loss = nn.L1Loss().to(device)

def Cycle_loss(inputs, targets):
  loss = L1_loss(inputs, targets)
  return lambda_weight * loss

def G_loss(inputs):
  loss = loss_GAN(inputs, torch.ones_like(inputs))
  return loss

def D_loss(inputs, targets):
  loss = loss_GAN(inputs, torch.ones_like(inputs))
  
  g_loss = loss_GAN(targets, torch.zeros_like(targets))
  
  total_loss = loss + g_loss

  return total_loss * 0.5

def Identity_loss(inputs,targets):
  loss = L1_loss(inputs, targets)
  return lambda_weight * loss * 0.5


In [None]:
lr = 0.001

G_optimizer = torch.optim.Adam(list(G_XtoY.parameters()) + list(G_YtoX.parameters()), lr, betas=[0.500,0.999])
D_X_optimizer = torch.optim.Adam(D_X.parameters(), lr , betas=[0.500,0.999])
D_Y_optimizer = torch.optim.Adam(D_Y.parameters(), lr , betas=[0.500,0.999])

In [None]:
def train(photo_dl, monet_dl, test, n_epochs=1000):
    
    D_losses = []
    G_losses = []

    test_X = next(iter(test))

    batch_per_epoch = min(len(iter(photo_dl)),len(iter(monet_dl)))

    D_avg_loss = 0
    G_avg_loss = 0

    photo_img = next(iter(photo_dl))
    monet_img = next(iter(monet_dl))

    Buffer_XtoY = replay_buffer(50)
    Buffer_YtoX = replay_buffer(50)


    #Loop through epochs
    for epoch in range(1, n_epochs+1):

        if epoch % batch_per_epoch == 0:
          photo_img = next(iter(photo_dl))
          monet_img = next(iter(monet_dl))
        
        #move images to GPU if available (otherwise stay on CPU)
        X = photo_img.to(device) # X
        Y = monet_img.to(device) # Y
        
        # Discriminator Train
        D_X_optimizer.zero_grad()

        d_real_X = D_X(X)

        fake_X = G_YtoX(Y).detach()

        d_fake_X = D_X(fake_X)

        fake_X = Buffer_YtoX.query(fake_X)
        D_X_loss = D_loss(d_real_X,d_fake_X)
        D_X_loss.backward()
        D_X_optimizer.step()


        D_Y_optimizer.zero_grad()

        d_real_Y = D_Y(Y)

        fake_Y = G_XtoY(X).detach()

        d_fake_Y = D_Y(fake_Y)
        
        fake_Y = Buffer_XtoY.query(fake_Y)
        D_Y_loss = D_loss(d_real_Y,d_fake_Y)
        D_Y_loss.backward()
        D_Y_optimizer.step()

        D_total_loss = (D_X_loss + D_Y_loss) / 2

        # Generator Train
        # For domain Y
        G_optimizer.zero_grad()

        fake_X = G_YtoX(Y)

        d_fake_X = D_X(fake_X)

        G_YtoX_loss = G_loss(d_fake_X)

        rec_Y = G_XtoY(fake_X)

        Y_cycle_loss = Cycle_loss(Y, rec_Y)

        Y_iden_loss = Identity_loss(Y,fake_X)

        # For domain X
        fake_Y = G_XtoY(X)

        d_fake_Y = D_Y(fake_Y)

        G_XtoY_loss = G_loss(d_fake_Y)

        rec_X = G_YtoX(fake_Y)

        X_cycle_loss = Cycle_loss(X, rec_X)

        X_iden_loss = Identity_loss(X,fake_Y)


        G_total_loss = G_YtoX_loss + G_XtoY_loss + Y_cycle_loss + X_cycle_loss + Y_iden_loss + X_iden_loss

        G_total_loss.backward()
        G_optimizer.step()
        
        # Train log

        D_avg_loss += D_total_loss / batch_per_epoch
        G_avg_loss += G_total_loss / batch_per_epoch

        if epoch % batch_per_epoch == 0 :
          D_losses.append(D_avg_loss.item())
          G_losses.append(G_avg_loss.item())
          real_epoch = int(epoch / batch_per_epoch)
          total_epoch = int(n_epochs / batch_per_epoch)
          print('Epoch [{:5d}/{:5d}] | D_total_loss: {:6.4f} | G_total_loss: {:6.4f}'.format(
                    real_epoch, total_epoch, D_avg_loss.item(), G_avg_loss.item()))

        if epoch % (batch_per_epoch * 10) == 0 :
            G_XtoY.eval()
            show_test(test_X,G_XtoY)
            #set generators to train mode to continue training
            G_XtoY.train()

        if epoch % batch_per_epoch == 0:
          D_avg_loss = 0
          G_avg_loss = 0

    return D_losses,G_losses

In [None]:
batch_per_epoch = min(len(photo_dl), len(monet_dl))
epoch_true = 100
n_epochs = epoch_true * batch_per_epoch

In [None]:
%%time

D_losses,G_losses = train(photo_dl,monet_dl,test,n_epochs = n_epochs)

In [None]:
#Plot loss functions over training
fig, ax = plt.subplots(figsize=(12,8))
D_losses = np.array(D_losses)
G_losses = np.array(G_losses)
plt.plot(D_losses, label='Discriminators', alpha=0.5)
plt.plot(G_losses, label='Generators', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.show()

In [None]:
def save_checkpoint(model):
  save_path = '/content/mount/MyDrive/dataset/Art_GAN/model/'
  torch.save(model, save_path + 'monet_model.pt')  # 전체 모델 저장

In [None]:
save_checkpoint(G_XtoY)