In [None]:
import torch
import torch.nn as nn

import argparse
import os
from math import log10
import json

import torch.optim as optim
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms


from statistics import mean
  

from torch.nn import init
import functools
import itertools

from torch.autograd import Variable
from torch.optim import lr_scheduler
import numpy as np
import time
from generators.generators import create_gen
from discriminators.discriminators import create_disc
from losses.Loss import GANLoss
from datasets.datasets import get_dataset
from util import ImagePool, set_requires_grad,tensor_to_plt,init_weights, mkdir
from Tensorboard_Logger import Logger

In [None]:
def get_scheduler(optimizer,opt):
    def lambda_rule(epoch):
        lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.iter_constant) / float(opt.iter_decay + 1)
        return lr_l
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    
    return scheduler

class Args():
    '''
    See Pix2Pix.ipynb for hyperparmeter details
    '''
    def __init__(self):
        self.batch_size = 32
        self.test_batch_size = 32
        self.input_dim = 3
        self.output_dim = 1
        self.gen_filters =64
        self.disc_filters =64
        self.total_iters=6
        self.epoch_count =1
        self.iter_constant = 200
        self.iter_decay = 200
        self.lr = 0.0002
        self.beta1 = 0.5
        self.cuda = True
        self.threads = 8
        self.seed = 123
        self.lamb = 100
        self.use_ls = True
        self.resblocks = 9
        self.norm = "instance"
        self.dropout = False
        self.gen = "Resnet"
        self.disc= "Global"
        self.paired_dataset = False
        self.dataset_name = "sketchy" 
        self.folder_name = "12345"
    
        self.lambda_recon =10 #reconstruction loss weight
        self.pool_size=50 #how many images we store in our image pool, that keeps track of past images


In [None]:

