# Testing cosmogan
Aug 25, 2020

Borrowing pieces of code from : 

- https://github.com/pytorch/tutorials/blob/11569e0db3599ac214b03e01956c2971b02c64ce/beginner_source/dcgan_faces_tutorial.py
- https://github.com/exalearn/epiCorvid/tree/master/cGAN

In [1]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
# import torchvision.datasets as dset
# import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from torch.utils.data import DataLoader, TensorDataset

import time
from datetime import datetime
import glob
import pickle
import yaml
import logging

In [2]:
%matplotlib widget

## Modules

In [3]:
def f_load_config(config_file):
    with open(config_file) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config



### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)



In [4]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)



# Generator Code
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)



class Generator(nn.Module):
    def __init__(self, ngpu,nz,nc,ngf,kernel_size,stride,g_padding):
        super(Generator, self).__init__()
        self.ngpu = ngpu
#         self.nz,self.nc,self.ngf=nz,nc,ngf
#         self.kernel_size,self.g_padding=kernel_size,g_padding

        self.main = nn.Sequential(
            # nn.ConvTranspose2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Linear(nz,nc*ngf*8*8*8),# 32768
            nn.BatchNorm2d(nc,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            View(shape=[-1,ngf*8,8,8]),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
            nn.BatchNorm2d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm2d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
            nn.BatchNorm2d(ngf,eps=1e-05, momentum=0.9, affine=True),
            nn.ReLU(inplace=True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)


class Discriminator(nn.Module):
    def __init__(self, ngpu, nz,nc,ndf,kernel_size,stride,d_padding):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            # nn.Conv2d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
            nn.Conv2d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
            nn.BatchNorm2d(ndf,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm2d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm2d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
            nn.BatchNorm2d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Flatten(),
            nn.Linear(nc*ndf*8*8*8, 1)
        )

    def forward(self, input):
        return self.main(input)



In [5]:
def f_gen_images(netG,optimizerG,nz,device,ip_fname,strg,save_dir,op_size=500):
    '''Generate images for best saved models
     Arguments: ip_fname: name of input file
                strg: ['hist' or 'spec']
                op_size: Number of images to generate
    '''

    try:
        checkpoint=torch.load(ip_fname)
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        return
    
    netG.load_state_dict(checkpoint['G_state'])
#    netD.load_state_dict(checkpoint['D_state'])

    ## Load other stuff
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    # Generate batch of latent vectors
    noise = torch.randn(op_size, 1, 1, nz, device=device)
    # Generate fake image batch with G
    netG.eval() ## This is required before running inference
    gen = netG(noise)
    gen_images=gen.detach().cpu().numpy()[:,0,:,:]
    print(gen_images.shape)

    op_fname='best-%s_gen_img_epoch-%s_step-%s.npy'%(strg,epoch,iters)

    np.save(save_dir+'/images/'+op_fname,gen_images)

    print("Image saved in ",op_fname)
    
    
def f_save_checkpoint(epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc):
    ''' Checkpoint model '''
    torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.state_dict(),'D_state':netD.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc) 
    
    
def f_load_checkpoint(ip_fname,netG,netD,optimizerG,optimizerD):
    ''' Load saved checkpoint
    Also loads step, epoch, best_chi1, best_chi2'''
    
    try:
        checkpoint=torch.load(ip_fname)
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        raise SystemError
    
    ## Load checkpoint
    netG.load_state_dict(checkpoint['G_state'])
    netD.load_state_dict(checkpoint['D_state'])

    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    best_chi1=checkpoint['best_chi1']
    best_chi2=checkpoint['best_chi2']

    netG.train()
    netD.train()
    
    return iters,epoch,best_chi1,best_chi2

In [39]:
## Spectrum and histogram codes
### Pytorch code ###
####################

def f_torch_radial_profile(img, center=(None,None)):
    ''' Module to compute radial profile of a 2D image 
    Bincount causes issues with backprop, so not using this code
    '''
    
    y,x=torch.meshgrid(torch.arange(0,img.shape[0]),torch.arange(0,img.shape[1])) # Get a grid of x and y values
    if center[0]==None and center[1]==None:
        center = torch.Tensor([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0]) # compute centers

    # get radial values of every pair of points
    r = torch.sqrt((x - center[0])**2 + (y - center[1])**2)
    r= r.int()
    
#     print(r.shape,img.shape)
    # Compute histogram of r values
    tbin=torch.bincount(torch.reshape(r,(-1,)),weights=torch.reshape(img,(-1,)).type(torch.DoubleTensor))
    nr = torch.bincount(torch.reshape(r,(-1,)))
    radialprofile = tbin / nr
    
    return radialprofile[1:-1]


def f_torch_get_azimuthalAverage_with_batch(image, center=None): ### Not used in this code.
    """
    Calculate the azimuthally averaged radial profile. Only use if you need to combine batches

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
    batch, channel, height, width = image.shape
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])

    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))

    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (batch, channel,-1)))
    r_sorted = torch.gather(torch.reshape(r, (batch, channel, -1,)),2, ind)
    i_sorted = torch.gather(torch.reshape(image, (batch, channel, -1,)),2, ind)

    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[:,:,1:] - r_int[:,:,:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[2], (batch, -1))    # location of changes in radius
    rind=torch.unsqueeze(rind,1)
    nr = (rind[:,:,1:] - rind[:,:,:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin

    csum = torch.cumsum(i_sorted, axis=-1)
    print(csum.shape,rind.shape,nr.shape)

    tbin = torch.gather(csum, 2, rind[:,:,1:]) - torch.gather(csum, 2, rind[:,:,:-1])
    radial_prof = tbin / nr

    return radial_prof


def f_torch_get_azimuthalAverage(image, center=None):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 2D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
    height, width = image.shape
    # Create a grid of points with x and y coordinates
    y, x = np.indices([height,width])

    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r = torch.tensor(np.hypot(x - center[0], y - center[1]))

    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))
    print(type(ind),ind.get_device())
    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)

    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    y1=torch.rfft(arr,signal_ndim=2,onesided=False)
    real,imag=f_torch_fftshift(y1[:,:,0],y1[:,:,1])    ## last index is real/imag part
    y2=real**2+imag**2     ## Absolute value of each complex number
    
#     print(y2.shape)
    z1=f_torch_get_azimuthalAverage(y2)     ## Compute radial profile
#     z1=f_torch_radial_profile(y2)     ## Compute radial profile
    
    return(z1)

def f_torch_compute_batch_spectrum(arr):
    
    batch_pk=torch.stack([f_torch_compute_spectrum(i) for i in arr])
    
    return batch_pk

def f_torch_image_spectrum(x,num_channels):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    

    for i in range(num_channels):
        arr=x[:,i,:,:]
#         print(i,arr.shape)
        batch_pk=f_torch_compute_batch_spectrum(arr)
#         print(batch_pk.shape)
        mean[i]=torch.mean(batch_pk,axis=0)
        sdev[i]=torch.std(batch_pk,axis=0)/np.sqrt(batch_pk.shape[0])
        
    mean=torch.stack(mean)
    sdev=torch.stack(sdev)
    return mean,sdev

def f_compute_hist(data,bins):
    hist_data=torch.histc(data,bins=bins)
    ## A kind of normalization of histograms: divide by total sum
    hist_data=(hist_data*bins)/torch.sum(hist_data)

    return hist_data

### Losses 
def loss_spectrum(spec_mean,spec_mean_ref,spec_std,spec_std_ref,image_size):
    ''' Loss function for the spectrum : mean + variance '''
    
    # Log ( sum( batch value - expect value) ^ 2 ))
    
    idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
    
    ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.
    spec_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
    spec_sdev=torch.log(torch.mean(torch.pow(spec_std[:,:idx]-spec_std_ref[:,:idx],2)))

     
    lambda1=1.0;lambda2=1.0;
    ans=lambda1*spec_mean+lambda2*spec_sdev
    return ans

def loss_hist(hist_sample,hist_ref):
    
    lambda1=1.0
    return torch.log(torch.mean(torch.pow(hist_sample-hist_ref,2)))
#     return lambda1*torch.mean(torch.pow(hist_sample-hist_ref,2)).item()


## Main code

In [7]:
config_file='1_main_code/config_128.yaml'
config_dict=f_load_config(config_file)
print(config_dict)

{'description': 'GAN', 'data': {'ip_fname': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/128_square/dataset_2_smoothing_200k/norm_1_train_val.npy', 'op_loc': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'}, 'training': {'workers': 2, 'nc': 1, 'nz': 64, 'ngf': 64, 'ndf': 64, 'lr': 0.0002, 'beta1': 0.5, 'kernel_size': 5, 'stride': 2, 'g_padding': 2, 'd_padding': 2, 'image_size': 128, 'flip_prob': 0.01}}


In [8]:

workers=config_dict['training']['workers']
nc=config_dict['training']['nc']
nc,nz,ngf,ndf=config_dict['training']['nc'],config_dict['training']['nz'],config_dict['training']['ngf'],config_dict['training']['ndf']
lr,beta1=config_dict['training']['lr'],config_dict['training']['beta1']
kernel_size,stride=config_dict['training']['kernel_size'],config_dict['training']['stride']
g_padding,d_padding=config_dict['training']['g_padding'],config_dict['training']['d_padding']
image_size=config_dict['training']['image_size']
ip_fname=config_dict['data']['ip_fname']
op_loc=config_dict['data']['op_loc']
flip_prob=config_dict['training']['flip_prob']


ngpu=1
batch_size=128
spec_loss_flag=True

In [9]:
manualSeed=21245
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

device = torch.device("cuda" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)

Random Seed:  21245
cpu


In [26]:
# ip_fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/128_square/dataset_2_smoothing_200k/norm_1_train_val.npy'
img=np.load(ip_fname)[:2000].transpose(0,1,2,3)
t_img=torch.from_numpy(img)
print(img.shape,t_img.shape)
del(img)

(2000, 1, 128, 128) torch.Size([2000, 1, 128, 128])


In [27]:
dataset=TensorDataset(t_img)
dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=1,drop_last=True)

len(dataset),(dataset[0][0]).shape

(2000, torch.Size([1, 128, 128]))

In [28]:
lr=0.0002

In [29]:
### Build Models ###
# Create generator
netG = Generator(ngpu,nz,nc,ngf,kernel_size,stride,g_padding).to(device)
netG.apply(weights_init)
print(netG)
summary(netG,(1,1,64))

# Create Discriminator
netD = Discriminator(ngpu, nz,nc,ndf,kernel_size,stride,g_padding).to(device)
netD.apply(weights_init)
print(netD)
summary(netD,(1,128,128))

# Handle multi-gpu if desired
ngpu=torch.cuda.device_count()

print("Number of GPUs used",ngpu)
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
    netD = nn.DataParallel(netD, list(range(ngpu)))


Generator(
  (main): Sequential(
    (0): Linear(in_features=64, out_features=32768, bias=True)
    (1): BatchNorm2d(1, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): View()
    (4): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1), bias=False)
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): ConvTranspose2d(64, 1, kernel_size=(5, 5), s

In [30]:
# def f_size(ip):
#     p=2;s=2
# #     return (ip + 2 * 0 - 1 * (p-1) -1 )/ s + 1

#     return (ip-1)*s - 2 * p + 1 *(5-1)+ 1 + 1

# f_size(128)

In [31]:
# Initialize BCELoss function
# criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()

# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(batch_size, 1, 1, nz, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999),eps=1e-7)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999),eps=1e-7)


In [40]:
### Precompute metrics with validation data for computing losses
val_img=f_invtransform(np.load(ip_fname)[210000:213000]).transpose(0,1,2,3)
t_val_img=torch.from_numpy(val_img)
## Stored mean and std of spectrum for full input data once
mean_spec_val,sdev_spec_val=f_torch_image_spectrum(t_val_img,1)
bns=50
hist_val=f_compute_hist(t_val_img,bins=bns)
del(val_img)
del(t_val_img)


In [42]:
mean_spec_val.shape

torch.Size([1, 88])

In [34]:
run_suffix='_nb_test'
### Create prefix for foldername 
now=datetime.now()
fldr_name=now.strftime('%Y%m%d_%H%M%S') ## time format
# print(fldr_name)
save_dir=op_loc+fldr_name+run_suffix

if not os.path.exists(save_dir):
    os.makedirs(save_dir+'/models')
    os.makedirs(save_dir+'/images')

num_epochs=4


In [35]:
logging.basicConfig(filename=save_dir+'/log.log',filemode='w',format='%(name)s - %(levelname)s - %(message)s')

In [47]:
### Initialize variables
keys=['Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','spec_chi','hist_chi']
size=int(len(dataloader) * num_epochs)+1
metric_dict=dict(zip(keys,[np.empty(size)*np.nan for i in range(len(keys))]))

iters = 0; start_epoch=0
best_chi1,best_chi2=1e10,1e10

In [37]:
torch.backends.cudnn.benchmark=True

In [48]:
t0=time.time()
print("Starting Training Loop...")

for epoch in range(start_epoch,num_epochs):
    t_epoch_start=time.time()
    for count, data in enumerate(dataloader, 0):
        netG.train(); netD.train();  ### Need to add these after inference and before training

        tme1=time.time()
        ### Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        real_label = torch.full((b_size,), 1, device=device)
        fake_label = torch.full((b_size,), 0, device=device)
        g_label = torch.full((b_size,), 1, device=device) ## No flipping for Generator labels

        ## Flip labels with probability flip_prob
        for idx in np.random.choice(np.arange(b_size),size=int(np.ceil(b_size*flip_prob))):
            real_label[idx]=0; fake_label[idx]=1

        # Generate fake image batch with G
        noise = torch.randn(b_size, 1, 1, nz, device=device)
        fake = netG(noise)            

        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, real_label)
        errD_real.backward()
        D_x = output.mean().item()

        # Forward pass real batch through D
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ###Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        output = netD(fake).view(-1)
        errG_adv = criterion(output, g_label)

        # Histogram pixel intensity loss
        hist_gen=f_compute_hist(f_invtransform(fake),bins=bns)
        hist_loss=loss_hist(hist_gen,hist_val.to(device))

        # Add spectral loss
        mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1)  ### compute spectral mean,std for fake images for batch
        spec_loss=loss_spectrum(mean,mean_spec_val,sdev,sdev_spec_val,image_size)   
        print(type(spec_loss),type(errG_adv))
#             errG=errG_adv
        errG=errG_adv+spec_loss.detach()
#         errG=errG_adv+torch.Tensor(np.array([spec_loss]))
#             if spec_loss_flag: errG+=spec_loss
#                 errG+=hist_loss

        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        tme2=time.time()
        # Output training stats
        if count % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_adv: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, count, len(dataloader), errD.item(), errG_adv.item(),errG.item(), D_x, D_G_z1, D_G_z2)),
            print("Spec loss: %s,\t hist loss: %s"%(spec_loss,hist_loss)),
            print("Time taken for step %s : %s"%(iters, tme2-tme1))

        # Save metrics

        for key,val in zip(['Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss'],[errD_real.item(),errD_fake.item(),errD.item(),errG_adv.item(),errG.item(),spec_loss,hist_loss]):
            metric_dict[key][iters]=val

        ### Checkpoint the best model
        checkpoint=True

        iters += 1  ### Model has been updated, so update iters before saving metrics and model.

        ### Compute validation metrics for updated model
        netG.eval()
        with torch.no_grad():
            #fake = netG(fixed_noise).detach().cpu()
            fake = netG(fixed_noise)
#                 print('size of fake image array',fake.shape)

            hist_gen=f_compute_hist(f_invtransform(fake),bins=bns)
#                 hist_chi=loss_hist(hist_gen,hist_val)
            hist_chi=loss_hist(hist_gen,hist_val.to(device))
            mean,sdev=f_torch_image_spectrum(f_invtransform(fake),1)
            spec_chi=loss_spectrum(mean,mean_spec_val,sdev,sdev_spec_val,image_size)            
# 
#             hist_chi=0.1
        for key,val in zip(['spec_chi','hist_chi'],[spec_chi,hist_chi]):  metric_dict[key][iters]=val            
        if count == len(dataloader)-1: ## Check point at last step of epoch
            # Checkpoint model for continuing run
            f_save_checkpoint(epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_last.tar')  

        if (checkpoint and (epoch > 1)):
            # Choose best models by metric
            if hist_chi< best_chi1:
                f_save_checkpoint(epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_hist.tar')
                best_chi1=hist_chi
                print("Saving best hist model at epoch %s, step %s."%(epoch,iters))

            if  spec_chi< best_chi2:
                f_save_checkpoint(epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc=save_dir+'/models/checkpoint_best_spec.tar')
                best_chi2=spec_chi
                print("Saving best spec model at epoch %s, step %s"%(epoch,iters))

        # Save G's output on fixed_noise

        if (iters % 50 == 0) or ((epoch == num_epochs-1) and (count == len(dataloader)-1)):
            netG.eval()
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
                img_arr=np.array(fake[:,0,:,:])
                fname='gen_img_epoch-%s_step-%s'%(epoch,iters)
                np.save(save_dir+'/images/'+fname,img_arr)


    t_epoch_end=time.time()
    print("Time taken for epoch %s: %s"%(epoch,t_epoch_end-t_epoch_start))

tf=time.time()
print("Total time",tf-t0)

### Save Losses to files
with open (save_dir+'/metrics.pickle', 'wb') as f:
    pickle.dump(metric_dict,f)

# ### Generate images for best saved models    
# model_fname=save_dir+'/models/checkpoint_best_spec.tar'
# f_gen_images(netG,optimizerG,nz,device,model_fname,'spec',save_dir,2000)

# model_fname=save_dir+'/models/checkpoint_best_hist.tar'
# f_gen_images(netG,optimizerG,nz,device,model_fname,'hist',save_dir,2000)  




Starting Training Loop...
<class 'torch.Tensor'> <class 'torch.Tensor'>
[0/4][0/15]	Loss_D: 0.1732	Loss_adv: 4.8921	Loss_G: 53.6308	D(x): 5.1619	D(G(z)): -5.4940 / -4.8835
Spec loss: tensor(48.7386, dtype=torch.float64, grad_fn=<AddBackward0>),	 hist loss: tensor(3.9308, grad_fn=<LogBackward>)
Time taken for step 0 : 26.673779249191284
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
Time taken for epo

In [None]:
best_chi1,best_chi2,hist_chi, spec_chi

In [None]:
mean_spec_val.is_cuda,mean.is_cuda,sdev.is_cuda

In [None]:
type(hist_gen),type(hist_gen_loss),type(hist_val)

hist_gen.is_cuda,hist_val.is_cuda, hist_gen_loss.is_cuda


In [None]:
hist_gen.is_cuda 

In [None]:

hist_gen_loss-hist_val.to(device)
hist_gen_loss-hist_val.to(device)

In [None]:
# ! jupyter nbconvert --to script cosmogan_test.ipynb

In [None]:
print(metric_dict.keys())
metric_dict['spec_chi']
# metric_dict['hist_chi']

In [None]:
len(dataloader)

In [None]:
### Test the loss functions on results and keras results

In [None]:
0.17*3125