In [1]:
import os, time, pickle
from lib import networks, utils, train_history
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

### Hyper-parameters

In [3]:
batch_size=10
input_size=128
train_epoch=20

#input channel for discriminator
in_ngc=3
#output channel for discriminator
out_ngc=3
#generator first layer number of filters
ngf=8
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
#discriminator first layer number of filters
ndf=32
# resnet layer
nb = 4
# coupled layer
coupled_layer = 1
# decoupled layer
decoupled_layer = 4
#latent vector length
latent_len = 4096

#learning rate, default=0.0002
lrD=0.0001
lrG=0.0002
lrC = 0.0001

#generator lambda
rec_lambda=1
sem_lambda=10
idt_lambda=0.1
dann_lambda=1

#beta for Adam optimizer
beta1=0.5
beta2=0.999

### Folder structure

In [4]:
# results save path
project_name = 'XGAN_9'
result_path = project_name+'_results'
src_result_name = 'G_S'
tgt_result_name = 'G_T'

data_path = 'data'
src_data_path = os.path.join(data_path,'src_data_path_new')
# tgt_data_path = os.path.join(data_path,'tgt_data_path')
tgt_data_path = os.path.join(data_path,'clear_blur_tgt_data_path')

#ensure data folder exists
if not os.path.isdir(os.path.join(result_path, src_result_name)):
    os.makedirs(os.path.join(result_path, src_result_name))
if not os.path.isdir(os.path.join(result_path, tgt_result_name)):
    os.makedirs(os.path.join(result_path, tgt_result_name))
if not os.path.isdir(os.path.join(src_data_path,'train')):
    os.makedirs(os.path.join(src_data_path,'train'))
if not os.path.isdir(os.path.join(tgt_data_path,'train')):
    os.makedirs(os.path.join(tgt_data_path,'train'))

### Load Data

In [5]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

