# Testing cosmogan for 3D images
Jan 4, 2021


In [1]:
import os
import random
import logging
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import argparse
import time
from datetime import datetime
import glob
import pickle
import yaml
import collections

In [2]:
%matplotlib widget

## Modules

In [3]:
def f_load_config(config_file):
    with open(config_file) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config

### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


### Model definition

In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

class Generator(nn.Module):
    def __init__(self, gdict):
        super(Generator, self).__init__()

        ## Define new variables from dict
        keys=['ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
        ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

        self.main = nn.Sequential(
            # nn.ConvTranspose3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Linear(nz,nc*ngf*8**3),# 262144
            nn.BatchNorm3d(nc,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            View(shape=[-1,ngf*8,4,4,4]),
            nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
            nn.BatchNorm3d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose3d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm3d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose3d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm3d(ngf,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose3d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, ip):
        return self.main(ip)

class Discriminator(nn.Module):
    def __init__(self, gdict):
        super(Discriminator, self).__init__()
        
        ## Define new variables from dict
        keys=['ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
        ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())        

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            # nn.Conv3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Conv3d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
            nn.BatchNorm3d(ndf,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv3d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv3d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv3d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(nc*ndf*8*8*8, 1)
#             nn.Sigmoid()
        )

    def forward(self, ip):
        return self.main(ip)



### checkpoint and generate images

In [5]:
def f_gen_images(gdict,netG,optimizerG,ip_fname,op_loc,op_strg='inf_img_',op_size=500):
    '''Generate images for best saved models
     Arguments: gdict, netG, optimizerG, 
                 ip_fname: name of input file
                op_strg: [string name for output file]
                op_size: Number of images to generate
    '''

    nz,device=gdict['nz'],gdict['device']

    try:
        if torch.cuda.is_available(): checkpoint=torch.load(ip_fname)
        else: checkpoint=torch.load(ip_fname,map_location=torch.device('cpu'))
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        return
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
    
    ## Load other stuff
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    # Generate batch of latent vectors
    noise = torch.randn(op_size, 1, 1, 1, nz, device=device)
    # Generate fake image batch with G
    netG.eval() ## This is required before running inference
    gen = netG(noise)
    gen_images=gen.detach().cpu().numpy()[:,0,:,:]
    print(gen_images.shape)
    
    op_fname='%s_epoch-%s_step-%s.npy'%(op_strg,epoch,iters)

    np.save(op_loc+op_fname,gen_images)

    print("Image saved in ",op_fname)
    
def f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc):
    ''' Checkpoint model '''
    
    if gdict['multi-gpu']: ## Dataparallel
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.module.state_dict(),'D_state':netD.module.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc) 
    else :
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.state_dict(),'D_state':netD.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc)
    
def f_load_checkpoint(ip_fname,netG,netD,optimizerG,optimizerD,gdict):
    ''' Load saved checkpoint
    Also loads step, epoch, best_chi1, best_chi2'''
    
    try:
        checkpoint=torch.load(ip_fname)
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        raise SystemError
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
        netD.module.load_state_dict(checkpoint['D_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
        netD.load_state_dict(checkpoint['D_state'])
    
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    best_chi1=checkpoint['best_chi1']
    best_chi2=checkpoint['best_chi2']

    netG.train()
    netD.train()
    
    return iters,epoch,best_chi1,best_chi2

### spectral loss

In [29]:
####################
### Pytorch code ###
####################
def f_get_rad(img):
    ''' Get the radial tensor for use in f_torch_get_azimuthalAverage '''
    
    height,width,depth=img.shape[-3:]
    # Create a grid of points with x and y coordinates
    z,y,x = np.indices([height,width,depth])
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r= torch.tensor(np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2))
        
    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))

    return r.detach(),ind.detach()

def f_torch_get_azimuthalAverage(image,r,ind):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 3D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
#     height,width,depth=img.shape[-3:]
#     # Create a grid of points with x and y coordinates
#     z,y,x = np.indices([height,width,depth])
    
#     center=[]
#     if not center:
#         center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0])

#     # Get the radial coordinate for every grid point. Array has the shape of image
#     r= torch.tensor(np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2))
        
