In [1]:
import os, time, pickle, argparse
from lib import networks, utils
from lib.edge_promoting import edge_promoting
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

### Parameters

In [2]:
#input channel for generator
in_ngc=3
#output channel for generator
out_ngc=3
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
batch_size=8
ngf=64
ndf=32
#the number of resnet block layer for generator
nb=6
#downsampling times
n_downsampling=2

#input size
input_size=128
train_epoch=5
pre_train_epoch=10
#Discriminator learning rate, default=0.0002
lrD=0.0002
#Generator learning rate, default=0.0002
lrG=0.0002
#lambda for content loss
con_lambda=0.5
#beta1 for Adam optimizer
beta1=0.5
#beta2 for Adam optimizer
beta2=0.999

In [3]:
#change project name when start a different project
project_name='cartoonGAN_1'
result_path = project_name+'_results'
# results save path
if not os.path.isdir(os.path.join(result_path, 'Reconstruction')):
    os.makedirs(os.path.join(result_path, 'Reconstruction'))
if not os.path.isdir(os.path.join(result_path, 'Transfer')):
    os.makedirs(os.path.join(result_path, 'Transfer'))

data_path = 'data'
src_data_path = os.path.join(data_path,'src_data_path')
tgt_data_path = os.path.join(data_path,'clear_blur_tgt_data_path')

#pre-trained VGG19 model path
vgg_model='pre_trained_VGG19_model_path/vgg19.pth'

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

In [7]:
# edge-promoting
if not os.path.isdir(os.path.join(tgt_data_path, 'newpair')):
    print('edge-promoting start!!')
    edge_promoting(os.path.join(tgt_data_path, 'train'), os.path.join(tgt_data_path, 'newpair'))
else:
    print('edge-promoting already done')

  0%|          | 0/8140 [00:00<?, ?it/s]

edge-promoting start!!


100%|██████████| 8140/8140 [1:21:01<00:00,  1.71it/s]


In [5]:
# data_loader
src_transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
tgt_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_loader_src = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, src_transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_tgt = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, tgt_transform), batch_size=batch_size, shuffle=True, drop_last=True)

In [6]:
# network
G = networks.cyclegan_generator1(in_ngc, out_ngc, ngf, nb, n_downsampling)
D = networks.discriminator(in_ndc, out_ndc, ndf)
VGG = networks.VGG19(init_weights=vgg_model, feature_mode=True)
G.to(device)
D.to(device)
VGG.to(device)
G.train()
D.train()
VGG.eval();
# print('---------- Networks initialized -------------')
# utils.print_network(G)
# utils.print_network(D)
# utils.print_network(VGG)
# print('-----------------------------------------------')

In [7]:
# loss
BCE_loss = nn.BCELoss().to(device)
L1_loss = nn.L1Loss().to(device)

In [8]:
# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)
D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[train_epoch // 2, train_epoch // 4 * 3], gamma=0.1)

pre_train_hist = {}
pre_train_hist['Recon_loss'] = []
pre_train_hist['per_epoch_time'] = []
pre_train_hist['total_time'] = []

In [9]:
train_hist = {}
train_hist['Disc_loss'] = []
train_hist['Gen_loss'] = []
train_hist['Con_loss'] = []
train_hist['per_epoch_time'] = []
train_hist['Gen_loss_one_epoch']=[]
train_hist['Disc_loss_one_epoch']=[]
train_hist['Con_loss_one_epoch']=[]

### Load train_hist

In [18]:
#if need to load train history
with open(os.path.join(name + '_results', 'train_hist.pkl'), 'rb') as pickle_file:
    train_hist = pickle.load(pickle_file)

### Load model

In [11]:
#if need to load model
G.load_state_dict(torch.load(os.path.join(name + '_results', 'generator_latest.pkl')))
D.load_state_dict(torch.load(os.path.join(name + '_results', 'discriminator_latest.pkl')))

### Pre-train

In [10]:
#can skip if model is loaded
print('Pre-training start!')
start_time = time.time()
for epoch in range(pre_train_epoch):
    epoch_start_time = time.time()
    Recon_losses = []
    for x, _ in train_loader_src:
        x = x.to(device)

        # train generator G
        G_optimizer.zero_grad()

        x_feature = VGG((x + 1) / 2)
        G_ = G(x)
        G_feature = VGG((G_ + 1) / 2)

        Recon_loss = 10 * L1_loss(G_feature, x_feature.detach())
        Recon_losses.append(Recon_loss.item())
        pre_train_hist['Recon_loss'].append(Recon_loss.item())

        Recon_loss.backward()
        G_optimizer.step()

        break

    per_epoch_time = time.time() - epoch_start_time
    pre_train_hist['per_epoch_time'].append(per_epoch_time)
    print('[%d/%d] - time: %.2f, Recon loss: %.3f' % ((epoch + 1), pre_train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Recon_losses))))

total_time = time.time() - start_time
pre_train_hist['total_time'].append(total_time)
with open(os.path.join(result_path,  'pre_train_hist.pkl'), 'wb') as f:
    pickle.dump(pre_train_hist, f)

