# Testing cosmogan
Aug 25, 2020

Borrowing pieces of code from : 

- https://github.com/pytorch/tutorials/blob/11569e0db3599ac214b03e01956c2971b02c64ce/beginner_source/dcgan_faces_tutorial.py
- https://github.com/exalearn/epiCorvid/tree/master/cGAN

In [1]:
import os
import random
import logging
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
# import torch.fft

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

import argparse
import time
from datetime import datetime
import glob
import pickle
import yaml
import collections
import socket
import shutil

# # Import modules from other files
# from utils import *
# from spec_loss import *

In [2]:
%matplotlib widget

## Modules

In [3]:
# Mod for 3D
def f_get_model(model_name,gdict):
    ''' Module to define Generator and Discriminator'''
#     print("Model name",model_name)

    if model_name==2: #### Concatenate sigma input
        if gdict['image_size']==64:
            class Generator(nn.Module):
                def __init__(self, gdict):
                    super(Generator, self).__init__()

                    ## Define new variables from dict
                    keys=['ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                    ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                    self.main = nn.Sequential(
                        # nn.ConvTranspose3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                        nn.Linear(nz+1,nc*ngf*8**3),# 262144
                        nn.BatchNorm3d(nc,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        View(shape=[-1,ngf*8,4,4,4]),
                        nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                        nn.BatchNorm3d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf*4) x 8 x 8
                        nn.ConvTranspose3d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                        nn.BatchNorm3d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf*2) x 16 x 16
                        nn.ConvTranspose3d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                        nn.BatchNorm3d(ngf,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf) x 32 x 32
                        nn.ConvTranspose3d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                        nn.Tanh()
                    )

                def forward(self, noise,labels):
                    x=labels.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float()
                    gen_input=torch.cat((noise,x),-1)
                    img=self.main(gen_input)

                    return img

            class Discriminator(nn.Module):
                def __init__(self, gdict):
                    super(Discriminator, self).__init__()

                    ## Define new variables from dict
                    keys=['ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                    ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                    self.linear_transf=nn.Linear(4,4)
                    self.main = nn.Sequential(
                        # input is (nc) x 64 x 64
                        # nn.Conv3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                        nn.Conv3d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                        nn.BatchNorm3d(ndf,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf) x 32 x 32
                        nn.Conv3d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*2) x 16 x 16
                        nn.Conv3d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*4) x 8 x 8
                        nn.Conv3d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*8) x 4 x 4
                        nn.Flatten(),
                        nn.Linear(nc*ndf*8*8*8, 1)
            #             nn.Sigmoid()
                    )

                def forward(self, img,labels):
                    img_size=gdict['image_size']
                    x=labels.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,1,4).float() # get to size (batch,1,1,1,4)
                    x=self.linear_transf(x)
                    x=torch.repeat_interleave(x,int((img_size*img_size*img_size)/4)) # get to size (batch* img^3)
                    x=x.view(labels.size(0),1,img_size,img_size,img_size) ## Get to size (batch,1,img,img,img)

                    ip=torch.cat((img,x),axis=1)

                    results=[ip]
                    lst_idx=[]
                    for i,submodel in enumerate(self.main.children()):
                        mid_output=submodel(results[-1])
                        results.append(mid_output)
                        ## Select indices in list corresponding to output of Conv layers
                        if submodel.__class__.__name__.startswith('Conv'):
            #                 print(submodel.__class__.__name__)
            #                 print(mid_output.shape)
                            lst_idx.append(i)

                    FMloss=True
                    if FMloss:
                        ans=[results[1:][i] for i in lst_idx + [-1]]
                    else :
                        ans=results[-1]

                    return ans    
        
        elif gdict['image_size']==128:

            class Generator(nn.Module):
                def __init__(self, gdict):
                    super(Generator, self).__init__()

                    ## Define new variables from dict
                    keys=['ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                    ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                    self.main = nn.Sequential(
                        # nn.ConvTranspose3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                        nn.Linear(nz+1,nc*ngf*8**3*8),# 262144
                        nn.BatchNorm3d(nc,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        View(shape=[-1,ngf*8,8,8,8]),
                        nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                        nn.BatchNorm3d(ngf*4,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf*4) x 8 x 8
                        nn.ConvTranspose3d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                        nn.BatchNorm3d(ngf*2,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf*2) x 16 x 16
                        nn.ConvTranspose3d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                        nn.BatchNorm3d(ngf,eps=1e-05, momentum=0.9, affine=True),
                        nn.ReLU(inplace=True),
                        # state size. (ngf) x 32 x 32
                        nn.ConvTranspose3d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                        nn.Tanh()
                    )

                def forward(self, noise,labels):
                    x=labels.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float()
                    print(x.shape)
                    gen_input=torch.cat((noise,x),-1)
                    print(gen_input.shape)
                    img=self.main(gen_input)

                    return img
                

            class Discriminator(nn.Module):
                def __init__(self, gdict):
                    super(Discriminator, self).__init__()

                    ## Define new variables from dict
                    keys=['ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                    ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                    self.linear_transf=nn.Linear(4,4)
                    self.main = nn.Sequential(
                        # input is (nc) x 64 x 64
                        # nn.Conv3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                        nn.Conv3d(nc+1, ndf,kernel_size, stride, d_padding,  bias=True),
                        nn.BatchNorm3d(ndf,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf) x 32 x 32
                        nn.Conv3d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 2,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*2) x 16 x 16
                        nn.Conv3d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 4,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*4) x 8 x 8
                        nn.Conv3d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                        nn.BatchNorm3d(ndf * 8,eps=1e-05, momentum=0.9, affine=True),
                        nn.LeakyReLU(0.2, inplace=True),
                        # state size. (ndf*8) x 4 x 4
                        nn.Flatten(),
                        nn.Linear(nc*ndf*8*8*8*8, 1)
            #             nn.Sigmoid()
                    )

                def forward(self, img,labels):
                    img_size=gdict['image_size']
                    x=labels.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,1,4).float() # get to size (batch,1,1,1,4)
                    x=self.linear_transf(x)
                    x=torch.repeat_interleave(x,int((img_size*img_size*img_size)/4)) # get to size (batch* img^3)
                    x=x.view(labels.size(0),1,img_size,img_size,img_size) ## Get to size (batch,1,img,img,img)

                    ip=torch.cat((img,x),axis=1)

                    results=[ip]
                    lst_idx=[]
                    for i,submodel in enumerate(self.main.children()):
                        mid_output=submodel(results[-1])
                        results.append(mid_output)
                        ## Select indices in list corresponding to output of Conv layers
                        if submodel.__class__.__name__.startswith('Conv'):
            #                 print(submodel.__class__.__name__)
            #                 print(mid_output.shape)
                            lst_idx.append(i)

                    FMloss=True
                    if FMloss:
                        ans=[results[1:][i] for i in lst_idx + [-1]]
                    else :
                        ans=results[-1]

                    return ans 
                
                
                
    elif model_name==3:#### Model 3: with ConditionalInstanceNorm2d,
        class ConditionalInstanceNorm2d(nn.Module):
            def __init__(self, num_features, num_params):
                super().__init__()
                self.num_features = num_features
                self.InstNorm = nn.InstanceNorm2d(num_features, affine=False)
                self.affine = nn.Linear(num_params, num_features * 2)
                self.affine.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
                self.affine.weight.data[:, num_features:].zero_()  # Initialise bias at 0

            def forward(self, x, y):
                out = self.InstNorm(x)
                gamma, beta = self.affine(y).chunk(2, 1)
                out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
                return out

        class ConditionalSequential(nn.Sequential):
            def __init__(self,*args):
                super(ConditionalSequential, self).__init__(*args)

            def forward(self, inputs, labels):
                for module in self:
                    if module.__class__ is ConditionalInstanceNorm2d:
                        inputs = module(inputs, labels.float())
                    else:
                        inputs = module(inputs)

                return inputs

        class Generator(nn.Module):
            def __init__(self, gdict):
                super(Generator, self).__init__()

                ## Define new variables from dict
                keys=['ngpu','nz','nc','ngf','kernel_size','stride','g_padding']
                ngpu, nz,nc,ngf,kernel_size,stride,g_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())

                self.main = ConditionalSequential(
                    # nn.ConvTranspose3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Linear(nz,nc*ngf*8**3),# 262144
                    nn.BatchNorm3d(nc,eps=1e-05, momentum=0.9, affine=True),
                    nn.ReLU(inplace=True),
                    View(shape=[-1,ngf*8,4,4,4]),
                    nn.ConvTranspose3d(ngf * 8, ngf * 4, kernel_size, stride, g_padding, output_padding=1, bias=False),
                    ConditionalInstanceNorm2d(ngf*4,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*4) x 8 x 8
                    nn.ConvTranspose3d( ngf * 4, ngf * 2, kernel_size, stride, g_padding, 1, bias=False),
                    ConditionalInstanceNorm2d(ngf*2,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf*2) x 16 x 16
                    nn.ConvTranspose3d( ngf * 2, ngf, kernel_size, stride, g_padding, 1, bias=False),
                    ConditionalInstanceNorm2d(ngf,1),
                    nn.ReLU(inplace=True),
                    # state size. (ngf) x 32 x 32
                    nn.ConvTranspose3d( ngf, nc, kernel_size, stride,g_padding, 1, bias=False),
                    nn.Tanh()
                )
                
                
            def forward(self, noise,labels):
                img=self.main(noise,labels)

                return img

        class Discriminator(nn.Module):
            def __init__(self, gdict):
                super(Discriminator, self).__init__()

                ## Define new variables from dict
                keys=['ngpu','nz','nc','ndf','kernel_size','stride','d_padding']
                ngpu, nz,nc,ndf,kernel_size,stride,d_padding=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
                
                self.main = nn.Sequential(
                    # input is (nc) x 64 x 64
                    # nn.Conv3d(in_channels, out_channels, kernel_size,stride,padding,output_padding,groups,bias, Dilation,padding_mode)
                    nn.Conv3d(nc, ndf,kernel_size, stride, d_padding,  bias=True),
                    ConditionalInstanceNorm2d(ndf,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf) x 32 x 32
                    nn.Conv3d(ndf, ndf * 2, kernel_size, stride, d_padding, bias=True),
                    ConditionalInstanceNorm2d(ndf*2,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*2) x 16 x 16
                    nn.Conv3d(ndf * 2, ndf * 4, kernel_size, stride, d_padding, bias=True),
                    ConditionalInstanceNorm2d(ndf*4,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*4) x 8 x 8
                    nn.Conv3d(ndf * 4, ndf * 8, kernel_size, stride, d_padding, bias=True),
                    ConditionalInstanceNorm2d(ndf*8,1),
                    nn.LeakyReLU(0.2, inplace=True),
                    # state size. (ndf*8) x 4 x 4
                    nn.Flatten(),
                    nn.Linear(nc*ndf*8*8*8, 1)
        #             nn.Sigmoid()
                )
        
            def forward(self, ip,labels):   
                results=[ip]
                lst_idx=[]
                for i,submodel in enumerate(self.main.children()):
                    mid_output=submodel(results[-1])
                    results.append(mid_output)
                    ## Select indices in list corresponding to output of Conv layers
                    if submodel.__class__.__name__.startswith('Conv'):
        #                 print(submodel.__class__.__name__)
        #                 print(mid_output.shape)
                        lst_idx.append(i)

                FMloss=True
                if FMloss:
                    ans=[results[1:][i] for i in lst_idx + [-1]]
                else :
                    ans=results[-1]
                return ans

    return Generator, Discriminator