#     # Get sorted radii
#     ind = torch.argsort(torch.reshape(r, (-1,)))

    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)
    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr,r,ind):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    
    y1=torch.rfft(arr,signal_ndim=3,onesided=False) ## 3D FFT
    real,imag=f_torch_fftshift(y1[:,:,:,0],y1[:,:,:,1])    ## last index is real/imag part
    y2=real**2+imag**2     ## Absolute value of each complex number
    z1=f_torch_get_azimuthalAverage(y2,r,ind)     ## Compute radial profile
    return z1


def f_torch_compute_batch_spectrum(arr,r,ind):
    
    batch_pk=torch.stack([f_torch_compute_spectrum(i,r,ind) for i in arr])
    
    return batch_pk

def f_torch_image_spectrum(x,num_channels,r,ind):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    
    
    for i in range(num_channels):
        arr=x[:,i,:,:,:]
        batch_pk=f_torch_compute_batch_spectrum(arr,r,ind)
        mean[i]=torch.mean(batch_pk,axis=0)
#         sdev[i]=torch.std(batch_pk,axis=0)/np.sqrt(batch_pk.shape[0])
#         sdev[i]=torch.std(batch_pk,axis=0)
        sdev[i]=torch.var(batch_pk,axis=0)
    
    mean=torch.stack(mean)
    sdev=torch.stack(sdev)
        
    return mean,sdev


def f_compute_hist(data,bins):
    
    try: 
        hist_data=torch.histc(data,bins=bins)
        ## A kind of normalization of histograms: divide by total sum
        hist_data=(hist_data*bins)/torch.sum(hist_data)
    except Exception as e:
        print(e)
        hist_data=torch.zeros(bins)

    return hist_data

### Losses 
def loss_spectrum(spec_mean,spec_mean_ref,spec_std,spec_std_ref,image_size,lambda1):
    ''' Loss function for the spectrum : mean + variance 
    Log(sum( batch value - expect value) ^ 2 )) '''
    
    idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
    ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.
        
    spec_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
    spec_sdev=torch.log(torch.mean(torch.pow(spec_std[:,:idx]-spec_std_ref[:,:idx],2)))
    
    lambda1=lambda1;
    lambda2=lambda1;
    ans=lambda1*spec_mean+lambda2*spec_sdev
    
    if torch.isnan(spec_sdev).any():    print("spec loss with nan",ans)
    
    return ans
    
def loss_hist(hist_sample,hist_ref):
    
    lambda1=1.0
    return lambda1*torch.log(torch.mean(torch.pow(hist_sample-hist_ref,2)))


In [7]:
# def f_size(ip):
#     p=2;s=2
# #     return (ip + 2 * 0 - 1 * (p-1) -1 )/ s + 1

#     return (ip-1)*s - 2 * p + 1 *(5-1)+ 1 + 1

# f_size(128)

In [8]:
# logging.basicConfig(filename=save_dir+'/log.log',filemode='w',format='%(name)s - %(levelname)s - %(message)s')

## Main code

