In [1]:
import os, time, pickle
from lib import networks, utils
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = True
device

device(type='cuda')

### Hyper-parameters

In [28]:
#input channel for discriminator
in_ngc=3
#output channel for discriminator
out_ngc=3
#generator first layer number of filters
ngf=16
#input channel for discriminator
in_ndc=3
#output channel for discriminator
out_ndc=1
#discriminator first layer number of filters
ndf=32
#latent vector length
latent_len = 1024

batch_size=10
input_size=128
train_epoch=20

#learning rate, default=0.0002
lrD=0.0002
lrG=0.0004
lrC = 0.0002

#generator lambda
rec_lambda=1
sem_lambda=1
dann_lambda=1

#beta for Adam optimizer
beta1=0.5
beta2=0.999

### Folder structure

In [23]:
# results save path
project_name = 'XGAN_1'
result_path = project_name+'_results'
src_result_name = 'G_S'
tgt_result_name = 'G_T'

data_path = 'data'
src_data_path = os.path.join(data_path,'src_data_path_new')
tgt_data_path = os.path.join(data_path,'tgt_data_path')

#ensure data folder exists
if not os.path.isdir(os.path.join(result_path, src_result_name)):
    os.makedirs(os.path.join(result_path, src_result_name))
if not os.path.isdir(os.path.join(result_path, tgt_result_name)):
    os.makedirs(os.path.join(result_path, tgt_result_name))
if not os.path.isdir(os.path.join(src_data_path,'train')):
    os.makedirs(os.path.join(src_data_path,'train'))
if not os.path.isdir(os.path.join(tgt_data_path,'train')):
    os.makedirs(os.path.join(tgt_data_path,'train'))

### Load Data

