# Generate Images
March 1, 2021

In [1]:
import argparse
import os
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data

from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from torch.utils.data import DataLoader, TensorDataset

import time
from datetime import datetime
import glob
import pickle
import yaml
import collections

In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/cosmogan_pytorch/code/5_3d_cgan/1_main_code/')
from utils import *

## Build Model structure

In [4]:
def f_load_config(config_file):
    with open(config_file) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    return config

def f_manual_add_argparse():
    ''' use only in jpt notebook'''
    args=argparse.Namespace()
    args.config='1_main_code/config_3d_cgan_64_cori.yaml'
    args.mode='fresh'
    args.ip_fldr=''
#     args.local_rank=0
    args.facility='cori'
    args.distributed=False
    args.ngpu=1
#     args.mode='continue'
#     args.ip_fldr='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201211_093818_nb_test/'
    
    return args
             
    
def weights_init(m):
    '''custom weights initialization called on netG and netD '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def f_init_gdict(args,gdict):
    ''' Create global dictionary gdict from args and config file'''
    
    ## read config file
    config_file=args.config
    with open(config_file) as f:
        config_dict= yaml.load(f, Loader=yaml.SafeLoader)
        
    gdict=config_dict['parameters']

    args_dict=vars(args)
    ## Add args variables to gdict
    for key in args_dict.keys():
        gdict[key]=args_dict[key]
    return gdict

In [5]:
def f_gen_images(gdict,netG,optimizerG,sigma,ip_fname,op_loc,op_strg='inf_img_',op_size=500):
    '''Generate images for best saved models
     Arguments: gdict, netG, optimizerG, 
                 sigma : sigma input value
                 ip_fname: name of input file
                op_strg: [string name for output file]
                op_size: Number of images to generate
    '''

    nz,device=gdict['nz'],gdict['device']

    try:
        if torch.cuda.is_available(): checkpoint=torch.load(ip_fname)
        else: checkpoint=torch.load(ip_fname,map_location=torch.device('cpu'))
    except Exception as e:
        print(e)
        print("skipping generation of images for ",ip_fname)
        return
    
    ## Load checkpoint
    if gdict['multi-gpu']:
        netG.module.load_state_dict(checkpoint['G_state'])
    else:
        netG.load_state_dict(checkpoint['G_state'])
    
    ## Load other stuff
    iters=checkpoint['iters']
    epoch=checkpoint['epoch']
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    
    # Generate batch of latent vectors
    noise = torch.randn(op_size, 1, 1, 1, nz, device=device)
    tnsr_cosm_params=(torch.ones(op_size,device=device)*sigma).view(op_size,1)

    # Generate fake image batch with G
    netG.eval() ## This is required before running inference
    with torch.no_grad(): ## This is important. fails without it for multi-gpu
        gen = netG(noise,tnsr_cosm_params)
        gen_images=gen.detach().cpu().numpy()
        print(gen_images.shape)

    op_fname='%s_label-%s_epoch-%s_step-%s.npy'%(op_strg,sigma,epoch,iters)
    np.save(op_loc+op_fname,gen_images)
    
    print("Image saved in ",op_fname)

In [6]:
if __name__=="__main__":
    torch.backends.cudnn.benchmark=True
    t0=time.time()
    #################################
    args=f_manual_add_argparse()
    ### Set up ###

    # Initilize variables    
    gdict={}
    gdict=f_init_gdict(args,gdict)
    print(gdict)
    ## Add args variables to gdict
#     for key in ['ngpu']:
#         gdict[key]=vars(args)[key]
    gdict['device']=torch.device("cuda" if (torch.cuda.is_available() and gdict['ngpu'] > 0) else "cpu")
    gdict['ngpu']=torch.cuda.device_count()
    gdict['multi-gpu']=True if (gdict['device'].type == 'cuda') and (gdict['ngpu'] > 1) else False 

    Generator, Discriminator=f_get_model(gdict['model'],gdict)
    print("Building GAN networks")
    # Create Generator
    netG = Generator(gdict).to(gdict['device'])
    netG.apply(weights_init)
    #     print(netG)
    # summary(netG,(1,1,64))
    
    print("Number of GPUs used %s"%(gdict['ngpu']))
    if (gdict['multi-gpu']):
        netG = nn.DataParallel(netG, list(range(gdict['ngpu'])))

    optimizerG = optim.Adam(netG.parameters(), lr=gdict['learn_rate_g'], betas=(gdict['beta1'], 0.999),eps=1e-7)



{'ip_fname': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing', 'op_loc': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/', 'image_size': 64, 'num_imgs': 50000, 'kappa': 40, 'ip_fldr': '', 'chkpt_file': '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20211028_113034_cgan_bs32_nodes8_lr0.00001-fixed_fm00_spec0.0_kappa40/models/checkpoint_last.tar', 'workers': 2, 'nc': 1, 'nz': 64, 'ngf': 64, 'ndf': 64, 'beta1': 0.5, 'kernel_size': 5, 'stride': 2, 'g_padding': 2, 'd_padding': 2, 'flip_prob': 0.01, 'bns': 50, 'checkpoint_size': 10, 'sigma_list': [0.5, 0.8, 1.1], 'model': 2, 'batch_size': 32, 'epochs': 300, 'op_size': 16, 'learn_rate_d': 1e-05, 'learn_rate_g': 1e-05, 'lr_d_epochs': [15, 30, 45, 60, 75, 90], 'lr_d_gamma': 1.0, 'lr_g_epochs': [15, 30, 45, 60, 75, 90], 'lr_g_gamma': 1.0, 'deterministic': False, 'seed': 234373, 'lambda_spec_m

## Run Inference

In [7]:
ls /global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/

[0m[01;34m20210223_210217_3dcgan_predict_0.8_m2[0m/
[01;34m20210227_050213_3dcgan_predict_0.65_m2[0m/
[01;34m20210309_200006_3dcgan_predict_0.65_m2[0m/
[01;34m20210519_81818_cgan_bs16_lr0.001_nodes8_spec0.1_lrg0.001fixed_lrd-fastfall_good-results[0m[K/
[01;34m20210615_72613_cgan_bs32_nodes8_lr0.0001_good_cgan_run[0m/
[01;34m20210616_212328_cgan_bs32_nodes8_lr0.0001_fm50[0m/
[01;34m20210617_204752_cgan_bs32_nodes8_lr0.0001-vary_fm50[0m/
[01;34m20210619_224213_cgan_bs32_nodes8_lr0.0001-vary_fm50[0m/
[01;34m20210620_113852_cgan_bs32_nodes8_lr0.0001-vary_fm50_spec0.01[0m/
[01;34m20210620_63445_cgan_bs32_nodes8_lr0.0001-vary_fm50_spec0.05[0m/


In [8]:
# ## For single checkpoint
# main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'
# fldr='20210119_134802_cgan_predict_0.65_m2'
# param_label=1.1
# op_loc=main_dir+fldr+'/images/'
# ip_fname=main_dir+fldr+'/models/checkpoint_best_spec.tar'
# f_gen_images(gdict,netG,optimizerG,param_label,ip_fname,op_loc,op_strg='inference_spec',op_size=1000)


In [7]:
## For multiple checkpoints 
main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/'
fldr='20210929_32023_cgan_bs32_nodes8_lr0.0001-fixed_fm10_spec0.0_kappa10_goodrun'

param_list=[0.5,0.65,0.8,1.1]
op_loc=main_dir+fldr+'/images/'
step_list=[42770,43060,48560,48570]
for step in step_list:
    for param_label in param_list:
        try:
    #     ip_fname=main_dir+fldr+'/models/checkpoint_{0}.tar'.format(step)
    #     f_gen_images(gdict,netG,optimizerG,param_label,ip_fname,op_loc,op_strg='inference_spec',op_size=1000)
            ip_fname=glob.glob(main_dir+fldr+'/models/checkpoint_*{0}.tar'.format(step))[0]
            print(ip_fname)
            f_gen_images(gdict,netG,optimizerG,param_label,ip_fname,op_loc,op_strg='inference',op_size=128)
        except Exception as e:
            print(e)
            print("skipping ",step)       

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210929_32023_cgan_bs32_nodes8_lr0.0001-fixed_fm10_spec0.0_kappa10_goodrun/models/checkpoint_42770.tar
(128, 1, 64, 64, 64)
Image saved in  inference_label-0.5_epoch-146_step-42770.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210929_32023_cgan_bs32_nodes8_lr0.0001-fixed_fm10_spec0.0_kappa10_goodrun/models/checkpoint_42770.tar
(128, 1, 64, 64, 64)
Image saved in  inference_label-0.65_epoch-146_step-42770.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210929_32023_cgan_bs32_nodes8_lr0.0001-fixed_fm10_spec0.0_kappa10_goodrun/models/checkpoint_42770.tar
(128, 1, 64, 64, 64)
Image saved in  inference_label-0.8_epoch-146_step-42770.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210929_32023_cgan_bs32_nodes8_lr0.0001-fixed_fm10_spec0.0_kappa10_g

In [10]:
# fname=op_loc+'inference_spec_epoch-11_step-37040.npy'
# a1=np.load(fname)
# print(a1.shape)