In [42]:
def f_train_loop(dataloader,metrics_df,gdict):
    ''' Train single epoch '''
    
    ## Define new variables from dict
    keys=['image_size','start_epoch','epochs','iters','best_chi1','best_chi2','save_dir','device','flip_prob','nz','batchsize','bns']
    image_size,start_epoch,epochs,iters,best_chi1,best_chi2,save_dir,device,flip_prob,nz,batchsize,bns=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
    
    for epoch in range(start_epoch,epochs):
        t_epoch_start=time.time()
        for count, data in enumerate(dataloader, 0):
            
            ####### Train GAN ########
            netG.train(); netD.train();  ### Need to add these after inference and before training

            tme1=time.time()
            ### Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            netD.zero_grad()
            
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            real_label = torch.full((b_size,), 1, device=device)
            fake_label = torch.full((b_size,), 0, device=device)
            g_label = torch.full((b_size,), 1, device=device) ## No flipping for Generator labels
            # Flip labels with probability flip_prob
            for idx in np.random.choice(np.arange(b_size),size=int(np.ceil(b_size*flip_prob))):
                real_label[idx]=0; fake_label[idx]=1

            # Generate fake image batch with G
            noise = torch.randn(b_size, 1, 1, 1, nz, device=device)
            fake = netG(noise)            

            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            errD_real = criterion(output, real_label)
            errD_real.backward()
            D_x = output.mean().item()

            # Forward pass real batch through D
            output = netD(fake.detach()).view(-1)
            errD_fake = criterion(output, fake_label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ###Update G network: maximize log(D(G(z)))
            netG.zero_grad()
            output = netD(fake).view(-1)
            errG_adv = criterion(output, g_label)
            # Histogram pixel intensity loss
            hist_gen=f_compute_hist(fake,bins=bns)
            hist_loss=loss_hist(hist_gen,hist_val.to(device))

            # Add spectral loss
            mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1,r.to(device),ind.to(device))
            spec_loss=loss_spectrum(mean,mean_spec_val.to(device),sdev,sdev_spec_val.to(device),image_size,gdict['lambda1'])
            
            if gdict['spec_loss_flag']: errG=errG_adv+spec_loss
            else: errG=errG_adv
            
            if torch.isnan(errG).any():
                logging.info(errG)
                raise SystemError
            
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            tme2=time.time()
            
            ####### Store metrics ########
            # Output training stats
            if count % gdict['checkpoint_size'] == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_adv: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, epochs, count, len(dataloader), errD.item(), errG_adv.item(),errG.item(), D_x, D_G_z1, D_G_z2)),
                print("Spec loss: %s,\t hist loss: %s"%(spec_loss.item(),hist_loss.item())),
                print("Training time for step %s : %s"%(iters, tme2-tme1))

            # Save metrics
            cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','D(x)','D_G_z1','D_G_z2','time']
            vals=[iters,epoch,errD_real.item(),errD_fake.item(),errD.item(),errG_adv.item(),errG.item(),spec_loss.item(),hist_loss.item(),D_x,D_G_z1,D_G_z2,tme2-tme1]
            for col,val in zip(cols,vals):  metrics_df.loc[iters,col]=val

            ### Checkpoint the best model
            checkpoint=True
            iters += 1  ### Model has been updated, so update iters before saving metrics and model.

            ### Compute validation metrics for updated model
            netG.eval()
            with torch.no_grad():
                #fake = netG(fixed_noise).detach().cpu()
                fake = netG(fixed_noise)
                hist_gen=f_compute_hist(fake,bins=bns)
                hist_chi=loss_hist(hist_gen,hist_val.to(device))
                mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1,r.to(device),ind.to(device))
                spec_chi=loss_spectrum(mean,mean_spec_val.to(device),sdev,sdev_spec_val.to(device),image_size,gdict['lambda1'])      
            # Storing chi for next step
            for col,val in zip(['spec_chi','hist_chi'],[spec_chi.item(),hist_chi.item()]):  metrics_df.loc[iters,col]=val            

            # Checkpoint model for continuing run
            if count == len(dataloader)-1: ## Check point at last step of epoch
                f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_last.tar')  

            if (checkpoint and (epoch > 1)): # Choose best models by metric
                if hist_chi< best_chi1:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_hist.tar')
                    best_chi1=hist_chi.item()
                    logging.info("Saving best hist model at epoch %s, step %s."%(epoch,iters))

                if  spec_chi< best_chi2:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_spec.tar')
                    best_chi2=spec_chi.item()
                    logging.info("Saving best spec model at epoch %s, step %s"%(epoch,iters))
                    
                if iters in gdict['save_steps_list']:
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_{0}.tar'.format(iters))
                    logging.info("Saving given-step at epoch %s, step %s."%(epoch,iters))
                    
            # Save G's output on fixed_noise
            if ((iters % gdict['checkpoint_size'] == 0) or ((epoch == epochs-1) and (count == len(dataloader)-1))):
                netG.eval()
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                    img_arr=np.array(fake[:,0,:,:])
                    fname='gen_img_epoch-%s_step-%s'%(epoch,iters)
                    np.save(save_dir+'/images/'+fname,img_arr)
        
        t_epoch_end=time.time()
        logging.info("Time taken for epoch %s: %s"%(epoch,t_epoch_end-t_epoch_start))
        # Save Metrics to file after each epoch
        metrics_df.to_pickle(save_dir+'/df_metrics.pkle')
        
    logging.info("best chis: {0}, {1}".format(best_chi1,best_chi2))