tgt_transform = transforms.Compose([
        transforms.Resize((input_size, 2*input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# train_loader_S = utils.data_load(os.path.join('data', src_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
# train_loader_T = utils.data_load(os.path.join('data', tgt_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
train_loader_S = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_T = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, tgt_transform), batch_size=batch_size, shuffle=True, drop_last=True)


### Model define

In [12]:
# network
# G = networks.xgan_generator(in_ngc,out_ngc,ngf)
G = networks.xgan_generator2(in_ngc,out_ngc,ngf,nb,coupled_layer,decoupled_layer)
# G = networks.xgan_generator3(in_ngc,out_ngc,ngf)
# G = networks.xgan_generator4(in_ngc,out_ngc,ngf,input_size,coupled_layer,decoupled_layer)

D = networks.discriminator(in_ndc, out_ndc, ndf)
C = networks.xgan_classifier2(latent_len)
# C = networks.xgan_classifier(latent_len)

G.to(device)
D.to(device)
C.to(device)

xgan_classifier2(
  (classifier): Linear(in_features=4096, out_features=2, bias=True)
)

In [13]:
# loss
MSE_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)
Cross_Entropy_loss = nn.CrossEntropyLoss().to(device)

def D_loss_criterion(D_decision,device,zeros,trick=True):
    if(zeros):
        if(trick):
            return MSE_loss(D_decision, torch.rand(D_decision.size(), device=device)/10.0)
        return MSE_loss(D_decision, torch.zeros(D_decision.size(), device=device))
    else:
        if(trick):
            return MSE_loss(D_decision, 1-torch.rand(D_decision.size(), device=device)/10.0)
        return MSE_loss(D_decision, torch.ones(D_decision.size(), device=device))

In [14]:
# optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
C_optimizer = optim.Adam(C.parameters(), lr=lrC, betas=(beta1,beta2))

In [None]:
train_hist = train_history.train_history(['per_epoch_time',
                                          'G_rec_loss',
                                          'G_sem_loss',
                                          'G_gan_loss',
                                          'G_dann_loss',
                                          'D_loss',
                                          'C_loss'])

In [15]:
# train history
train_hist = {}
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
train_hist['G_rec_loss']=[]
train_hist['G_sem_loss']=[]
train_hist['G_gan_loss']=[]
train_hist['G_dann_loss']=[]
train_hist['D_loss']=[]
train_hist['C_loss']=[]

### Load existing model parameters

In [40]:
G.load_state_dict(torch.load(os.path.join(result_path, 'G.pkl')))
C.load_state_dict(torch.load(os.path.join(result_path, 'C.pkl')))
D.load_state_dict(torch.load(os.path.join(result_path, 'D.pkl')))

### Load train hist

In [None]:
train_hist.load_train(os.path.join(result_path, 'train_hist.pkl'))

In [11]:
with open(os.path.join(result_path, 'train_hist.pkl'), 'rb') as pickle_file:
    train_hist = pickle.load(pickle_file)

### Starting epoch

In [22]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 40

### Train

In [23]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
#     print("==> Epoch {}/{}".format(starting_epoch+epoch + 1, starting_epoch+train_epoch))
#     if (epoch + 1) > decay_epoch:
#         D_A_optimizer.param_groups[0]['lr'] -= lrD / 10
#         D_B_optimizer.param_groups[0]['lr'] -= lrD / 10
#         G_optimizer.param_groups[0]['lr'] -= lrG / 10
    
    G_rec_losses = []
    G_sem_losses = []
    G_gan_losses = []
    G_dann_losses = []
    
    D_losses = []
    C_losses = []
    for (real_S,_),(y,_) in zip(train_loader_S, train_loader_T):
        G.train()
         
        blur_T = y[:, :, :, input_size:]
        real_T = y[:, :, :, :input_size]
        real_S, real_T, blur_T = real_S.to(device), real_T.to(device), blur_T.to(device)

        # Train generator G
        # do not record grad for D and C to save time
        for model in [D, C]:
            for param in model.parameters():
                param.requires_grad = False
        # S->T
        real_S_latent = G.enc_s2t(real_S)     
        real_S_recon = G.dec_t2s(real_S_latent)
        fake_T = G.dec_s2t(real_S_latent)
        fake_T_latent = G.enc_t2s(fake_T)
              
        #T->S
        real_T_latent = G.enc_t2s(real_T)     
        real_T_recon = G.dec_s2t(real_T_latent)      
        fake_S = G.dec_t2s(real_T_latent)
        fake_S_latent = G.enc_s2t(fake_S)
        
        # identity loss
#         G_idt_loss = L1_loss(real_S, fake_T)
        
        # rec loss
        G_S_rec_loss = L1_loss(real_S, real_S_recon)
        G_T_rec_loss = L1_loss(real_T, real_T_recon)
        
        # semantic loss
        G_S_sem_loss = MSE_loss(real_S_latent,fake_T_latent)
        G_T_sem_loss = MSE_loss(real_T_latent, fake_S_latent)
        
        # gan loss, only do for S->T to save computing
        D_decision = D(fake_T)
        G_gan_loss = D_loss_criterion(D_decision,device,zeros=False,trick=False)

        # domain adversarial loss
        # source should be 0, so we want it to classify as 1, target should be 1, so we want it to classify as 0
        C_S_decision = C(real_S_latent)
        C_T_decision = C(real_T_latent)
        G_dann_loss = Cross_Entropy_loss(C_S_decision, torch.ones(C_S_decision.shape[0], dtype=torch.long, device=device)) + Cross_Entropy_loss(C_T_decision, torch.zeros(C_T_decision.shape[0],dtype=torch.long,device=device))     
        
        # Back propagation
        G_rec_loss = 0.5*(G_S_rec_loss + G_T_rec_loss)
        G_sem_loss = 0.5*(G_S_sem_loss + G_T_sem_loss)
#         G_loss = G_gan_loss + rec_lambda*G_rec_loss + sem_lambda*G_sem_loss + dann_lambda*G_dann_loss + idt_lambda*G_idt_loss
        G_loss = G_gan_loss + rec_lambda*G_rec_loss + sem_lambda*G_sem_loss + dann_lambda*G_dann_loss

        G_rec_losses.append(G_rec_loss)
        G_sem_losses.append(G_sem_loss)
        G_gan_losses.append(G_gan_loss)
        G_dann_losses.append(G_dann_loss)
        
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        #train D and C parameters
        for model in [D, C]:
            for param in model.parameters():
                param.requires_grad = True
                
        # Train discriminator D
        D_real_decision = D(real_S)
        D_real_loss = D_loss_criterion(D_real_decision,device,zeros=False,trick=True)
        fake_T = fake_pool.query(fake_T.detach())
        D_fake_decision = D(fake_T)
        D_fake_loss = D_loss_criterion(D_fake_decision,device,zeros=True,trick=True)
        D_blur_decision = D(blur_T)
        D_blur_loss = D_loss_criterion(D_blur_decision,device,zeros=True,trick=True)
        D_loss = D_real_loss + D_fake_loss + D_blur_loss
        D_losses.append(D_loss)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        #Train classifier C
        real_S_latent = G.enc_s2t(real_S)
        real_T_latent = G.enc_t2s(real_T)
        C_S_decision = C(real_S_latent.detach())
        C_T_decision = C(real_T_latent.detach())
        C_loss = 0.5*(Cross_Entropy_loss(C_S_decision, torch.zeros(C_S_decision.shape[0],dtype=torch.long,device=device)) + Cross_Entropy_loss(C_T_decision, torch.ones(C_T_decision.shape[0],dtype=torch.long,device=device)))
        C_losses.append(C_loss)
        C_optimizer.zero_grad()
        C_loss.backward()
        C_optimizer.step()
        
    #record train history
    train_params = []
    per_epoch_time = time.time() - epoch_start_time
    train_params.append(per_epoch_time)
    for loss in [G_rec_losses,G_sem_losses,G_gan_losses,G_dann_losses,D_losses,C_losses]:
        train_params.append(torch.mean(torch.FloatTensor(loss)))  
    train_hist.add_params(train_params)
    print(str.format('{}/{}',starting_epoch+epoch+1,starting_epoch+train_epoch) + train_hist.get_last_param_str())
    
    #Save image result
    with torch.no_grad():
        G.eval()
        for n, (x, _) in enumerate(train_loader_S):
            x = x.to(device)
            G_latent = G.enc_s2t(x)
            G_result = G.dec_s2t(G_latent)
            G_recon = G.dec_t2s(G_latent)
            result = torch.cat((x[0], G_result[0], G_recon[0]), 2)
            path = os.path.join(result_path, 'G_S', str(epoch+starting_epoch) + '_epoch_'  + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 2:
                break

#         for n, (x,_) in enumerate(train_loader_T):
#             x = x.to(device)
#             G_latent = G.enc_t2s(x)
#             G_result = G.dec_t2s(G_latent)
#             G_recon = G.dec_s2t(G_latent)
#             result = torch.cat((x[0],G_result[0],G_recon[0]),2)
#             path = os.path.join(result_path,'G_T',str(epoch+starting_epoch) + '_epoch_' +'_train_'+str(n+1)+'.png')
#             plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
#             if n == 2:
#                 break
                
        torch.save(G.state_dict(), os.path.join(result_path, 'G.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, 'D.pkl'))
        torch.save(C.state_dict(), os.path.join(result_path, 'C.pkl'))
        train_hist.save_train(os.path.join(result_path,  'train_hist.pkl'))


training start!
[41/60] - time: 222.34, G_rec loss: 46639.766, G_sem loss: 37.057, G_gan loss: 0.903, G_dann loss: 16.002, D loss: 0.003, C loss: 0.007
[42/60] - time: 223.77, G_rec loss: 46470.855, G_sem loss: 35.062, G_gan loss: 0.906, G_dann loss: 15.740, D loss: 0.003, C loss: 0.007
[43/60] - time: 224.89, G_rec loss: 46318.078, G_sem loss: 34.304, G_gan loss: 0.904, G_dann loss: 15.765, D loss: 0.003, C loss: 0.006
[44/60] - time: 225.12, G_rec loss: 46063.312, G_sem loss: 33.828, G_gan loss: 0.903, G_dann loss: 15.906, D loss: 0.003, C loss: 0.006
[45/60] - time: 226.26, G_rec loss: 45899.707, G_sem loss: 33.440, G_gan loss: 0.908, G_dann loss: 15.983, D loss: 0.004, C loss: 0.007
[46/60] - time: 225.31, G_rec loss: 45769.133, G_sem loss: 33.352, G_gan loss: 0.903, G_dann loss: 16.059, D loss: 0.003, C loss: 0.006
[47/60] - time: 226.66, G_rec loss: 45575.254, G_sem loss: 33.204, G_gan loss: 0.904, G_dann loss: 16.170, D loss: 0.003, C loss: 0.007
[48/60] - time: 225.92, G_rec lo

KeyboardInterrupt: 

In [20]:
with open(os.path.join(result_path,  'train_hist.pkl'), 'wb') as f:
    pickle.dump(train_hist, f)