In [29]:
# data_loader
transform = transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# train_loader_S = utils.data_load(os.path.join('data', src_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
# train_loader_T = utils.data_load(os.path.join('data', tgt_data), 'train', transform, batch_size, shuffle=True, drop_last=True)
train_loader_S = torch.utils.data.DataLoader(datasets.ImageFolder(src_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)
train_loader_T = torch.utils.data.DataLoader(datasets.ImageFolder(tgt_data_path, transform), batch_size=batch_size, shuffle=True, drop_last=True)


### Model define

In [15]:
# network
G = networks.xgan_generator(in_ngc,out_ngc,ngf)
D = networks.discriminator(in_ndc, out_ndc, ndf)
C = networks.xgan_classifier(latent_len)

G.to(device)
D.to(device)
C.to(device)

xgan_classifier(
  (classifier): Linear(in_features=1024, out_features=2, bias=True)
)

In [16]:
# loss
MSE_loss = nn.MSELoss().to(device)
L1_loss = nn.L1Loss().to(device)
Cross_Entropy_loss = nn.CrossEntropyLoss().to(device)

In [32]:
# optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
C_optimizer = optim.Adam(C.parameters(), lr=lrC, betas=(beta1,beta2))

In [18]:
# train history
train_hist = {}
train_hist['per_epoch_time'] = []
train_hist['total_time'] = []
train_hist['G_loss']=[]
train_hist['D_loss']=[]
train_hist['C_loss']=[]

### Load existing model parameters

In [None]:
G.load_state_dict(torch.load(os.path.join(result_path, 'G.pkl')))
C.load_state_dict(torch.load(os.path.join(result_path, 'C.pkl')))
D.load_state_dict(torch.load(os.path.join(result_path, 'D.pkl')))

### Load train hist

In [None]:
with open(os.path.join(result_path, 'train_hist.pkl'), 'rb') as pickle_file:
    train_hist = pickle.load(pickle_file)

### Starting epoch

In [33]:
#starting_epoch is used to avoid overriding of the previously generated results
starting_epoch = 9

### Train

In [34]:
print('training start!')
start_time = time.time()
num_pool = 50
fake_pool = utils.ImagePool(num_pool)
for epoch in range(train_epoch):
    epoch_start_time = time.time()
    print("==> Epoch {}/{}".format(starting_epoch+epoch + 1, starting_epoch+train_epoch))
#     if (epoch + 1) > decay_epoch:
#         D_A_optimizer.param_groups[0]['lr'] -= lrD / 10
#         D_B_optimizer.param_groups[0]['lr'] -= lrD / 10
#         G_optimizer.param_groups[0]['lr'] -= lrG / 10
    
    G_losses = []
    D_losses = []
    C_losses = []
    for (real_S,_),(real_T,_) in zip(train_loader_S, train_loader_T):
        G.train()

        # input image data
        real_S = real_S.to(device)
        real_T = real_T.to(device)

        # Train generator G
        # S->T
        real_S_latent = G.enc_s2t(real_S)     
        real_S_recon = G.dec_t2s(real_S_latent)
        fake_T = G.dec_s2t(real_S_latent)
        fake_T_latent = G.enc_t2s(fake_T)
              
        #T->S
        real_T_latent = G.enc_t2s(real_T)     
        real_T_recon = G.dec_s2t(real_T_latent)      
        fake_S = G.dec_t2s(real_T_latent)
        fake_S_latent = G.enc_s2t(fake_S)
        
        # rec loss
        G_S_rec_loss = L1_loss(real_S, real_S_recon)
        G_T_rec_loss = L1_loss(real_T, real_T_recon)
        
        # semantic loss
        G_S_sem_loss = MSE_loss(real_S_latent,fake_T_latent)
        G_T_sem_loss = MSE_loss(real_T_latent, fake_S_latent)
        
        # gan loss, only do for S->T to save computing
        D_decision = D(fake_T)
        G_gan_loss = MSE_loss(D_decision, torch.ones(D_decision.size(), device=device))

        # domain adversarial loss
        # source should be 0, so we want it to classify as 1, target should be 1, so we want it to classify as 0
        C_S_decision = C(real_S_latent)
        C_T_decision = C(real_T_latent)
        G_dann_loss = Cross_Entropy_loss(C_S_decision, torch.ones(C_S_decision.shape[0], dtype=torch.long, device=device)) + Cross_Entropy_loss(C_T_decision, torch.zeros(C_T_decision.shape[0],dtype=torch.long,device=device))     
        
        #fix D and C parameters
        for model in [D, C]:
            for param in model.parameters():
                param.requires_grad = False
        
        # Back propagation
        G_rec_loss = 0.5*(G_S_rec_loss + G_T_rec_loss)
        G_sem_loss = 0.5*(G_S_sem_loss + G_T_sem_loss)
        G_loss = G_gan_loss + rec_lambda*G_rec_loss + sem_lambda*G_sem_loss + dann_lambda*G_dann_loss
        G_losses.append(G_loss)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        #train D and C parameters
        for model in [D, C]:
            for param in model.parameters():
                param.requires_grad = True
                
        # Train discriminator D
        D_real_decision = D(real_S)
        D_real_loss = MSE_loss(D_real_decision, torch.ones(D_real_decision.size(), device=device))
        fake_T = fake_pool.query(fake_T.detach())
        D_fake_decision = D(fake_T)
        D_fake_loss = MSE_loss(D_fake_decision, torch.zeros(D_fake_decision.size(), device=device))
        D_loss = (D_real_loss + D_fake_loss) * 0.5
        D_losses.append(D_loss)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        #Train classifier C
        real_S_latent = G.enc_s2t(real_S)
        real_T_latent = G.enc_t2s(real_T)
        C_S_decision = C(real_S_latent.detach())
        C_T_decision = C(real_T_latent.detach())
        C_loss = 0.5*(Cross_Entropy_loss(C_S_decision, torch.zeros(C_S_decision.shape[0],dtype=torch.long,device=device)) + Cross_Entropy_loss(C_T_decision, torch.ones(C_T_decision.shape[0],dtype=torch.long,device=device)))
        C_losses.append(C_loss)
        C_optimizer.zero_grad()
        C_loss.backward()
        C_optimizer.step()
        
    per_epoch_time = time.time() - epoch_start_time
    train_hist['per_epoch_time'].append(per_epoch_time)
    
    G_loss_avg = float(torch.mean(torch.FloatTensor(G_losses)).cpu().numpy())
    D_loss_avg = float(torch.mean(torch.FloatTensor(D_losses)).cpu().numpy())
    C_loss_avg = float(torch.mean(torch.FloatTensor(C_losses)).cpu().numpy())

    train_hist['G_loss'].append(G_loss_avg)
    train_hist['D_loss'].append(D_loss_avg)
    train_hist['C_loss'].append(C_loss_avg)

    
    print(
    '[%d/%d] - time: %.2f, G loss: %.3f, D loss: %.3f, C loss: %.3f' % ((epoch + 1), train_epoch, per_epoch_time, G_loss_avg, D_loss_avg, C_loss_avg))
    
    #Save image result
    with torch.no_grad():
        G.eval()
        for n, (x, _) in enumerate(train_loader_S):
            x = x.to(device)
            G_latent = G.enc_s2t(x)
            G_result = G.dec_s2t(G_latent)
            G_recon = G.dec_t2s(G_latent)
            result = torch.cat((x[0], G_result[0], G_recon[0]), 2)
            path = os.path.join(result_path, 'G_S', str(epoch+starting_epoch) + '_epoch_'  + '_train_' + str(n + 1) + '.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break

        for n, (x,_) in enumerate(train_loader_T):
            x = x.to(device)
            G_latent = G.enc_t2s(x)
            G_result = G.dec_t2s(G_latent)
            G_recon = G.dec_s2t(G_latent)
            result = torch.cat((x[0],G_result[0],G_recon[0]),2)
            path = os.path.join(result_path,'G_T',str(epoch+starting_epoch) + '_epoch_' +'_train_'+str(n+1)+'.png')
            plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2)
            if n == 4:
                break
                
        torch.save(G.state_dict(), os.path.join(result_path, 'G.pkl'))
        torch.save(D.state_dict(), os.path.join(result_path, 'D.pkl'))
        torch.save(C.state_dict(), os.path.join(result_path, 'C.pkl'))

total_time = time.time() - start_time
train_hist['total_time'].append(total_time)
with open(os.path.join(result_path,  'train_hist.pkl'), 'wb') as f:
    pickle.dump(train_hist, f)

training start!
==> Epoch 10/29




[1/20] - time: 173.65, G loss: 205703.672, D loss: 0.000, C loss: 0.372
==> Epoch 11/29
[2/20] - time: 173.36, G loss: 420135.969, D loss: 0.000, C loss: 1.610
==> Epoch 12/29


KeyboardInterrupt: 