In [None]:
!pip install import-ipynb
import import_ipynb
#make sure to have Utils in the Colab file
import Utils

In [None]:
import os, time, pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
from matplotlib import pyplot as plt

In [None]:
def train_model( G, D, loss_key, training_data, root,model, disc, epochnum = 200, size=256,order = True,crop= True,):
  #loss_keys:
  #-5: only L1, just used for testing in the end, reverse oreder
  #0: only BCE_loss
  #1: BCE +L1, L1 small
  #2: BCE +L1, L1 big
  #3: BCE+ MSE(L2), L2 small
  #4: BCE+ MSE, Le big


  # loss
  BCE_loss = nn.BCELoss().cuda()
  L1_loss = nn.L1Loss().cuda()
  MSE_loss = nn.MSELoss().cuda()

  # Adam optimizer
  G_optimizer = optim.Adam(G.parameters(), lr=0.002, betas=(0.5, 0.99))
  D_optimizer = optim.Adam(D.parameters(), lr=0.002, betas=(0.5, 0.99))

  train_hist = {}
  train_hist['D_losses'] = []
  train_hist['G_losses'] = []
  train_hist['per_epoch_ptimes'] = []
  train_hist['total_ptime'] = []

  print('training start!')
  start_time = time.time()
  for epoch in range(epochnum):
      D_losses = []
      G_losses = []
      epoch_start_time = time.time()
      num_iter = 0
      for x_, _ in training_data:
          # train discriminator D
          D.zero_grad()
       
          x_,y_ = Utils.prepare_image(x_,size=256,order = order,crop= crop)
          x_, y_ = Variable(x_.cuda()), Variable(y_.cuda())

          D_result = D(x_, y_).squeeze()
          #print(D_result)
          
          D_real_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda()))

          G_result = G(x_)
  
          D_result = D(x_, G_result).squeeze()
          #print(D_result)
          
          D_fake_loss = BCE_loss(D_result, Variable(torch.zeros(D_result.size()).cuda()))

          D_train_loss = (D_real_loss + D_fake_loss) * 0.5
          D_train_loss.backward()
          D_optimizer.step()

          train_hist['D_losses'].append(D_train_loss.data.cpu())

          D_losses.append(D_train_loss.data)

          # train generator G
          G.zero_grad()

          G_result = G(x_)
          D_result = D(x_, G_result).squeeze()

          if loss_key == 0:
            G_train_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda()))
          elif loss_key == 1:
            G_train_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda())) + 10 * L1_loss(G_result, y_)
          elif loss_key == 2:
            G_train_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda())) + 500 * L1_loss(G_result, y_)
          elif loss_key == 3:
            G_train_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda())) + 10 * MSE_loss(G_result, y_)
          elif loss_key == 4:
            G_train_loss = BCE_loss(D_result, Variable(torch.ones(D_result.size()).cuda())) + 500 * MSE_loss(G_result, y_)
          elif loss_key == -5:
            G_train_loss = L1_loss(G_result, y_)
          else:
            G_train_loss = 0
          G_train_loss.backward()
          G_optimizer.step()
          train_hist['G_losses'].append(G_train_loss.data.cpu())
          G_losses.append(G_train_loss.data)

          num_iter += 1
          if num_iter % 50 == 0:
            print(num_iter)
          if num_iter == 1:
            fixed_x_ = x_
            fixed_y_ = y_

      epoch_end_time = time.time()
      per_epoch_ptime = epoch_end_time - epoch_start_time

      print('epoch ', epoch+1, ' finished')
      if disc == 70:
        fixed_p = root + 'Fixed_results/loss_key'+str(loss_key)+'/' + model  + str(epoch+1) + '.png'
      else:
        fixed_p = root + 'Fixed_results/loss_key'+str(loss_key)+'_GAN'+str(disc)+'/' + model  + str(epoch + 1) + '.png'
      Utils.create_picture(G,Variable(fixed_x_.cuda(), volatile=True), fixed_y_, (epoch+1),fixed_p)
      train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

  end_time = time.time()
  total_ptime = end_time - start_time
  train_hist['total_ptime'].append(total_ptime)

  #print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), total_ptime))
  print("Training finish!... save training results")
  if disc == 70:
    torch.save(G.state_dict(), root + model+ str(loss_key) + 'generator_param.pkl')
    torch.save(D.state_dict(), root + model+ str(loss_key) + 'discriminator_param.pkl')
    with open(root + model +str(loss_key)+ 'train_hist.pkl', 'wb') as f:
      pickle.dump(train_hist, f)
      Utils.plot_results(train_hist, root, model, epoch, loss_key) 
  else:
    torch.save(G.state_dict(), root + model+ str(loss_key)+'GAN'+str(disc) + 'generator_param.pkl')
    torch.save(D.state_dict(), root + model+ str(loss_key) +'GAN'+str(disc) + 'discriminator_param.pkl')
    with open(root + model +str(loss_key)+'GAN'+str(disc) + 'train_hist.pkl', 'wb') as f:
      pickle.dump(train_hist, f)
    
  print('results_saved') 

In [None]:
def test_images(dataset, generator, test_data,key, order= True):
  n=1
  for x_, _ in test_data:
    size = x_.size()[2] 
    if order:
      y_ = x_[:, :, :, 0:size]
      x_ = x_[:, :, :, size:]
    else:
      y_ = x_[:, :, :, size:]
      x_ = x_[:, :, :, 0:size]

    test_image = generator(x_)
    path = dataset + '_results/test_results/loss_key'+str(key)+'/' + str(n) + '_input.png'
    plt.imsave(path, (x_[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
    path = dataset + '_results/test_results/loss_key'+str(key) +'/' +str(n) + '_output.png'
    plt.imsave(path, (test_image[0].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)
    path =dataset + '_results/test_results/loss_key' +str(key)+'/'+ str(n) + '_target.png'
    
    plt.imsave(path, (y_[0].numpy().transpose(1, 2, 0) + 1) / 2)
    n += 1


In [None]:
def test_models_disc(generator, discriminator, test_load, order = True):
  size = 256
  n= 0
  gen_loss = 0
  for x_,_ in test_load:
    if order:
      x_ = x_[:, :, :, size:]
    else:
      x_ = x_[:, :, :, 0:size]
    y_ = generator(x_)
    lossx = discriminator(x_,y_).detach().numpy()
    lossxm = np.mean(lossx)
    gen_loss += lossxm
    n += 1
    

  return gen_loss/n

In [None]:
def test_models_inv(generator_1, generator_2, test_loader, order = True):
  size = 256
  L1_loss = nn.L1Loss()
  n= 0
  gen_loss = 0
  for x_,_ in test_loader:
    if order:
      x_ = x_[:, :, :, size:]
    else:
      x_ = x_[:, :, :, 0:size]
    y_ = generator_1(x_)
    xn = generator_2(y_)
    lossxt = L1_loss(x_,xn)
    gen_loss += lossxt
    n += 1
  loss = gen_loss.item()

  return loss/n