In [2]:
%run Include.ipynb
%run Net.ipynb
%run Data.ipynb
%run Topo_treatment.ipynb
%run Viewer.ipynb
        
class GAN(object):
    
    def __init__(self, general, adv_params, G_arch, D_arch):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        
        self.N_critic = adv_params["wgangp"]["N_CRITIC"]
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        torch.manual_seed(random.randint(1, 10000))
        
        self.inputG_dims, G_layers = Net.parse_layers(G_arch)
        self.inputD_dims, D_layers = Net.parse_layers(D_arch)
        self.netG = Network_template(gpu_num, G_layers).to(self.device)
        Net.init_weights(self.netG, "normal")
        self.netD = Network_template(gpu_num, D_layers).to(self.device)
        Net.init_weights(self.netD, "normal")
        
        self.et         = Edges_(adv_params, debug=False)
        self.criterion  = GANLoss(loss_mode).to(self.device)
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=(beta1, beta2))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, beta2))
        
    def sample_(self, shape):
        z_ = torch.randn(shape, device=self.device)
        return self.netG(z_)
        
    def D_iteration(self, Dreal_device, Dfake_device):
        self.netD.zero_grad()      
        errD = self.criterion(["D", self.netD, self.device, Dreal_device, Dfake_device])
        errD.backward()
        self.optimizerD.step()
        return errD.item()
    
    def G_iteration(self, Dfake_device, withTopo):
        self.netG.zero_grad()
        errG = self.criterion(["G", self.netD, Dfake_device])
#         errT = self.et.test(Dfake_device, self.device)
#         errT.backward()

        fake_fix = self.et.fix_with_topo(Dfake_device.detach().cpu().numpy(), 1, -1.0, 1.0)
        print("d")

        errG.backward()
        self.optimizerG.step()
        return errG.item()
    
    def train(self, data_params, withTopo):
        epochs        = data_params["epochs"]
        batch_size    = data_params["batch_size"]
        batch_workers = data_params["batch_workers"]
        shuffle       = data_params["shuffle"]
        drop_last     = data_params["drop_last"]
        dataloader    = Data_fetcher.fetch_dataset(FLAGS.dataset, batch_size, batch_workers, shuffle, drop_last)
        
        if withTopo:
            self.et.load_pd_pool(FLAGS.pds_path, "dat", 1.0, batch_size, 0)
        if FLAGS.continue_model:
            self.netG.load_state_dict(torch.load('%s/netG_step_%d.pth' % (FLAGS.model_save, 5000)))
            self.netD.load_state_dict(torch.load('%s/netD_step_%d.pth' % (FLAGS.model_save, 5000)))

        fixed_z_ = torch.randn([batch_size]+self.inputG_dims, device=self.device)
        step = 0
        for epoch in range(epochs):
            for i, data in enumerate(dataloader, 0):             
                Dreal_device = data['image'].to(self.device)
                Dfake_device = self.sample_([data['image'].shape[0]]+self.inputG_dims)
                d_loss = self.D_iteration(Dreal_device, Dfake_device)
                
                if step % self.N_critic == 0:
                    Dfake_device = self.sample_([data['image'].shape[0]]+self.inputG_dims)
                    g_loss = self.G_iteration(Dfake_device, withTopo)
                step = step + 1
                
                if step % FLAGS.print_step == 0:
                    print('[%d/%d][%d/%d] D_loss: %.4f G_loss: %.4f'
                          %(epoch, epochs, i, len(dataloader), d_loss, g_loss))
                if step % FLAGS.save_step == 0:
                    # ===== Save images
                    Dfake_device_ = self.netG(fixed_z_)
                    vutils.save_image(Dfake_device_.detach(),
                    '%s/generated_step_%d.png' % (FLAGS.image_save, step), normalize=True)
                    # ===== Save models ====
                    torch.save(self.netG.state_dict(), '%s/netG_step_%d.pth' % (FLAGS.model_save, step))
                    torch.save(self.netD.state_dict(), '%s/netD_step_%d.pth' % (FLAGS.model_save, step))
        print("Training complete.")