In [43]:
def f_init_gdict(gdict,config_dict):
    ''' Initialize the global dictionary gdict with values in config file'''
    keys1=['workers','nc','nz','ngf','ndf','beta1','kernel_size','stride','g_padding','d_padding','flip_prob']
    keys2=['image_size','checkpoint_size','num_imgs','ip_fname','op_loc']
    for key in keys1: gdict[key]=config_dict['training'][key]
    for key in keys2: gdict[key]=config_dict['data'][key]

## Start

In [44]:
if __name__=="__main__":
    torch.backends.cudnn.benchmark=True
#     torch.autograd.set_detect_anomaly(True)

    t0=time.time()
    #################################
#     args=f_parse_args()
    # Manually add args ( different for jupyter notebook)
    args=argparse.Namespace()
    args.config='1_main_code/config_3d.yaml'
    args.ngpu=1
    args.batchsize=32
    args.spec_loss_flag=True
    args.checkpoint_size=50
    args.epochs=10
    args.learn_rate=0.0002
    args.mode='fresh'
#     args.mode='continue'
#     args.ip_fldr='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201211_093818_nb_test/'
    args.run_suffix='nb_test'
    args.deterministic=False
    args.seed='234373'
    args.lambda1=0.1
    args.save_steps_list=[5,10]

    ### Set up ###
    config_file=args.config
    config_dict=f_load_config(config_file)

    # Initilize variables    
    gdict={}
    f_init_gdict(gdict,config_dict)
    
    ## Add args variables to gdict
    for key in ['ngpu','batchsize','mode','spec_loss_flag','epochs','learn_rate','lambda1','save_steps_list']:
        gdict[key]=vars(args)[key]
       
    ###### Set up directories #######
    if gdict['mode']=='fresh':
        # Create prefix for foldername        
        fldr_name=datetime.now().strftime('%Y%m%d_%H%M%S') ## time format
        gdict['save_dir']=gdict['op_loc']+fldr_name+'_'+args.run_suffix
        
        if not os.path.exists(gdict['save_dir']):
            os.makedirs(gdict['save_dir']+'/models')
            os.makedirs(gdict['save_dir']+'/images')
        
    elif gdict['mode']=='continue': ## For checkpointed runs
        gdict['save_dir']=args.ip_fldr
        ### Read loss data
        with open (gdict['save_dir']+'df_metrics.pkle','rb') as f:
            metrics_dict=pickle.load(f) 

#     ### Write all logging.info statements to stdout and log file (different for jpt notebooks)
#     logfile=gdict['save_dir']+'/log.log'
#     logging.basicConfig(level=logging.DEBUG, filename=logfile, filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")
    
#     Lg = logging.getLogger()
#     Lg.setLevel(logging.DEBUG)
#     lg_handler_file = logging.FileHandler(logfile)
#     lg_handler_stdout = logging.StreamHandler(sys.stdout)
#     Lg.addHandler(lg_handler_file)
#     Lg.addHandler(lg_handler_stdout)
    
#     logging.info('Args: {0}'.format(args))
#     logging.info(config_dict)
#     logging.info('Start: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
#     if gdict['spec_loss_flag']: logging.info("Using Spectral loss")

    ### Override (different for jpt notebooks)
    gdict['num_imgs']=2000
    
    ## Special declarations
    gdict['bns']=50
    gdict['device']=torch.device("cuda" if (torch.cuda.is_available() and gdict['ngpu'] > 0) else "cpu")
    gdict['ngpu']=torch.cuda.device_count()
    
    gdict['multi-gpu']=True if (gdict['device'].type == 'cuda') and (gdict['ngpu'] > 1) else False 
    print(gdict)
    
    ### Initialize random seed
    if args.seed=='random': manualSeed = np.random.randint(1, 10000)
    else: manualSeed=int(args.seed)
    logging.info("Seed:{0}".format(manualSeed))
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    logging.info('Device:{0}'.format(gdict['device']))
    
    if args.deterministic: 
        logging.info("Running with deterministic sequence. Performance will be slower")
        torch.backends.cudnn.deterministic=True
#         torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
    
    #################################
    ####### Read data and precompute ######
    img=np.load(gdict['ip_fname'],mmap_mode='r')[:gdict['num_imgs']].transpose(0,1,2,3,4).astype(np.float32)