In [4]:
### Transformation functions for image pixel values
def f_transform(x,a):
    return 2.*x/(x + float(a)) - 1.

def f_invtransform(s,a):
    return float(a)*(1. + s)/(1. - s)

# Generator Code
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

def f_gen_images(gdict,netG,optimizerG,sigma,ip_fname,op_loc,op_strg='inf_img_',op_size=500):
    '''Generate images for best saved models
     Arguments: gdict, netG, optimizerG, sigma (parameter value),
                 ip_fname: name of input file
                op_strg: [string name for output file]
                op_size: Number of images to generate
    '''

    nz,device=gdict['nz'],gdict['device']

    try:# handling cpu vs gpu
        if torch.cuda.is_available(): checkpoint=torch.load(ip_fname)
        else: checkpoint=torch.load(ip_fname,map_location=torch.device('cpu'))
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        return
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
    
    ## Load other stuff
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    # Generate batch of latent vectors
    noise = torch.randn(op_size, 1, 1, 1, nz, device=device) ## Mod for 3D
    tnsr_cosm_params=(torch.ones(op_size,device=device)*sigma).view(op_size,1)
    
    # Generate fake image batch with G
    netG.eval() ## This is required before running inference
    with torch.no_grad(): ## This is important. fails without it for multi-gpu
        gen = netG(noise,tnsr_cosm_params)
        gen_images=gen.detach().cpu().numpy()
        print(gen_images.shape)
    
    op_fname='%s_epoch-%s_step-%s.npy'%(op_strg,epoch,iters)
    np.save(op_loc+op_fname,gen_images)

    print("Image saved in ",op_fname)
    
def f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,netG,netD,optimizerG,optimizerD,save_loc):
    ''' Checkpoint model '''
    
    if gdict['multi-gpu']: ## Dataparallel
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.module.state_dict(),'D_state':netD.module.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc) 
    else :
        torch.save({'epoch':epoch,'iters':iters,'best_chi1':best_chi1,'best_chi2':best_chi2,
                'G_state':netG.state_dict(),'D_state':netD.state_dict(),'optimizerG_state_dict':optimizerG.state_dict(),
                'optimizerD_state_dict':optimizerD.state_dict()}, save_loc)
    
def f_load_checkpoint(ip_fname,netG,netD,optimizerG,optimizerD,gdict):
    ''' Load saved checkpoint
    Also loads step, epoch, best_chi1, best_chi2'''
    
    print("torch device",torch.device('cuda',torch.cuda.current_device()))

    try:
        checkpoint=torch.load(ip_fname,map_location=torch.device('cuda',torch.cuda.current_device()))

    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        raise SystemError
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
        netD.module.load_state_dict(checkpoint['D_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
        netD.load_state_dict(checkpoint['D_state'])
    
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    best_chi1=checkpoint['best_chi1']
    best_chi2=checkpoint['best_chi2']

    netG.train()
    netD.train()
    
    return iters,epoch,best_chi1,best_chi2,netD,optimizerD,netG,optimizerG



## 

### Spec loss modules

In [48]:
####################
### Pytorch code ###
####################

## Mod for 3D

def f_get_rad(img):
    ''' Get the radial tensor for use in f_torch_get_azimuthalAverage '''
    
    height,width,depth=img.shape[-3:]
    # Create a grid of points with x and y and z coordinates
    z,y,x = np.indices([height,width,depth])
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0])

    # Get the radial coordinate for every grid point. Array has the shape of image
    r= torch.tensor(np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2))
        
    # Get sorted radii
    ind = torch.argsort(torch.reshape(r, (-1,)))

    return r.detach(),ind.detach()

