In [None]:
%run include.ipynb
%run Net.ipynb
%run Data.ipynb
%run viewer.ipynb
        
class GAN(object):
    
    def __init__(self, general, adv_params, G_arch, D_arch):
        
        lr        = general["learning_rate"]
        beta1     = general["beta1"]
        beta2     = general["beta2"]
        loss_mode = general["loss"]
        reduction = general["reduction"]
        
        self.N_critic = adv_params["wgangp"]["N_CRITIC"]
        
        cudnn.benchmark = FLAGS.cudnn_benchmark
        gpu_num     = FLAGS.gpu_num
        self.device = torch.device("cuda:0" if torch.cuda.is_available()
                      and FLAGS.gpu_enable else "cpu")
        torch.manual_seed(random.randint(1, 10000))
        
        self.inputG_dims, G_layers = Net.parse_layers(G_arch)
        self.inputD_dims, D_layers = Net.parse_layers(D_arch)
        self.netG = Network_template(gpu_num, G_layers).to(self.device)
        Net.init_weights(self.netG, "normal")
        self.netD = Network_template(gpu_num, D_layers).to(self.device)
        Net.init_weights(self.netD, "normal")
        
        self.criterion  = GANLoss(loss_mode, reduction).to(self.device)
        self.criterionT = GANLoss("vanilla_topo", "sum").to(self.device)
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=(beta1, beta2))
        #self.optimizerD = optim.Adam(filter(lambda p: p.requires_grad, self.netD.parameters()), lr=lr, betas=(beta1,beta2))
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=(beta1, beta2))
        
    def save_noise_(self, shape, name):
        '''
        shape: shape of the noise, usually it is [batch_size, 128, 1, 1]
        name: should be like 128_128_1_1_0.dat
        all noise are saved under D:/Data/fixed_z/
        '''
        z_ = torch.randn(shape, device=self.device)
        FileIO.write_binary('D:/Data/fixed_z/'+name, z_.cpu().numpy().flatten(), list(z_.shape), 'f')
        
    def sample_z_(self, shape):
        z_ = torch.randn(shape, device=self.device)
        return self.netG(z_)
        
    def sample_save_(self, name, shape, direc, scalor, binary_out):
        '''
        shape: shape of the noise, usually it is [batch_size, 128, 1, 1]
        name: should be like 128_128_1_1_0.dat
        all noise are saved under D:/Data/fixed_z/
        '''
        Path(direc).mkdir(parents=True, exist_ok=True)
        if FLAGS.continue_model:
            self.netG.load_state_dict(torch.load('%s/netG_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.netD.load_state_dict(torch.load('%s/netD_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            print("Models loaded at step %d" % FLAGS.model_step)
        
        i = 0
        interval = 1000
        z_ = FileIO.read_binary('D:/Data/fixed_z/'+name, shape, 'f')
        while True:
            si = i
            se = np.min((si + interval, z_.shape[0]))
            z_sub_ = z_[si:se,:]
            z_sub_ = torch.from_numpy(z_sub_).to(self.device)
            samples = self.netG(z_sub_)
            FileIO.save_image_batch(samples.detach().cpu().numpy(), direc, 'gen', scalor, si, binary_out)
            i = se
            if i >= z_.shape[0]:
                break
     
    def D_iteration(self, Dreal_device, Dfake_device):
        self.netD.zero_grad()      
        errD = self.criterion(["D", self.netD, self.device, Dreal_device, Dfake_device])
        errD.backward()
        self.optimizerD.step()
        return errD.item()
    
    def G_iteration(self, Dfake_device):
        self.netG.zero_grad()
        errG = self.criterion(["G", self.netD, Dfake_device])
        errG.backward()
        self.optimizerG.step()
        return errG.item()
    
    def debug_input_info(self, data):
        data = np.squeeze(data.detach().cpu().numpy())
        print("squeezed shape: ", data.shape)
        print("max: ", np.amax(data), "min: ", np.amin(data))
        viewer.imshow_(data[0,:], 'jet', gray=True)

    def train(self, data_params):       
        epochs        = data_params["epochs"]
        batch_size    = data_params["batch_size"]
        batch_workers = data_params["batch_workers"]
        shuffle       = data_params["shuffle"]
        drop_last     = data_params["drop_last"]
        dataloader    = Data_fetcher.fetch_dataset(FLAGS.dataset, batch_size, batch_workers, shuffle, drop_last, 0.5)
        log           = open(FLAGS.log_path, "a")
        
        step = 0
        if FLAGS.continue_model:
            self.netG.load_state_dict(torch.load('%s/netG_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            self.netD.load_state_dict(torch.load('%s/netD_step_%d.pth' % (FLAGS.model_save, FLAGS.model_step)))
            step = FLAGS.model_step + 1
        
        g_lrec = []
        d_lrec = []
        fixed_z_ = FileIO.read_binary('D:/Data/fixed_z/128_128_1_1_0.dat', [batch_size]+self.inputG_dims, 'f')
        fixed_z_ = torch.from_numpy(fixed_z_).to(self.device)
        for epoch in range(epochs):
            for i, data in enumerate(dataloader, 0):
                #self.debug_input_info(data)
                Dreal_device = data.to(self.device)
                Dfake_device = self.sample_z_([data.shape[0]]+self.inputG_dims)
                d_lrec.append(self.D_iteration(Dreal_device, Dfake_device))
                
                if step % self.N_critic == 0:
                    Dfake_device = self.sample_z_([data.shape[0]]+self.inputG_dims)
                    g_lrec.append(self.G_iteration(Dfake_device))
                step = step + 1
                
                if step % FLAGS.print_step == 0:
                    msg = ('[%d/%d][%d/%d] D_loss: %.4f G_loss: %.4f Step: %d'
                      %(epoch, epochs, i, len(dataloader), np.mean(np.asarray(d_lrec)), np.mean(np.asarray(g_lrec)), step))
                    g_lrec[:] = []
                    d_lrec[:] = []
                    print(msg)
                    log.write(msg+"\n")
                    log.flush()
                if step % FLAGS.save_step == 0:
                    # ===== Save images ====
                    Dfake_device_ = self.netG(fixed_z_)
                    
                    Dfake_device_cpu = np.squeeze(Dfake_device_.detach().cpu().numpy())
                    viewer.imshow_(Dfake_device_cpu[0,:], 'jet')
                    
                    
                    vutils.save_image(Dfake_device_.detach().cpu(),
                    '%s/generated_step_%d.png' % (FLAGS.image_save, step), normalize=True)
                    # ===== Save models ====
                    torch.save(self.netG.state_dict(), '%s/netG_step_%d.pth' % (FLAGS.model_save, step))
                    torch.save(self.netD.state_dict(), '%s/netD_step_%d.pth' % (FLAGS.model_save, step))
        log.close()
        print("Training complete.")