class Train_CycleGan:
    def __init__(self,opt,traindataset,testdataset):
        
        self.dataset = DataLoader(dataset=traindataset, batch_size=opt.batch_size, shuffle=True,num_workers=opt.threads)
        self.test_set = DataLoader(dataset=testdataset, batch_size=opt.test_batch_size, shuffle=True,num_workers=opt.threads)
        self.atest, self.btest,self.btestreal = next(iter(self.test_set))
        self.dataviz = DataLoader(dataset=traindataset, batch_size=opt.test_batch_size, shuffle=True,num_workers=opt.threads)
        self.atrain, self.btrain,self.btrainreal = next(iter(self.dataviz))
        
        self.device = torch.device("cuda:0" if opt.cuda else "cpu")
        self.writer = Logger(opt.folder_name)
        self.writer.write_photo_to_tb(self.atest,"photos test")
        self.writer.write_sketch_to_tb(self.btestreal,"sketches test")
        self.writer.write_photo_to_tb(self.atrain,"photos train")
        self.writer.write_sketch_to_tb(self.btrain,"sketches train")

        self.G_ab = create_gen(opt.gen,opt.input_dim,opt.output_dim,opt.gen_filters,opt.norm)
        self.G_ab.to(self.device)
        init_weights(self.G_ab)
        
        self.G_ba = create_gen(opt.gen,opt.output_dim,opt.input_dim,opt.gen_filters,opt.norm)
        self.G_ba.to(self.device)
        init_weights(self.G_ba)
        

        self.D_b = create_disc(opt.disc,opt.output_dim,use_sigmoid=False) #discriminator for sketches
        self.D_b.to(self.device)
        init_weights(self.D_b)
        
        self.D_a = create_disc(opt.disc,opt.input_dim,use_sigmoid=False) #discriminator for images
        self.D_a.to(self.device)
        init_weights(self.D_a)


        self.MSE = nn.MSELoss().to(self.device)
        self.L1 = nn.L1Loss().to(self.device)

        self.schedulers = []
        self.optimizers = []
        
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.G_ab.parameters(), 
                                                            self.G_ba.parameters()),
                                            lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.D_a.parameters(),
                                                            self.D_b.parameters()),
                                            lr=opt.lr, betas=(opt.beta1, 0.999))
            
        
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        for optimizer in self.optimizers:
            self.schedulers.append(get_scheduler(optimizer,opt))
        
    
        self.gen_loss = []
        self.disc_loss = []
        self.l1_loss = []
        self.gan_loss = []
        
        self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
        self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images

    def train(self,opt):
        
        #train
        for epoch in range(opt.epoch_count, opt.total_iters+1):

            #go through each batch
            lossdlist = []
            lossglist = []
            t1 = time.time()
                                   
            for i, batch in enumerate(self.dataset):
                print("training epoch ",epoch,"batch", i,"/",len(self.dataset))
                real_A, real_B = batch[0].to(self.device), batch[1].to(self.device)


                set_requires_grad(nets=self.D_a, requires_grad=False)
                set_requires_grad(nets=self.D_b, requires_grad=False)
                self.optimizer_G.zero_grad()
 
                
                fake_A = self.G_ba(real_B)
                fake_B = self.G_ab(real_A)
                recon_A = self.G_ba(fake_B)
                recon_B = self.G_ab(fake_A)
                
                pred_fake_A = self.D_a(fake_A)
                pred_fake_B = self.D_b(fake_B)
                
                real_label = torch.ones(pred_fake_A.size()).to(self.device)
                
                gen_loss_A = self.MSE(pred_fake_A, real_label)
                gen_loss_B = self.MSE(pred_fake_B, real_label)
                
                cycle_loss_A = self.L1(recon_A, real_A) * opt.lambda_recon
                cycle_loss_B = self.L1(recon_B, real_B) * opt.lambda_recon
                
                gen_loss = gen_loss_A + gen_loss_B + cycle_loss_A + cycle_loss_B 
                lossglist.append(gen_loss.item())
                gen_loss.backward()
                self.optimizer_G.step()
                ############################################################
                set_requires_grad(nets=self.D_a, requires_grad=True)
                set_requires_grad(nets=self.D_b, requires_grad=True)
                
                self.optimizer_D.zero_grad()

                fake_A = self.fake_A_pool.query(fake_A) #.to(self.device)
                fake_B = self.fake_B_pool.query(fake_B) #.to(self.device)

                pred_real_A = self.D_a(real_A)
                pred_fake_A = self.D_a(fake_A.detach())
                pred_real_B = self.D_b(real_B)
                pred_fake_B = self.D_b(fake_B.detach())

                real_label = torch.ones(pred_real_A.size()).to(self.device)
                fake_label = torch.zeros(pred_fake_A.size()).to(self.device)

                a_dis_real_loss = self.MSE(pred_real_A, real_label)
                a_dis_fake_loss = self.MSE(pred_fake_A, fake_label)
                b_dis_real_loss = self.MSE(pred_real_B, real_label)
                b_dis_fake_loss = self.MSE(pred_fake_B, fake_label)

                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss)*0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss)*0.5
                lossdlist.append(a_dis_loss.item()*.5 + b_dis_loss.item()*0.5)

                a_dis_loss.backward()
                b_dis_loss.backward()

                self.optimizer_D.step()

               

            #update_learning_rate()
            for scheduler in self.schedulers:
                scheduler.step()
            lr = self.optimizers[0].param_groups[0]['lr']
            print('learning rate = %.7f' % lr)
            t2 = time.time()
            diff = t2-t1
            print("iteration:",epoch,"loss D:", mean(lossdlist),"loss G:", mean(lossglist))
            print("Took ", diff, "seconds")
            print("Estimated time left:", diff*(opt.total_iters - epoch))

            self.gen_loss.append(mean(lossglist))
            self.disc_loss.append(mean(lossdlist))


            if epoch % 1 == 0:
                with torch.no_grad():
                    out1 = self.G_ab(self.atrain.to(self.device))
                    title= "Epoch "+str(epoch) +"Training"
                    self.writer.write_sketch_to_tb(out1.detach(),title) 
                    
                    out2 = self.G_ab(self.atest.to(self.device))
                    title= "Epoch "+str(epoch)
                    self.writer.write_sketch_to_tb(out2.detach(),title)
            
            
            
        self.writer.plot_losses(self.gen_loss,self.disc_loss,[])
    
    def save_model(self,folderpath,modelpath):
        mkdir(folderpath)
        torch.save({
            'genAB': self.G_ab.module.state_dict(),
            'genBA': self.G_ba.module.state_dict(),
            'discA': self.D_a.module.state_dict(),
            'discB': self.D_b.module.state_dict(),
            'optimizerG_state_dict': self.optimizer_G.state_dict(),
            'optimizerD_state_dict': self.optimizer_D.state_dict(),
            
            }, modelpath)
        
    def save_arrays(self,path):
        np.save( os.path.join(path,"ganloss"),np.asarray(self.gen_loss))
        np.save( os.path.join(path,"discloss"),np.asarray(self.disc_loss))
        np.save( os.path.join(path,"l1loss"),np.asarray(self.l1_loss))
        
    def save_hyper_params(self,folderpath,opt):
        with open(os.path.join(folderpath,'params.txt'), 'w') as file:
             file.write(json.dumps(opt.__dict__)) 
        
        

In [None]:
opt = Args()

photo_path_train = os.path.join(os.getcwd(),"data",opt.dataset_name,"train", "photo")
sketch_path_train = os.path.join(os.getcwd(),"data",opt.dataset_name,"train", "sketch")
train_set = get_dataset(photo_path_train,sketch_path_train, opt,flip=False,jitter=False,erase=True)

photo_path_test = os.path.join(os.getcwd(),"data",opt.dataset_name,"test", "photo")
sketch_path_test = os.path.join(os.getcwd(),"data",opt.dataset_name,"test", "sketch")
testing_set =  get_dataset(photo_path_test,sketch_path_test, opt,flip=False,jitter=False,erase=False)

In [None]:
exps = [opt]
for option in exps:
    experiment = Train_CycleGan(option,train_set,testing_set)
    experiment.train(option)
    folderpath = os.path.join(os.getcwd(),option.folder_name)
    model_path = os.path.join(os.getcwd(),option.folder_name,option.gen)
    experiment.save_model(folderpath,model_path)
    experiment.save_arrays(folderpath)
    experiment.save_hyper_params(folderpath,opt)