def f_torch_get_azimuthalAverage(image,r,ind):
    """
    Calculate the azimuthally averaged radial profile.

    image - The 3D image
    center - The [x,y] pixel coordinates used as the center. The default is 
             None, which then uses the center of the image (including 
             fracitonal pixels).
    source: https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    
#     height,width,depth=img.shape[-3:]
#     # Create a grid of points with x and y coordinates
#     z,y,x = np.indices([height,width,depth])
    
#     center=[]
#     if not center:
#         center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0])

#     # Get the radial coordinate for every grid point. Array has the shape of image
#     r= torch.tensor(np.sqrt((x-center[0])**2 + (y-center[1])**2 + (z-center[2])**2))
        
#     # Get sorted radii
#     ind = torch.argsort(torch.reshape(r, (-1,)))

    r_sorted = torch.gather(torch.reshape(r, ( -1,)),0, ind)
    i_sorted = torch.gather(torch.reshape(image, ( -1,)),0, ind)
    
    # Get the integer part of the radii (bin size = 1)
    r_int=r_sorted.to(torch.int32)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = torch.reshape(torch.where(deltar)[0], (-1,))    # location of changes in radius
    nr = (rind[1:] - rind[:-1]).type(torch.float)       # number of radius bin

    # Cumulative sum to figure out sums for each radius bin
    
    csum = torch.cumsum(i_sorted, axis=-1)
    tbin = torch.gather(csum, 0, rind[1:]) - torch.gather(csum, 0, rind[:-1])
    radial_prof = tbin / nr

    return radial_prof

def f_torch_fftshift(real, imag):
    for dim in range(0, len(real.size())):
        real = torch.roll(real, dims=dim, shifts=real.size(dim)//2)
        imag = torch.roll(imag, dims=dim, shifts=imag.size(dim)//2)
    return real, imag

def f_torch_compute_spectrum(arr,r,ind):
    
    GLOBAL_MEAN=1.0
    arr=(arr-GLOBAL_MEAN)/(GLOBAL_MEAN)
    
    y1=torch.rfft(arr,signal_ndim=3,onesided=False)
    real,imag=f_torch_fftshift(y1[:,:,:,0],y1[:,:,:,1])    ## last index is real/imag part  ## Mod for 3D
    
#     # For pytorch 1.8
#     y1=torch.fft.fftn(arr,dim=(-3,-2,-1))
#     real,imag=f_torch_fftshift(y1.real,y1.imag)    
    
    y2=real**2+imag**2     ## Absolute value of each complex number
    z1=f_torch_get_azimuthalAverage(y2,r,ind)     ## Compute radial profile
    return z1


def f_torch_compute_batch_spectrum(arr,r,ind):
    
    batch_pk=torch.stack([f_torch_compute_spectrum(i,r,ind) for i in arr])
    
    return batch_pk

def f_torch_image_spectrum(x,num_channels,r,ind):
    '''
    Data has to be in the form (batch,channel,x,y)
    '''
    mean=[[] for i in range(num_channels)]    
    var=[[] for i in range(num_channels)] 

    for i in range(num_channels):
        arr=x[:,i,:,:,:] # Mod for 3D
        batch_pk=f_torch_compute_batch_spectrum(arr,r,ind)
        mean[i]=torch.mean(batch_pk,axis=0)
#         var[i]=torch.std(batch_pk,axis=0)/np.sqrt(batch_pk.shape[0])
#         var[i]=torch.std(batch_pk,axis=0)
        var[i]=torch.var(batch_pk,axis=0)
    
    mean=torch.stack(mean)
    var=torch.stack(var)
        
    # if (torch.isnan(mean).any() or torch.isnan(var).any()):
    #     print("Nans in spectrum",mean,var)
    #     if torch.isnan(x).any():
    #         print("Nans in Input image")

    return mean,var

def f_compute_hist(data,bins):
    try: 
        hist_data=torch.histc(data,bins=bins)
        ## A kind of normalization of histograms: divide by total sum
        hist_data=(hist_data*bins)/torch.sum(hist_data)
    except Exception as e:
        print(e)
        hist_data=torch.zeros(bins)

    return hist_data

### Losses 
def loss_spectrum(spec_mean,spec_mean_ref,spec_var,spec_var_ref,image_size,lambda_spec_mean,lambda_spec_var):
    ''' Loss function for the spectrum : mean + variance 
    Log(sum( batch value - expect value) ^ 2 )) '''
    
    if (torch.isnan(spec_mean).any() or torch.isnan(spec_var).any()):
        ans=torch.tensor(float("inf"))
        return ans
    
    idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
    ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.
    
    # loss_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
    # loss_var=torch.log(torch.mean(torch.pow(spec_var[:,:idx]-spec_var_ref[:,:idx],2)))
    
    epsilon_spec=1e6 ## correction in case of a 0 inside the log (= min value of spectrum)
    loss_mean=torch.mean(torch.log(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)+epsilon_spec))
    
    
    # loss_var =torch.mean(torch.log(torch.pow(spec_var[:,:idx]-spec_var_ref[:,:idx],2)+epsilon_spec))    
    ## For variance, square takes value beyond 1e32, so it fails. value is always positive, so can instead do square of log.
    # loss_var =torch.mean(torch.pow(torch.log(spec_var[:,:idx])-torch.log(spec_var_ref[:,:idx]),2))
    
    loss_var=torch.Tensor([0.0]).to(gdict['device'])
    
    ans=lambda_spec_mean*loss_mean+lambda_spec_var*loss_var

    if (torch.isnan(ans).any()) :    
        print("loss spec mean %s, loss spec var %s"%(loss_mean,loss_var))
        # print("spec mean %s, ref %s"%(spec_mean, spec_mean_ref))
        print("spec var %s, ref %s"%(spec_var, spec_var_ref))
        raise SystemExit
        
    return ans

def loss_hist(hist_sample,hist_ref):
    
    lambda1=1.0
    #return lambda1*torch.log(torch.mean(torch.pow(hist_sample-hist_ref,2)))
    epsilon_hist=1e-10
    return lambda1*torch.mean(torch.log(torch.pow(hist_sample-hist_ref,2)+epsilon_hist))


def f_FM_loss(real_output,fake_output,lambda_fm,gdict):
    '''
    Module to implement Feature-Matching loss. Reads all but last elements of Discriminator ouput
    '''
    FM=torch.Tensor([0.0]).to(gdict['device'])
    for i,j in zip(real_output[:-1],fake_output[:-1]):
        real_mean=torch.mean(i)
        fake_mean=torch.mean(j)
        FM=FM.clone()+torch.sum(torch.square(real_mean-fake_mean))
    return lambda_fm*FM

def f_gp_loss(grads,l=1.0):
    '''
    Module to implement gradient penalty loss.
    '''
    loss=torch.mean(torch.sum(torch.square(grads),dim=[1,2,3]))
    return l*loss

def f_get_loss_cond(loss_type,img_tensor,cosm_params,gdict,bins=None,hist_val_tnsr=None,spec_mean_tnsr=None,spec_var_tnsr=None,r=None,ind=None,real_output=None,fake_output=None,grads=None):
    ''' Module to compute one of the losses for conditional GAN '''
    
    loss_tensor=torch.zeros(len(gdict['sigma_list']),device=gdict['device'])
    
    for count,i in enumerate(gdict['sigma_list']):
        idxs=torch.where(cosm_params==i)[0] ## Get indices for that category
        if idxs.size(0)>1: 
            num_frac=idxs.size(0)/img_tensor.shape[0] ## Fraction of points in the category
            img=img_tensor[idxs]
            if loss_type=='hist':
                loss_tensor[count]=loss_hist(f_compute_hist(img,bins),hist_val_tnsr[count])*num_frac
            elif loss_type=='spec':
                mean,var=f_torch_image_spectrum(f_invtransform(img,gdict['kappa']),1,r,ind)
                loss_tensor[count]=loss_spectrum(mean,spec_mean_tnsr[count],var,spec_var_tnsr[count],gdict['image_size'],gdict['lambda_spec_mean'],gdict['lambda_spec_var'])*num_frac
            elif loss_type=='fm':
                loss_tensor[count]=f_FM_loss(real_output,fake_output,gdict['lambda_fm'],gdict)
            elif loss_type=='gp':
                loss_tensor[count]=f_gp_loss(grads,gdict['lambda_gp'])

    loss=loss_tensor.sum()
            
    return loss


## Start

In [9]:
########## Modules
### Setup modules ###
def f_manual_add_argparse():
    ''' use only in jpt notebook'''
    args=argparse.Namespace()
    args.config='config_3d_cgan_64_cori.yaml'
    args.mode='fresh'
    args.local_rank=0
    args.facility='cori'
    args.distributed=False

#     args.mode='continue'
    
    return args

def f_parse_args():
    """Parse command line arguments.Only for .py file"""
    parser = argparse.ArgumentParser(description="Run script to train GAN using pytorch", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    add_arg = parser.add_argument
    
    add_arg('--config','-cfile',  type=str, default='config_3d_Cgan.yaml', help='Name of config file')
    add_arg('--mode','-m',  type=str, choices=['fresh','continue','fresh_load'],default='fresh', help='Whether to start fresh run or continue previous run or fresh run loading a config file.')
    add_arg("--local_rank", default=0, type=int,help='Local rank of GPU on node. Using for pytorch DDP. ')
    add_arg("--facility", default='cori', choices=['cori','summit'],type=str,help='Facility: cori or summit ')
    add_arg("--ddp", dest='distributed' ,default=False,action='store_true',help='use Distributed DataParallel for Pytorch or DataParallel')
    
    return parser.parse_args()


def try_barrier(rank):
    """
    Used in Distributed data parallel
    Attempt a barrier but ignore any exceptions
    """
    print('BAR %d'%rank)
    try:
        dist.barrier()
    except:
        pass

def f_init_gdict(args,gdict):
    ''' Create global dictionary gdict from args and config file'''
    
    ## read config file
    config_file=args.config
    with open(config_file) as f:
        config_dict= yaml.load(f, Loader=yaml.SafeLoader)
        
    gdict=config_dict['parameters']

    args_dict=vars(args)
    ## Add args variables to gdict
    for key in args_dict.keys():
        gdict[key]=args_dict[key]

    if gdict['distributed']: 
        assert not gdict['lambda_gp'],"GP couplings is %s. Cannot use Gradient penalty loss in pytorch DDP"%(gdict['lambda_gp'])
    else : print("Not using DDP")
    return gdict


def f_get_img_samples(ip_arr,rank=0,num_ranks=1):
    '''
    Module to get part of the numpy image file
    '''
    
    data_size=ip_arr.shape[0]
    size=data_size//num_ranks
    
    if gdict['batch_size']>size:
        print("Caution: batchsize %s is greater than samples per GPU %s"%(gdict['batch_size'],size))
        raise SystemExit
        
    ### Get a set of random indices from numpy array
    random=False
    if random:
        idxs=np.arange(ip_arr.shape[0])
        np.random.shuffle(idxs)
        rnd_idxs=idxs[rank*(size):(rank+1)*size]
        arr=ip_arr[rnd_idxs].copy()
        
    else: arr=ip_arr[rank*(size):(rank+1)*size].copy()
    
    return arr

def f_setup(gdict,metrics_df,log):
    ''' 
    Set up directories, Initialize random seeds, add GPU info, add logging info.
    '''
    
    torch.backends.cudnn.benchmark=True
#     torch.autograd.set_detect_anomaly(True)

    ## New additions. Code taken from Jan B.
    os.environ['MASTER_PORT'] = "8885"

    if gdict['facility']=='summit':
        get_master = "echo $(cat {} | sort | uniq | grep -v batch | grep -v login | head -1)".format(os.environ['LSB_DJOB_HOSTFILE'])
        os.environ['MASTER_ADDR'] = str(subprocess.check_output(get_master, shell=True))[2:-3]
        os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
        os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
        gdict['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
    else:
        if gdict['distributed']:
            os.environ['WORLD_SIZE'] = os.environ['SLURM_NTASKS']
            os.environ['RANK'] = os.environ['SLURM_PROCID']
            gdict['local_rank'] = int(os.environ['SLURM_LOCALID'])

    ## Special declarations
    gdict['ngpu']=torch.cuda.device_count()
    gdict['device']=torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
    gdict['multi-gpu']=True if (gdict['device'].type == 'cuda') and (gdict['ngpu'] > 1) else False 
    
    ########################
    ###### Set up Distributed Data parallel ######
    if gdict['distributed']:
#         gdict['local_rank']=args.local_rank  ## This is needed when using pytorch -m torch.distributed.launch
        gdict['world_size']=int(os.environ['WORLD_SIZE'])
        torch.cuda.set_device(gdict['local_rank']) ## Very important
        dist.init_process_group(backend='nccl', init_method="env://")  
        gdict['world_rank']= dist.get_rank()
        
        device = torch.cuda.current_device()
        logging.info("World size %s, world rank %s, local rank %s device %s, hostname %s, GPUs on node %s\n"%(gdict['world_size'],gdict['world_rank'],gdict['local_rank'],device,socket.gethostname(),gdict['ngpu']))
        
        # Divide batch size by number of GPUs
#         gdict['batch_size']=gdict['batch_size']//gdict['world_size']
    else:
        gdict['world_size'],gdict['world_rank'],gdict['local_rank']=1,0,0

    
    ########################
    ###### Set up directories #######
    ### sync up so that time is the same for each GPU for DDP
    if gdict['mode'] in ['fresh','fresh_load']:
        ### Create prefix for foldername      
        if gdict['world_rank']==0: ### For rank=0, create directory name string and make directories
            dt_strg=datetime.now().strftime('%Y%m%d_%H%M%S') ## time format
            dt_lst=[int(i) for i in dt_strg.split('_')] # List storing day and time            
            dt_tnsr=torch.LongTensor(dt_lst).to(gdict['device'])  ## Create list to pass to other GPUs 

        else: dt_tnsr=torch.Tensor([0,0]).long().to(gdict['device'])
        ### Pass directory name to other ranks
        if gdict['distributed']: dist.broadcast(dt_tnsr, src=0)

        gdict['save_dir']=gdict['op_loc']+str(int(dt_tnsr[0]))+'_'+str(int(dt_tnsr[1]))+'_'+gdict['run_suffix']
        
        if gdict['world_rank']==0: # Create directories for rank 0
            ### Create directories
            if not os.path.exists(gdict['save_dir']):
                os.makedirs(gdict['save_dir']+'/models')
                os.makedirs(gdict['save_dir']+'/images')
                shutil.copy(gdict['config'],gdict['save_dir'])    
    
    elif gdict['mode']=='continue': ## For checkpointed runs
        gdict['save_dir']=gdict['ip_fldr']
        ### Read loss data
        metrics_df=pd.read_pickle(gdict['save_dir']+'/df_metrics.pkle').astype(np.float64)
   
    ########################
    ### Initialize random seed
    
    manualSeed = np.random.randint(1, 10000) if gdict['seed']=='random' else int(gdict['seed'])
#     print("Seed",manualSeed,gdict['world_rank'])
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    
    if gdict['deterministic']:
        logging.info("Running with deterministic sequence. Performance will be slower")
        torch.backends.cudnn.deterministic=True
#         torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False        
    
    ########################
    if log:
        ### Write all logging.info statements to stdout and log file
        logfile=gdict['save_dir']+'/log.log'
        if gdict['world_rank']==0:
            logging.basicConfig(level=logging.DEBUG, filename=logfile, filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")

            Lg = logging.getLogger()
            Lg.setLevel(logging.DEBUG)
            lg_handler_file = logging.FileHandler(logfile)
            lg_handler_stdout = logging.StreamHandler(sys.stdout)
            Lg.addHandler(lg_handler_file)
            Lg.addHandler(lg_handler_stdout)

            logging.info('Args: {0}'.format(args))
            logging.info('Start: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))
        
        if gdict['distributed']:  try_barrier(gdict['world_rank'])

        if gdict['world_rank']!=0:
                logging.basicConfig(level=logging.DEBUG, filename=logfile, filemode="a+", format="%(asctime)-15s %(levelname)-8s %(message)s")

    return metrics_df

class Dataset:
    def __init__(self,gdict):
        '''
        Load training dataset and compute spectrum and histogram for a small batch of training and validation dataset.
        '''
        
        ## Load training dataset
        t0a=time.time()
        for count,sigma in enumerate(gdict['sigma_list']):
            fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
            x=np.load(fname,mmap_mode='r')[:gdict['num_imgs']].transpose(0,1,2,3,4) ## Mod for 3D
            x=f_get_img_samples(x,gdict['world_rank'],gdict['world_size'])
            size=x.shape[0]
            y=sigma*np.ones(size)

            if count==0:
                img=x[:]
                c_pars=y[:]
            else: 
                img=np.vstack([img,x]) # Store images
                c_pars=np.hstack([c_pars,y]) # Store cosmological parameters

        ### Manually shuffling numpy arrays to mix sigma values
        size=img.shape[0]
        idxs=np.random.choice(size,size=size,replace=False)
        img=img[idxs]
        c_pars=c_pars[idxs]
        ## convert to tensors
        t_img=torch.from_numpy(img)
        cosm_params=torch.Tensor(c_pars).view(size,1)

        dataset=TensorDataset(t_img,cosm_params)
        self.train_dataloader=DataLoader(dataset,batch_size=gdict['batch_size'],shuffle=True,num_workers=0,drop_last=True)
        logging.info("Size of dataset for GPU %s : %s"%(gdict['world_rank'],len(self.train_dataloader.dataset)))

        t0b=time.time()
        logging.info("Time for creating dataloader %s for rank %s"%(t0b-t0a,gdict['world_rank']))


        # Precompute spectrum and histogram for small training and validation data for computing losses           
        def f_compute_summary_stats(idx1=-50,idx2=None):
            # Compute hist and spec for given dataset
            with torch.no_grad():
                spec_mean_list=[];spec_var_list=[];hist_val_list=[]
                for count,sigma in enumerate(gdict['sigma_list']):
                    ip_fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
                    val_img=np.load(ip_fname,mmap_mode='r')[idx1:idx2].transpose(0,1,2,3,4).copy() ## Mod for 3D
                    t_val_img=torch.from_numpy(val_img).to(gdict['device'])

                    # Precompute radial coordinates
                    if count==0: 
                        r,ind=f_get_rad(val_img)
                        r=r.to(gdict['device']); ind=ind.to(gdict['device'])
                    # Stored mean and std of spectrum for full input data once
                    mean_spec_val,var_spec_val=f_torch_image_spectrum(f_invtransform(t_val_img,gdict['kappa']),1,r,ind)
                    hist_val=f_compute_hist(t_val_img,bins=gdict['bns'])
                    
                    ## Quit if reference spectrum or histogram have nans 
                    if (torch.isnan(mean_spec_val).any() or torch.isnan(var_spec_val).any() or torch.isnan(hist_val).any() ): 
                        print(mean_spec_val,var_spec_val,hist_val)
                        raise SystemError

                    spec_mean_list.append(mean_spec_val)
                    spec_var_list.append(var_spec_val)
                    hist_val_list.append(hist_val)
    #             del val_img; del t_val_img; del img; del spec_mean_list; del spec_var_list; del hist_val_list   
                return torch.stack(spec_mean_list),torch.stack(spec_var_list),torch.stack(hist_val_list),r,ind
        
        with torch.no_grad():
            self.train_spec_mean,self.train_spec_var,self.train_hist,self.r,self.ind=f_compute_summary_stats(-50,None)
            ## Compute for validation data
            self.val_spec_mean,self.val_spec_var,self.val_hist,_,_=f_compute_summary_stats(-100,-50)

class Dataset:
    def __init__(self,gdict):
        '''
        Load training dataset and compute spectrum and histogram for a small batch of training and validation dataset.
        '''
        
        ## Load training dataset
        t0a=time.time()
        for count,sigma in enumerate(gdict['sigma_list']):
#            fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
            fname=gdict['ip_fname']+'/Om0.3_Sg%s_H70.0.npy'%(sigma)
            x=np.load(fname,mmap_mode='r')[:gdict['num_imgs']].transpose(0,1,2,3,4) ## Mod for 3D
            x=f_get_img_samples(x,gdict['world_rank'],gdict['world_size'])
            x=f_transform(x,gdict['kappa'])
            print("shape of input file",x.shape)
            size=x.shape[0]
            y=sigma*np.ones(size)

            if count==0:
                img=x[:]
                c_pars=y[:]
            else: 
                img=np.vstack([img,x]) # Store images
                c_pars=np.hstack([c_pars,y]) # Store cosmological parameters

        ### Manually shuffling numpy arrays to mix sigma values
        size=img.shape[0]
        idxs=np.random.choice(size,size=size,replace=False)
        img=img[idxs]
        c_pars=c_pars[idxs]
        ## convert to tensors
        t_img=torch.from_numpy(img)
        cosm_params=torch.Tensor(c_pars).view(size,1)

        dataset=TensorDataset(t_img,cosm_params)
        self.train_dataloader=DataLoader(dataset,batch_size=gdict['batch_size'],shuffle=True,num_workers=0,drop_last=True)
        logging.info("Size of dataset for GPU %s : %s"%(gdict['world_rank'],len(self.train_dataloader.dataset)))

        t0b=time.time()
        logging.info("Time for creating dataloader %s for rank %s"%(t0b-t0a,gdict['world_rank']))


        # Precompute spectrum and histogram for small training and validation data for computing losses           
        def f_compute_summary_stats(idx1=-50,idx2=None):
            # Compute hist and spec for given dataset
            with torch.no_grad():
                spec_mean_list=[];spec_var_list=[];hist_val_list=[]
                for count,sigma in enumerate(gdict['sigma_list']):
#                     ip_fname=gdict['ip_fname']+'/norm_1_sig_%s_train_val.npy'%(sigma)
                    ip_fname=gdict['ip_fname']+'/Om0.3_Sg%s_H70.0.npy'%(sigma)

                    val_img=np.load(ip_fname,mmap_mode='r')[idx1:idx2].transpose(0,1,2,3,4).copy() ## Mod for 3D
                    val_img=f_transform(val_img,gdict['kappa'])

                    t_val_img=torch.from_numpy(val_img).to(gdict['device'])

                    # Precompute radial coordinates
                    if count==0: 
                        r,ind=f_get_rad(val_img)
                        r=r.to(gdict['device']); ind=ind.to(gdict['device'])
                    # Stored mean and std of spectrum for full input data once
                    mean_spec_val,var_spec_val=f_torch_image_spectrum(f_invtransform(t_val_img,gdict['kappa']),1,r,ind)
                    hist_val=f_compute_hist(t_val_img,bins=gdict['bns'])

                    spec_mean_list.append(mean_spec_val)
                    spec_var_list.append(var_spec_val)
                    hist_val_list.append(hist_val)
    #             del val_img; del t_val_img; del img; del spec_mean_list; del spec_var_list; del hist_val_list   
                return torch.stack(spec_mean_list),torch.stack(spec_var_list),torch.stack(hist_val_list),r,ind
        
        with torch.no_grad():
            self.train_spec_mean,self.train_spec_var,self.train_hist,self.r,self.ind=f_compute_summary_stats(-50,None)
            ## Compute for validation data
            self.val_spec_mean,self.val_spec_var,self.val_hist,_,_=f_compute_summary_stats(-100,-50)

class GAN_model():
    def __init__(self,gdict,print_model=False):
    
        def weights_init(m):
            '''custom weights initialization called on netG and netD '''
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
        
        ## Choose model
        Generator, Discriminator=f_get_model(gdict['model'],gdict) ## Mod for cGAN

        # Create Generator
        self.netG = Generator(gdict).to(gdict['device'])
        self.netG.apply(weights_init)
        # Create Discriminator
        self.netD = Discriminator(gdict).to(gdict['device'])
        self.netD.apply(weights_init)

        if print_model:
            if gdict['world_rank']==0:
                print(self.netG)
            #     summary(netG,(1,1,64))
                print(self.netD)
            #     summary(netD,(1,128,128))
                print("Number of GPUs used %s"%(gdict['ngpu']))

        if (gdict['multi-gpu']):
            if not gdict['distributed']:
                self.netG = nn.DataParallel(self.netG, list(range(gdict['ngpu'])))
                self.netD = nn.DataParallel(self.netD, list(range(gdict['ngpu'])))
            else:
                self.netG=DistributedDataParallel(self.netG,device_ids=[gdict['local_rank']],output_device=[gdict['local_rank']])
                self.netD=DistributedDataParallel(self.netD,device_ids=[gdict['local_rank']],output_device=[gdict['local_rank']])

        #### Initialize networks ####
        # self.criterion = nn.BCELoss()
        self.criterion = nn.BCEWithLogitsLoss()

        self.optimizerD = optim.Adam(self.netD.parameters(), lr=gdict['learn_rate_d'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=gdict['learn_rate_g'], betas=(gdict['beta1'], 0.999),eps=1e-7)
        
        if gdict['distributed']:  try_barrier(gdict['world_rank'])

        if gdict['mode']=='fresh':
            iters,start_epoch,best_chi1,best_chi2=0,0,1e10,1e10 
            
        elif gdict['mode']=='continue':
            iters,start_epoch,best_chi1,best_chi2,self.netD,self.optimizerD,self.netG,self.optimizerG=f_load_checkpoint(gdict['save_dir']+'/models/checkpoint_last.tar',\
                                                                                                                        self.netG,self.netD,self.optimizerG,self.optimizerD,gdict) 
            if gdict['world_rank']==0: logging.info("\nContinuing existing run. Loading checkpoint with epoch {0} and step {1}\n".format(start_epoch,iters))
            if gdict['distributed']:  try_barrier(gdict['world_rank'])
            start_epoch+=1  ## Start with the next epoch 
        
        elif gdict['mode']=='fresh_load':
            iters,start_epoch,best_chi1,best_chi2,self.netD,self.optimizerD,self.netG,self.optimizerG=f_load_checkpoint(gdict['chkpt_file'],\
                                                                                                                        self.netG,self.netD,self.optimizerG,self.optimizerD,gdict) 
            if gdict['world_rank']==0: logging.info("Fresh run loading checkpoint file {0}".format(gdict['chkpt_file']))
#             if gdict['distributed']:  try_barrier(gdict['world_rank'])
            iters,start_epoch,best_chi1,best_chi2=0,0,1e10,1e10 
        
        ## Add to gdict
        for key,val in zip(['best_chi1','best_chi2','iters','start_epoch'],[best_chi1,best_chi2,iters,start_epoch]): gdict[key]=val
        
        ## Set up learn rate scheduler
        lr_stepsize=int((gdict['num_imgs']*len(gdict['sigma_list']))/(gdict['batch_size']*gdict['world_size'])) # convert epoch number to step 
        lr_d_epochs=[i*lr_stepsize for i in gdict['lr_d_epochs']] 
        lr_g_epochs=[i*lr_stepsize for i in gdict['lr_g_epochs']]
        self.schedulerD = optim.lr_scheduler.MultiStepLR(self.optimizerD, milestones=lr_d_epochs,gamma=gdict['lr_d_gamma'])
        self.schedulerG = optim.lr_scheduler.MultiStepLR(self.optimizerG, milestones=lr_g_epochs,gamma=gdict['lr_g_gamma'])


## Train loop

In [10]:
def f_train_loop(gan_model,Dset,metrics_df,gdict,fixed_noise,fixed_cosm_params):
    ''' Train single epoch '''

    ## Define new variables from dict
    keys=['image_size','start_epoch','epochs','iters','best_chi1','best_chi2','save_dir','device','flip_prob','nz','batch_size','bns']
    image_size,start_epoch,epochs,iters,best_chi1,best_chi2,save_dir,device,flip_prob,nz,batchsize,bns=list(collections.OrderedDict({key:gdict[key] for key in keys}).values())
    
    for epoch in range(start_epoch,epochs):
        t_epoch_start=time.time()
        for count, data in enumerate(Dset.train_dataloader):

            ####### Train GAN ########
            gan_model.netG.train(); gan_model.netD.train();  ### Need to add these after inference and before training

            tme1=time.time()
            ### Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            gan_model.netD.zero_grad()

            real_cpu = data[0].to(device)
            real_cosm_params=data[1].to(device)
            real_cpu.requires_grad=True
            
            b_size = real_cpu.size(0)
            real_label = torch.full((b_size,), 1, device=device,dtype=float)
            fake_label = torch.full((b_size,), 0, device=device,dtype=float)
            g_label = torch.full((b_size,), 1, device=device,dtype=float) ## No flipping for Generator labels
            # Flip labels with probability flip_prob
            for idx in np.random.choice(np.arange(b_size),size=int(np.ceil(b_size*flip_prob))):
                real_label[idx]=0; fake_label[idx]=1

            # Generate fake image batch with G
            noise = torch.randn(b_size, 1, 1, 1, nz, device=device) ### Mod for 3D
            rnd_idx=torch.randint(len(gdict['sigma_list']),(gdict['batch_size'],1),device=gdict['device'])
            fake_cosm_params=torch.tensor([gdict['sigma_list'][i] for i in rnd_idx.long()],device=gdict['device']).unsqueeze(-1)

            fake = gan_model.netG(noise,fake_cosm_params)         
            
            if torch.isnan(fake).any():
                
                
            # Forward pass real batch through D
            real_output = gan_model.netD(real_cpu,real_cosm_params)
            errD_real = gan_model.criterion(real_output[-1].view(-1), real_label.float())
            errD_real.backward(retain_graph=True)
            D_x = real_output[-1].mean().item()

            # Forward pass fake batch through D
            fake_output = gan_model.netD(fake.detach(),fake_cosm_params)  # The detach is important
            errD_fake = gan_model.criterion(fake_output[-1].view(-1), fake_label.float())            
            errD_fake.backward(retain_graph=True)
            D_G_z1 = fake_output[-1].mean().item()
            
            errD = errD_real + errD_fake 

            if gdict['lambda_gp']: ## Add gradient - penalty loss                
                grads=torch.autograd.grad(outputs=real_output[-1],inputs=real_cpu,grad_outputs=torch.ones_like(real_output[-1]),allow_unused=False,create_graph=True)[0]
                gp_loss=f_get_loss_cond('gp',fake,fake_cosm_params,gdict,grads=grads)
                gp_loss.backward(retain_graph=True)
                errD = errD + gp_loss
            else:
                gp_loss=torch.Tensor([np.nan])

            ### Implement Gradient clipping
            if gdict['grad_clip']:
                nn.utils.clip_grad_norm_(gan_model.netD.parameters(),gdict['grad_clip'])
                
            
            if torch.isnan(errD).any(): ## Check nans
                print("errD has nans",errD_real,errD_fake,errD)
                print("fake image tensor",fake_output[-1])
                raise SystemError

                
            gan_model.optimizerD.step()
            lr_d=gan_model.optimizerD.param_groups[0]['lr']
            gan_model.schedulerD.step()
            
            ###Update G network: maximize log(D(G(z)))
            gan_model.netG.zero_grad()
            output = gan_model.netD(fake,fake_cosm_params)
            errG_adv = gan_model.criterion(output[-1].view(-1), g_label.float())
            # Histogram pixel intensity loss
            hist_loss=f_get_loss_cond('hist',fake,fake_cosm_params,gdict,bins=gdict['bns'],hist_val_tnsr=Dset.train_hist)

            # Add spectral loss
            mean,var=f_torch_image_spectrum(f_invtransform(fake,gdict['kappa']),1,Dset.r.to(device),Dset.ind.to(device))
            spec_loss=f_get_loss_cond('spec',fake,fake_cosm_params,gdict,spec_mean_tnsr=Dset.train_spec_mean,spec_var_tnsr=Dset.train_spec_var,r=Dset.r,ind=Dset.ind)
            
            errG=errG_adv
            if gdict['lambda_spec_mean']: errG = errG+ spec_loss 
            if gdict['lambda_fm']:## Add feature matching loss
                fm_loss=f_get_loss_cond('fm',fake,fake_cosm_params,gdict,real_output=[i.detach() for i in real_output],fake_output=output)
                errG= errG+ fm_loss
            else: 
                fm_loss=torch.Tensor([np.nan])

            if torch.isnan(errG).any():
                logging.info(errG)
                raise SystemError
            
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output[-1].mean().item()

            ### Implement Gradient clipping
            if gdict['grad_clip']:
                nn.utils.clip_grad_norm_(gan_model.netG.parameters(),gdict['grad_clip'])

            gan_model.optimizerG.step()
            lr_g=gan_model.optimizerG.param_groups[0]['lr']
            gan_model.schedulerG.step()
            
            tme2=time.time()
            ####### Store metrics ########
            # Output training stats
            if gdict['world_rank']==0:
                if ((count % gdict['checkpoint_size'] == 0)):
                    logging.info('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_adv: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                          % (epoch, epochs, count, len(Dset.train_dataloader), errD.item(), errG_adv.item(),errG.item(), D_x, D_G_z1, D_G_z2)),
                    logging.info("Spec loss: %s,\t hist loss: %s"%(spec_loss.item(),hist_loss.item())),
                    logging.info("Training time for step %s : %s"%(iters, tme2-tme1))

                # Save metrics
                cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','fm_loss','gp_loss','D(x)','D_G_z1','D_G_z2','lr_d','lr_g','time']
                vals=[iters,epoch,errD_real.item(),errD_fake.item(),errD.item(),errG_adv.item(),errG.item(),spec_loss.item(),hist_loss.item(),fm_loss.item(),gp_loss.item(),D_x,D_G_z1,D_G_z2,lr_d,lr_g,tme2-tme1]
                for col,val in zip(cols,vals):  metrics_df.loc[iters,col]=val

                ### Checkpoint the best model
                checkpoint=True
                iters += 1  ### Model has been updated, so update iters before saving metrics and model.

                ### Compute validation metrics for updated model
                gan_model.netG.eval()
                with torch.no_grad():
                    fake = gan_model.netG(fixed_noise,fixed_cosm_params)
                    hist_chi=f_get_loss_cond('hist',fake,fixed_cosm_params,gdict,bins=gdict['bns'],hist_val_tnsr=Dset.val_hist)
                    spec_chi=f_get_loss_cond('spec',fake,fixed_cosm_params,gdict,spec_mean_tnsr=Dset.val_spec_mean,spec_var_tnsr=Dset.val_spec_var,r=Dset.r,ind=Dset.ind)

                # Storing chi for next step
                for col,val in zip(['spec_chi','hist_chi'],[spec_chi.item(),hist_chi.item()]):  metrics_df.loc[iters,col]=val            

                # Checkpoint model for continuing run
                if count == len(Dset.train_dataloader)-1: ## Checkpoint at last step of epoch
                    f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,gan_model.netG,gan_model.netD,gan_model.optimizerG,gan_model.optimizerD,save_loc=save_dir+'/models/checkpoint_last.tar')  
                    shutil.copy(save_dir+'/models/checkpoint_last.tar',save_dir+'/models/checkpoint_%s_%s.tar'%(epoch,iters)) # Store last step for each epoch
                    
                if (checkpoint and (epoch > 1)): # Choose best models by metric
                    if hist_chi< best_chi1:
                        f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,gan_model.netG,gan_model.netD,gan_model.optimizerG,gan_model.optimizerD,save_loc=save_dir+'/models/checkpoint_best_hist.tar')
                        best_chi1=hist_chi.item()
                        logging.info("Saving best hist model at epoch %s, step %s."%(epoch,iters))

                    if  spec_chi< best_chi2:
                        f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,gan_model.netG,gan_model.netD,gan_model.optimizerG,gan_model.optimizerD,save_loc=save_dir+'/models/checkpoint_best_spec.tar')
                        best_chi2=spec_chi.item()
                        logging.info("Saving best spec model at epoch %s, step %s"%(epoch,iters))

#                    if (iters in gdict['save_steps_list']) :
                    if ((gdict['save_steps_list']=='all') and (iters % gdict['checkpoint_size'] == 0)):                        
                        f_save_checkpoint(gdict,epoch,iters,best_chi1,best_chi2,gan_model.netG,gan_model.netD,gan_model.optimizerG,gan_model.optimizerD,save_loc=save_dir+'/models/checkpoint_{0}.tar'.format(iters))
                        logging.info("Saving given-step at epoch %s, step %s."%(epoch,iters))

                # Save G's output on fixed_noise
                if ((iters % gdict['checkpoint_size'] == 0) or ((epoch == epochs-1) and (count == len(Dset.train_dataloader)-1))):
                    gan_model.netG.eval()
                    with torch.no_grad():
                        for c_pars in gdict['sigma_list']:
                            tnsr_cosm_params=(torch.ones(gdict['op_size'],device=device)*c_pars).view(gdict['op_size'],1)
                            fake = gan_model.netG(fixed_noise,tnsr_cosm_params).detach().cpu()
                            img_arr=np.array(fake)
                            fname='gen_img_label-%s_epoch-%s_step-%s'%(c_pars,epoch,iters)
                            np.save(save_dir+'/images/'+fname,img_arr)
        
        t_epoch_end=time.time()
        if gdict['world_rank']==0:
            logging.info("Time taken for epoch %s, count %s: %s for rank %s"%(epoch,count,t_epoch_end-t_epoch_start,gdict['world_rank']))
            # Save Metrics to file after each epoch
            metrics_df.to_pickle(save_dir+'/df_metrics.pkle')
            logging.info("best chis: {0}, {1}".format(best_chi1,best_chi2))



## Main

In [49]:
#########################
### Main code #######
#########################

if __name__=="__main__":
    jpt=False
    jpt=True ##(different for jupyter notebook)
    t0=time.time()
    args=f_parse_args() if not jpt else f_manual_add_argparse()

    #################################
    ### Set up global dictionary###
    gdict={}
    gdict=f_init_gdict(args,gdict)
#     gdict['num_imgs']=200

    if jpt: ## override for jpt nbks
        gdict['num_imgs']=50
        gdict['epochs']=5
        gdict['run_suffix']='nb_test'
        
    ### Set up metrics dataframe
    cols=['step','epoch','Dreal','Dfake','Dfull','G_adv','G_full','spec_loss','hist_loss','spec_chi','hist_chi','gp_loss','fm_loss','D(x)','D_G_z1','D_G_z2','time']
    metrics_df=pd.DataFrame(columns=cols)
    
    # Setup
    metrics_df=f_setup(gdict,metrics_df,log=(not jpt))
    
    ## Build GAN
    gan_model=GAN_model(gdict,True)

    fixed_noise = torch.randn(gdict['op_size'], 1, 1, 1, gdict['nz'], device=gdict['device']) #Latent vectors to view G progress    # Mod for 3D
    rnd_idx=torch.randint(len(gdict['sigma_list']),(gdict['op_size'],1),device=gdict['device'])
    fixed_cosm_params=torch.tensor([gdict['sigma_list'][i] for i in rnd_idx.long()],device=gdict['device']).unsqueeze(-1)
    
    if gdict['distributed']:  try_barrier(gdict['world_rank'])

    ## Load data and precompute
    Dset=Dataset(gdict)
    
    if gdict['distributed']:  try_barrier(gdict['world_rank'])
    
    #################################
    ########## Train loop and save metrics and images ######    
    if gdict['world_rank']==0: 
        logging.info(gdict)
        logging.info("Starting Training Loop...")
    
    f_train_loop(gan_model,Dset,metrics_df,gdict,fixed_noise,fixed_cosm_params)
    
    if gdict['world_rank']==0: ## Generate images for best saved models ######
        for cl in gdict['sigma_list']:
            op_loc=gdict['save_dir']+'/images/'
            ip_fname=gdict['save_dir']+'/models/checkpoint_best_spec.tar'
            f_gen_images(gdict,gan_model.netG,gan_model.optimizerG,cl,ip_fname,op_loc,op_strg='gen_img_best_spec',op_size=32)

            ip_fname=gdict['save_dir']+'/models/checkpoint_best_hist.tar'
            f_gen_images(gdict,gan_model.netG,gan_model.optimizerG,cl,ip_fname,op_loc,op_strg='gen_img_best_hist',op_size=32)
    
    tf=time.time()
    logging.info("Total time %s"%(tf-t0))
    logging.info('End: %s'%(datetime.now().strftime('%Y-%m-%d  %H:%M:%S')))

Not using DDP
Generator(
  (main): Sequential(
    (0): Linear(in_features=65, out_features=32768, bias=True)
    (1): BatchNorm3d(1, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): View()
    (4): ConvTranspose3d(512, 256, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(2, 2, 2), output_padding=(1, 1, 1), bias=False)
    (5): BatchNorm3d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): ConvTranspose3d(256, 128, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(2, 2, 2), output_padding=(1, 1, 1), bias=False)
    (8): BatchNorm3d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ConvTranspose3d(128, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(2, 2, 2), output_padding=(1, 1, 1), bias=False)
    (11): BatchNorm3d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    

In [38]:
# t1=torch.Tensor([[3.9327e+19, 1.9647e+17, 9.8353e+15, 5.6435e+15, 4.4559e+15, 2.4807e+15,
#          3.5780e+15, 3.4637e+15, 3.0052e+15, 3.2010e+15, 4.8305e+15, 3.4678e+15,
#          3.0497e+15, 4.0568e+15, 8.5679e+15, 7.7981e+15, 3.9892e+15, 3.9012e+15,
#          3.4423e+15, 3.1456e+15, 4.0592e+15, 7.4298e+15, 5.1034e+15, 4.5205e+15,
#          3.6105e+15, 4.7984e+15, 3.9527e+15, 5.2564e+15, 3.9960e+15, 3.3404e+15,
#          3.3596e+15, 4.1172e+15, 3.7113e+15, 3.9968e+15, 5.6466e+15, 3.8044e+15,
#          3.7668e+15, 4.3343e+15, 3.9033e+15, 3.5537e+15, 3.1440e+15, 3.4889e+15,
#          3.2363e+15, 3.5794e+15, 4.6079e+15, 6.4687e+15, 1.4444e+16, 3.0655e+15,
#          3.6433e+15, 9.4971e+15, 3.0938e+15, 3.8165e+15, 1.0852e+17]])
# t2=torch.Tensor([[1.4227e+18, 4.5814e+17, 2.3731e+17, 1.2961e+17, 6.2309e+16, 3.0852e+16,
#          1.6995e+16, 9.4283e+15, 5.1132e+15, 2.9901e+15, 1.6873e+15, 9.9320e+14,
#          5.8631e+14, 3.3785e+14, 1.9779e+14, 1.1977e+14, 7.2761e+13, 4.5273e+13,
#          2.7184e+13, 1.5852e+13, 9.7745e+12, 6.0562e+12, 3.7181e+12, 2.3326e+12,
#          1.5128e+12, 9.9294e+11, 6.7178e+11, 4.6634e+11, 3.2382e+11, 2.4292e+11,
#          1.8293e+11, 1.2277e+11, 7.5090e+10, 4.7818e+10, 3.1517e+10, 2.0512e+10,
#          1.3701e+10, 9.3295e+09, 6.4210e+09, 4.6735e+09, 3.5680e+09, 2.4889e+09,
#          1.9398e+09, 1.3137e+09, 8.0469e+08, 5.5723e+08, 3.6250e+08, 2.4713e+08,
#          2.2210e+08, 1.8036e+08, 1.5043e+08, 1.4761e+08, 2.0282e+08]])

# # t1,t2

In [39]:
# # Dset.train_spec_var[-1]

# # idx=int(image_size/2) ### For the spectrum, use only N/2 indices for loss calc.
# idx=32
# ### Warning: the first index is the channel number.For multiple channels, you are averaging over them, which is fine.

# # loss_mean=torch.log(torch.mean(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)))
# # loss_var=torch.log(torch.mean(torch.pow(spec_var[:,:idx]-spec_var_ref[:,:idx],2)))


# epsilon_spec=1e6 ## correction in case of a 0 inside the log (= min value of spectrum)
# # loss_mean=torch.mean(torch.log(torch.pow(spec_mean[:,:idx]-spec_mean_ref[:,:idx],2)+epsilon_spec))
# loss_var =torch.mean(torch.log(torch.pow(t1[:,:idx]-t2[:,:idx],2)+epsilon_spec))    
# print(loss_var)
# # ans=lambda_spec_mean*loss_mean+lambda_spec_var*loss_var

In [35]:
torch.mean(torch.pow(torch.log(t1[:,:idx])-torch.log(t2[:,:idx]),2))

tensor(32.9701)

In [45]:
Dset.train_spec_mean

tensor([[[1.8552e+08, 9.6616e+07, 5.7577e+07, 3.8662e+07, 2.8498e+07,
          2.1656e+07, 1.6994e+07, 1.3622e+07, 1.1247e+07, 9.3366e+06,
          7.9250e+06, 6.6476e+06, 5.6010e+06, 4.7114e+06, 4.0233e+06,
          3.4015e+06, 2.8958e+06, 2.4542e+06, 2.0917e+06, 1.7853e+06,
          1.5208e+06, 1.2952e+06, 1.1163e+06, 9.5578e+05, 8.2363e+05,
          7.1118e+05, 6.2108e+05, 5.4527e+05, 4.8070e+05, 4.3283e+05,
          3.8789e+05, 3.3613e+05, 2.7904e+05, 2.3179e+05, 1.9561e+05,
          1.6538e+05, 1.4153e+05, 1.2193e+05, 1.0564e+05, 9.1189e+04,
          8.2313e+04, 7.1219e+04, 6.3291e+04, 5.3632e+04, 4.4296e+04,
          3.7274e+04, 3.1618e+04, 2.7908e+04, 2.5641e+04, 2.3574e+04,
          2.2025e+04, 2.0532e+04, 2.0107e+04]],

        [[4.8709e+08, 2.7155e+08, 1.7317e+08, 1.2641e+08, 9.6805e+07,
          7.8709e+07, 6.3263e+07, 5.1803e+07, 4.2764e+07, 3.5329e+07,
          2.9729e+07, 2.4546e+07, 2.0229e+07, 1.6615e+07, 1.3684e+07,
          1.1291e+07, 9.3133e+06, 7.7389e

In [40]:
# metrics_df.plot('step','time')
# metrics_df.plot(x='step',y=['hist_loss','spec_chi'],kind='line')
metrics_df

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time,lr_d,lr_g
0,0.0,0.0,0.678347,0.716826,1.39517,3.17817,3.17817,0.0,-0.497655,,,,,0.0411101,0.05727,-3.1299,0.725209,0.0001,0.0001
1,1.0,0.0,0.157085,0.64002,0.797106,6.29367,6.29367,0.0,-0.608192,0.0,-1.11564,,,3.18166,-0.194364,-6.29078,0.709936,0.0001,0.0001
2,2.0,0.0,0.797101,0.13975,0.93685,2.04063,2.04063,0.0,-0.742643,0.0,-0.829946,,,-0.0485594,-3.31683,-1.89558,0.70562,0.0001,0.0001
3,3.0,0.0,0.151557,1.352,1.50356,4.28733,4.28733,0.0,-0.913809,0.0,-0.848158,,,5.43908,1.10805,-4.27125,0.706481,0.0001,0.0001
4,4.0,1.0,0.139006,0.254753,0.393759,5.53134,5.53134,0.0,-1.10428,0.0,-0.930297,,,3.65098,-1.52158,-5.52616,0.725973,0.0001,0.0001
5,5.0,1.0,0.25504,0.127463,0.382503,3.96377,3.96377,0.0,-1.10383,0.0,-1.0895,,,1.52898,-2.86921,-3.94011,0.710698,0.0001,0.0001
6,6.0,1.0,0.199959,0.343768,0.543727,4.26291,4.26291,0.0,-1.19995,0.0,-1.09049,,,2.65582,-1.12418,-4.24629,0.705691,0.0001,0.0001
7,7.0,1.0,0.15692,0.255365,0.412286,5.06021,5.06021,0.0,-1.43114,0.0,-1.17235,,,2.37262,-1.50275,-5.05145,0.704702,0.0001,0.0001
8,8.0,2.0,0.22008,0.130066,0.350146,4.01742,4.01742,0.0,-1.57985,0.0,-1.32292,,,2.01309,-2.66047,-3.99448,0.711303,0.0001,0.0001
9,9.0,2.0,0.162802,0.317971,0.480774,6.10452,6.10452,0.0,-1.61647,0.0,-1.49051,,,3.25047,-1.2149,-6.10139,0.735922,0.0001,0.0001


In [None]:
gdict

In [None]:
print(metrics_df)

In [None]:

sum(p.numel() for p in gan_model.netG.parameters())/1e6,sum(p.numel() for p in gan_model.netD.parameters())/1e6

In [None]:
summary(gan_model.netG,[(1,1,1,1,128),tuple([1])])
# summary(gan_model.netD,[(1,128,128,128),tuple([1])])


In [None]:
class SimpleConv(nn.Module):
    def __init__(self):
        super(SimpleConv, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

    def forward(self, x, y):
        x1 = self.features(x)
        x2 = self.features(y)
        return x1, x2
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleConv().to(device)

summary(model, [(1, 16, 16), (1, 28, 28)])

In [None]:
tuple([1])

### Debug

In [None]:
# Feature matching loss


In [None]:
# netG = Generator(gdict).to(gdict['device'])
# netG.apply(weights_init)
# # # #     print(netG)
# # summary(netG,(1,1,64))
# # Create Discriminator
# netD = Discriminator(gdict).to(gdict['device'])
# netD.apply(weights_init)
# #     print(netD)
# summary(netD,(1,128,128))

In [None]:
netG=gan_model.netG
# summary(netG,(1,1,64))
# # Create Discriminator
netD = gan_model.netD
# #     print(netD)
# summary(netD,(1,128,128))

In [None]:
# Get real data
data=iter(Dset.train_dataloader)
real_data1=data.next()[0]
real_data1.shape
real_data2=data.next()[0]
real_data2.shape
# Get fake data
noise1 = torch.randn(gdict['batch_size'], 1, 1, 1, gdict['nz'], device=gdict['device']) ### Mod for 3D
rnd_idx=torch.randint(len(gdict['sigma_list']),(gdict['batch_size'],1),device=gdict['device'])
fake_cosm_params=torch.tensor([gdict['sigma_list'][i] for i in rnd_idx.long()],device=gdict['device']).unsqueeze(-1)
fake1 = gan_model.netG(noise1,fake_cosm_params)   

In [None]:
from torch.autograd import Variable
from torch.autograd import grad as torch_grad

def f_gradient_penalty(netD, real_data, generated_data,cosm_params,gdict):
    batch_size = gdict['batch_size']

    # Calculate interpolation
    alpha = torch.rand(batch_size, 1, 1, 1,1)
    alpha = alpha.expand_as(real_data).cuda()
    interpolated = alpha * real_data.data.cuda() + (1 - alpha) * generated_data.data.cuda()
    interpolated = Variable(interpolated, requires_grad=True).cuda()

    # Calculate probability of interpolated examples
    prob_interpolated = netD(interpolated,cosm_params)[-1]

    # Calculate gradients of probabilities with respect to examples
    gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                           grad_outputs=torch.ones(prob_interpolated.size()).cuda() ,
                           create_graph=True, retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)
    gn2=gradients.norm(2, dim=1).mean().data
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return ((gradients_norm - 1) ** 2).mean()


In [None]:
f_gradient_penalty(netD,real_data1,real_data2,fake_cosm_params,gdict)
gp_loss=f_gradient_penalty(gan_model.netD,real_cpu,fake.detach(),fake_cosm_params,gdict)

In [None]:
def f_div_loss(fake1,noise1,cosm_params,netG,gdict):
    '''
    Loss to avoid mode collapse. maximizes deviation between generated images for different noise inputs.
    '''
    noise2 = torch.randn(gdict['batch_size'], 1, 1, 1, gdict['nz'], device=gdict['device']) ### Mod for 3D
    fake2 = netG(noise2,cosm_params)   
    
    lambda_div=0.1
    ans=lambda_div*torch.norm(fake1-fake2)/(torch.norm(noise1-noise2))
    
    return 1.0/ans

div_loss=f_div_loss(fake1,noise1,fake_cosm_params,gan_model.netG,gdict)
div_loss