#     img=f_transform(img)
    t_img=torch.from_numpy(img)
    print("%s, %s"%(img.shape,t_img.shape))

    dataset=TensorDataset(t_img)
    dataloader=DataLoader(dataset,batch_size=gdict['batchsize'],shuffle=True,num_workers=0,drop_last=True)

    # Precompute metrics with validation data for computing losses
    with torch.no_grad():
        val_img=np.load(gdict['ip_fname'],mmap_mode='r')[-30:].transpose(0,1,2,3,4).astype(np.float32)
#         val_img=f_transform(val_img)
        t_val_img=torch.from_numpy(val_img).to(gdict['device'])

        # Precompute radial coordinates
        r,ind=f_get_rad(img)
        r=r.to(gdict['device']); ind=ind.to(gdict['device'])
        # Stored mean and std of spectrum for full input data once
        mean_spec_val,sdev_spec_val=f_torch_image_spectrum(f_invtransform(t_val_img),1,r,ind)
        hist_val=f_compute_hist(t_val_img,bins=gdict['bns'])
    #     del val_img; del t_val_img; del img; del t_img

    #################################
    ###### Build Networks ###
    # Define Models
    print("Building GAN networks")
    # Create Generator
    netG = Generator(gdict).to(gdict['device'])
    netG.apply(weights_init)
#     print(netG)
    summary(netG,(1,1,1,64))
    # Create Discriminator
    netD = Discriminator(gdict).to(gdict['device'])
    netD.apply(weights_init)