with torch.no_grad():
    G.eval()
    for n, (x, _) in enumerate(train_loader_src):
        x = x.to(device)
        G_recon = G(x)
        result = torch.cat((x[0], G_recon[0]), 2)
        path = os.path.join(result_path, 'Reconstruction', str(n + 1) + '.png')
        plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
        if n == 4:
            break

#     for n, (x, _) in enumerate(test_loader_src):
#         x = x.to(device)
#         G_recon = G(x)
#         result = torch.cat((x[0], G_recon[0]), 2)
#         path = os.path.join(result_path, 'Reconstruction', name + '_test_recon_' + str(n + 1) + '.png')
#         plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
#         if n == 4:
#             break



Pre-training start!
[1/10] - time: 1.70, Recon loss: 30.367
[2/10] - time: 0.07, Recon loss: 32.073
[3/10] - time: 0.11, Recon loss: 31.662
[4/10] - time: 0.11, Recon loss: 28.156
[5/10] - time: 0.11, Recon loss: 31.613
[6/10] - time: 0.11, Recon loss: 30.025
[7/10] - time: 0.11, Recon loss: 31.384
[8/10] - time: 0.11, Recon loss: 30.552
[9/10] - time: 0.11, Recon loss: 32.460
[10/10] - time: 0.11, Recon loss: 30.517


### Train

In [14]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 5

In [15]:
print('training start!')
start_time = time.time()

num_pool = 50
fake_pool = utils.ImagePool(num_pool)

for epoch in range(train_epoch):
    epoch_start_time = time.time()
    G.train()
    G_scheduler.step()
    D_scheduler.step()
    Disc_losses = []
    Gen_losses = []
    Con_losses = []
    for (x, _), (y, _) in zip(train_loader_src, train_loader_tgt):
        e = y[:, :, :, input_size:]
        y = y[:, :, :, :input_size]
        x, y, e = x.to(device), y.to(device), e.to(device)

        # train D
        for param in D.parameters():
            param.requires_grad = True
        D_optimizer.zero_grad()

        D_real = D(y)
        D_real_loss = BCE_loss(D_real, torch.ones(D_real.size(),device=device))

        G_ = G(x)
        G_ = fake_pool.query(G_.detach())
        D_fake = D(G_)
        D_fake_loss = BCE_loss(D_fake, torch.zeros(D_fake.size(),device=device))

        D_edge = D(e)
        D_edge_loss = BCE_loss(D_edge, torch.zeros(D_edge.size(),device=device))

        Disc_loss = D_real_loss + D_fake_loss + D_edge_loss
        Disc_losses.append(Disc_loss.item())
        train_hist['Disc_loss'].append(Disc_loss.item())

        Disc_loss.backward()
        D_optimizer.step()

        # train G
        G_optimizer.zero_grad()
        for param in D.parameters():
            param.requires_grad = False
        G_ = G(x)
        D_fake = D(G_)
        D_fake_loss = BCE_loss(D_fake, torch.ones(D_real.size(),device=device))

        x_feature = VGG((x + 1) / 2)
        G_feature = VGG((G_ + 1) / 2)
        Con_loss = con_lambda * L1_loss(G_feature, x_feature.detach())

        Gen_loss = D_fake_loss + Con_loss

        Gen_losses.append(D_fake_loss.item())
        train_hist['Gen_loss'].append(D_fake_loss.item())
        Con_losses.append(Con_loss.item())
        train_hist['Con_loss'].append(Con_loss.item())

        Gen_loss.backward()
        G_optimizer.step()


    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    Gen_loss_avg = torch.mean(torch.FloatTensor(Gen_losses))
    Con_loss_avg = torch.mean(torch.FloatTensor(Con_losses))
    Disc_loss_avg =  torch.mean(torch.FloatTensor(Disc_losses))
    
    train_hist['Gen_loss_one_epoch'].append(Gen_loss_avg)
    train_hist['Disc_loss_one_epoch'].append(Disc_loss_avg)
    train_hist['Con_loss_one_epoch'].append(Con_loss_avg)
    
    print(
    '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((starting_epoch + epoch + 1), (starting_epoch+train_epoch), per_epoch_time, Disc_loss_avg, Gen_loss_avg, Con_loss_avg))

    with torch.no_grad():
        G.eval()
        for n, (x, _) in enumerate(train_loader_src):
            x = x.to(device)
            G_recon = G(x)
            result = torch.cat((x[0], G_recon[0]), 2)
            path = os.path.join(result_path, 'Transfer', str(starting_epoch+epoch+1) + '_epoch_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        torch.save(G.state_dict(), os.path.join(result_path, 'generator_latest.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, 'discriminator_latest.pkl'))
        with open(os.path.join(result_path,  'train_hist.pkl'), 'wb') as f:
            pickle.dump(train_hist, f)

training start!
[6/10] - time: 295.68, Disc loss: 1.224, Gen loss: 0.989, Con loss: 1.457
[7/10] - time: 295.43, Disc loss: 1.208, Gen loss: 1.000, Con loss: 1.451
[8/10] - time: 295.48, Disc loss: 1.193, Gen loss: 1.015, Con loss: 1.441
[9/10] - time: 311.46, Disc loss: 1.175, Gen loss: 1.029, Con loss: 1.429


KeyboardInterrupt: 