#     print(netD)
    summary(netD,(1,64,64,64))
    
    print("Number of GPUs used %s"%(gdict['ngpu']))
    if (gdict['multi-gpu']):
        netG = nn.DataParallel(netG, list(range(gdict['ngpu'])))
        netD = nn.DataParallel(netD, list(range(gdict['ngpu'])))
    
    #### Initialize networks ####
    # criterion = nn.BCELoss()
    criterion = nn.BCEWithLogitsLoss()
    
    if gdict['mode']=='fresh':
        optimizerD = optim.Adam(netD.parameters(), lr=gdict['learn_rate'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        optimizerG = optim.Adam(netG.parameters(), lr=gdict['learn_rate'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        ### Initialize variables
        iters,start_epoch,best_chi1,best_chi2=0,0,1e10,1e10    
    
    ### Load network weights for continuing run
    elif gdict['mode']=='continue':
        iters,start_epoch,best_chi1,best_chi2=f_load_checkpoint(gdict['save_dir']+'/models/checkpoint_last.tar',netG,netD,optimizerG,optimizerD,gdict) 
        logging.info("Continuing existing run. Loading checkpoint with epoch {0} and step {1}".format(start_epoch,iters))
        start_epoch+=1  ## Start with the next epoch  

    ## Add to gdict
    for key,val in zip(['best_chi1','best_chi2','iters','start_epoch'],[best_chi1,best_chi2,iters,start_epoch]): gdict[key]=val
    print(gdict)
    
    fixed_noise = torch.randn(gdict['batchsize'], 1, 1, 1, gdict['nz'], device=gdict['device']) #Latent vectors to view G progress    



{'workers': 2, 'nc': 1, 'nz': 64, 'ngf': 64, 'ndf': 64, 'beta1': 0.5, 'kernel_size': 5, 'stride': 2, 'g_padding': 2, 'd_padding': 2, 'flip_prob': 0.01, 'image_size': 64, 'checkpoint_size': 10, 'num_imgs': 2000, 'ip_fname': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/train.npy', 'op_loc': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d/', 'ngpu': 1, 'batchsize': 32, 'mode': 'fresh', 'spec_loss_flag': True, 'epochs': 10, 'learn_rate': 0.0002, 'lambda1': 0.1, 'save_steps_list': [5, 10], 'save_dir': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d/20210108_092135_nb_test', 'bns': 50, 'device': device(type='cuda'), 'multi-gpu': False}
(2000, 1, 64, 64, 64), torch.Size([2000, 1, 64, 64, 64])
Building GAN networks
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1       [-1, 1, 1, 1, 32768]       

In [45]:
if __name__=="__main__":
    #################################       
    ### Set up metrics dataframe
    cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','spec_chi','hist_chi','D(x)','D_G_z1','D_G_z2','time']
    # size=int(len(dataloader) * epochs)+1
    metrics_df=pd.DataFrame(columns=cols)
    
    #################################
    ########## Train loop and save metrics and images ######
    print("Starting Training Loop...")
    f_train_loop(dataloader,metrics_df,gdict)
    
    ## Generate images for best saved models ######
    op_loc=gdict['save_dir']+'/images/'
    ip_fname=gdict['save_dir']+'/models/checkpoint_best_spec.tar'
    f_gen_images(gdict,netG,optimizerG,ip_fname,op_loc,op_strg='best_spec',op_size=200)
    
    ip_fname=gdict['save_dir']+'/models/checkpoint_best_hist.tar'
    f_gen_images(gdict,netG,optimizerG,ip_fname,op_loc,op_strg='best_hist',op_size=200)
    
    tf=time.time()
    print("Total time %s"%(tf-t0))
    print('End: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
    

Starting Training Loop...
[0/10][0/62]	Loss_D: 1.4373	Loss_adv: 6.3713	Loss_G: 18.0367	D(x): -0.1694	D(G(z)): -0.1583 / -6.3686
Spec loss: 11.665422439575195,	 hist loss: 3.8162171840667725
Training time for step 0 : 0.710432767868042
[0/10][10/62]	Loss_D: 0.6024	Loss_adv: 11.5320	Loss_G: 22.3816	D(x): 6.7351	D(G(z)): -11.1264 / -11.5320
Spec loss: 10.849590301513672,	 hist loss: 3.8161778450012207
Training time for step 10 : 0.6442148685455322
[0/10][20/62]	Loss_D: 1.8957	Loss_adv: 8.4732	Loss_G: 19.3249	D(x): 2.6858	D(G(z)): 1.6109 / -8.4730
Spec loss: 10.851640701293945,	 hist loss: 3.80815052986145
Training time for step 20 : 0.644254207611084


KeyboardInterrupt: 

In [39]:
# metrics_df.plot('step','time')
# metrics_df
metrics_df.plot(x='step',y=['hist_loss','spec_chi'],kind='line')


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:xlabel='step'>

## Testing

In [None]:
if __name__=="__main__":
    torch.backends.cudnn.benchmark=True
#     torch.autograd.set_detect_anomaly(True)

    t0=time.time()
    #################################
#     args=f_parse_args()
    # Manually add args ( different for jupyter notebook)
    args=argparse.Namespace()
    args.config='1_main_code/config_3d.yaml'
    args.ngpu=1
    args.batchsize=32
    args.spec_loss_flag=True
    args.checkpoint_size=50
    args.epochs=10
    args.learn_rate=0.0002
    args.mode='fresh'
#     args.mode='continue'
#     args.ip_fldr='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201211_093818_nb_test/'
    args.run_suffix='nb_test'
    args.deterministic=False
    args.seed='234373'
    args.lambda1=0.1
    args.save_steps_list=[5,10]

    ### Set up ###
    config_file=args.config
    config_dict=f_load_config(config_file)

    # Initilize variables    
    gdict={}
    f_init_gdict(gdict,config_dict)
    
    ## Add args variables to gdict
    for key in ['ngpu','batchsize','mode','spec_loss_flag','epochs','learn_rate','lambda1','save_steps_list']:
        gdict[key]=vars(args)[key]
       
    ###### Set up directories #######
    if gdict['mode']=='fresh':
        # Create prefix for foldername        
        fldr_name=datetime.now().strftime('%Y%m%d_%H%M%S') ## time format
        gdict['save_dir']=gdict['op_loc']+fldr_name+'_'+args.run_suffix
        
        if not os.path.exists(gdict['save_dir']):
            os.makedirs(gdict['save_dir']+'/models')
            os.makedirs(gdict['save_dir']+'/images')
        
    elif gdict['mode']=='continue': ## For checkpointed runs
        gdict['save_dir']=args.ip_fldr
        ### Read loss data
        with open (gdict['save_dir']+'df_metrics.pkle','rb') as f:
            metrics_dict=pickle.load(f) 

#     ### Write all logging.info statements to stdout and log file (different for jpt notebooks)
#     logfile=gdict['save_dir']+'/log.log'
#     logging.basicConfig(level=logging.DEBUG, filename=logfile, filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")
    
#     Lg = logging.getLogger()
#     Lg.setLevel(logging.DEBUG)
#     lg_handler_file = logging.FileHandler(logfile)
#     lg_handler_stdout = logging.StreamHandler(sys.stdout)
#     Lg.addHandler(lg_handler_file)
#     Lg.addHandler(lg_handler_stdout)
    
#     logging.info('Args: {0}'.format(args))
#     logging.info(config_dict)
#     logging.info('Start: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
#     if gdict['spec_loss_flag']: logging.info("Using Spectral loss")

    ### Override (different for jpt notebooks)
    gdict['num_imgs']=20
    
    ## Special declarations
    gdict['bns']=50
    gdict['device']=torch.device("cuda" if (torch.cuda.is_available() and gdict['ngpu'] > 0) else "cpu")
    gdict['ngpu']=torch.cuda.device_count()
    
    gdict['multi-gpu']=True if (gdict['device'].type == 'cuda') and (gdict['ngpu'] > 1) else False 
    print(gdict)
    
    ### Initialize random seed
    if args.seed=='random': manualSeed = np.random.randint(1, 10000)
    else: manualSeed=int(args.seed)
    logging.info("Seed:{0}".format(manualSeed))
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    logging.info('Device:{0}'.format(gdict['device']))
    
    if args.deterministic: 
        logging.info("Running with deterministic sequence. Performance will be slower")
        torch.backends.cudnn.deterministic=True
#         torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
    




In [None]:
#################################
####### Read data and precompute ######
img=np.load(gdict['ip_fname'],mmap_mode='r')[:gdict['num_imgs']].transpose(0,1,2,3,4).astype(float)
img=f_transform(img)
t_img=torch.from_numpy(img)
print("%s, %s"%(img.shape,t_img.shape))

dataset=TensorDataset(t_img)
dataloader=DataLoader(dataset,batch_size=gdict['batchsize'],shuffle=True,num_workers=0,drop_last=True)

# Precompute metrics with validation data for computing losses
with torch.no_grad():
    val_img=np.load(gdict['ip_fname'],mmap_mode='r')[-30:].transpose(0,1,2,3,4).astype(float)
    val_img=f_transform(val_img)
    t_val_img=torch.from_numpy(val_img).to(gdict['device'])

    # Precompute radial coordinates
    r,ind=f_get_rad(img)
    r=r.to(gdict['device']); ind=ind.to(gdict['device'])
    # Stored mean and std of spectrum for full input data once
    mean_spec_val,sdev_spec_val=f_torch_image_spectrum(f_invtransform(t_val_img),1,r,ind)
    hist_val=f_compute_hist(t_val_img,bins=gdict['bns'])
#     del val_img; del t_val_img; del img; del t_img

In [None]:
torch.max(t_val_img),torch.min(t_val_img)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module):
    def __init__(self, gdict):
        super(Generator, self).__init__()

        ## Define new variables from dict
        keys=['ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
        ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

        self.main = nn.Sequential(
            # nn.ConvTranspose3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Linear(nz,nc*ngf*8**3),# 262144
            nn.BatchNorm3d(nc,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            View(shape=[-1,ngf*8,4,4,4]),
            nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
            nn.BatchNorm3d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose3d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm3d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose3d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm3d(ngf,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose3d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, ip):
        return self.main(ip)

class Discriminator(nn.Module):
    def __init__(self, gdict):
        super(Discriminator, self).__init__()
        
        ## Define new variables from dict
        keys=['ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
        ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())        

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            # nn.Conv3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Conv3d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
            nn.BatchNorm3d(ndf,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv3d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv3d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv3d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm3d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(nc*ndf*8*8*8, 1)
#             nn.Sigmoid()
        )

    def forward(self, ip):
        return self.main(ip)



In [None]:
netD = Discriminator(gdict).to(gdict['device'])
netD.apply(weights_init)
summary(netD,(1,64,64,64))

netG = Generator(gdict).to(gdict['device'])
netG.apply(weights_init)
#print(netG)
summary(netG,(1,1,1,64))

In [None]:
netG = Generator(gdict).to(gdict['device'])
noise = torch.randn(gdict['batchsize'], 1, 1, 1, gdict['nz'], device=gdict['device']) #Latent vectors to view G progress    
netG(noise).shape

In [None]:
fake=netG(noise)
print(fake.shape)
output=netD(fake)
print(output.shape)
