# Extract data from output files
### Analyze the output from a single LBANN run
March 9, 2020 \
April 6, 2020 : to store files in order of epochs \
April 21, 2020: added jupyter widgets to compare pixel intensity plots \
May 8, 2020: using all images for a given batch \
May 29, 2020: Modified for new update of LBANN. File names of images changed, so new extraction code. Also added code for computing chi-squared. \
June 17, 2020: Removed train_inp, train_gen and val_inp to reduce memory overhead. From now on, the code only analyzes val_gen \
June 26, 2020: Added gathering of steps and new chi-square quantities.\
July 1, 2020: Switched back to storing mainly train_gen with large steps (10 steps saved for 256 batchsize).

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import itertools
import time
from scipy import fftpack
# from ipywidgets import interact, interact_manual,fixed, SelectMultiple, IntText, IntSlider, FloatSlider,SelectionSlider,BoundedIntText
from ipywidgets import *

In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/lbann_cosmogan/3_analysis')
from modules_image_analysis import *

[NbConvertApp] Converting notebook modules_image_analysis.ipynb to script
[NbConvertApp] Writing 17167 bytes to modules_image_analysis.py


In [4]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

In [5]:
# ### Other transformatino functinos
# ### Transformation functions for image pixel values

# def f_transform_new(x):
#     if x<=50:
#         a=0.03; b=-1.0
#         return a*x+b
#     elif x>50: 
#         a=0.5/np.log(300)
#         b=0.5-a*np.log(50)
#         return a*np.log(x)+b

# def f_invtransform_new(y):
#     if y<=0.5:
#         a=0.03;b=-1.0
#         return (y-b)/a
#     elif y>0.5: 
#         a=0.5/np.log(300)
#         b=0.5-a*np.log(50)
#         return np.exp((y-b)/a)
    

# def f_transform(x):
#     return np.vectorize(f_transform_new)(x)

# def f_invtransform(s):
#     return np.vectorize(f_invtransform_new)(s)

# f_transform_new(2000)

### Modules for Extraction

In [6]:
def f_get_files_df_sorted():
    '''
    Module to create Dataframe with filenames for each epoch and step
    Sorts by step and epoch
    '''
    
    ## Get images files and .npy arrays for each image in dump_outs folder
    t1=time.time()
    files_dict={}
#     keys=['train_gen','train_input','val_gen','val_input']
#     file_strg_lst=['model0-training*-gen_img*-output0.npy','model0-training*-inp_img*-output0.npy','model0-validation*-gen_img*-output0.npy','model0-validation*-inp_img*-output0.npy']
#     file_strg_lst=['sgd.training*_gen_img*_output0.npy','sgd.training*_inp_img*_output0.npy','sgd.validation*_gen_img*_output0.npy','sgd.validation*_inp_img*_output0.npy']

#     keys=['val_gen']
#     file_strg_lst=['sgd.validation*_gen_img*_output0.npy']
    keys=['train_gen']
    file_strg_lst=['sgd.training*_gen_img*_output0.npy']
    
    for key,file_strg in zip(keys,file_strg_lst):
        files_dict[key]=np.array(glob.glob(main_dir+file_strg))
        if files_dict[key].shape[0]>1000 : 
            print('Warning the number of files is very large. Possibility of memory overload')
    
    df_files=pd.DataFrame([])
    dict1={}
    t1=time.time()
    ### First get sorted Dataframe with file names
    for key in keys:
        files_arr=files_dict[key]  # Get array of files
        print(key,len(files_arr))
        for fname in files_arr:
            ### Extract the Epoch number and step number from the file name
            dict1['img_type']=key
            dict1['epoch']=np.int32(fname.split('epoch')[-1].split('.')[1])
            dict1['step']=np.int64(fname.split('step')[-1].split('.')[1].split('_')[0])
            dict1['fname']=fname
            
            df_files=df_files.append(dict1,ignore_index=True)
    ## Sort values
    df_files=df_files.sort_values(by=['img_type','epoch','step']).reset_index(drop=True)
    # df_files
    t2=time.time()
    print("Time for Sorting",t2-t1)
    
    return df_files


def f_filter_epoch(df_input,num_sliced=1):
    '''
    Get just equally spaced steps for each epoch
    '''
    print('Extracting %s steps of each epoch'%(num_sliced))
    df_output=pd.DataFrame([])
#     for key in ['train_gen','train_input','val_gen','val_input']: 
    for key in ['train_gen']: 
        ### For each type of images, get list of epochs
        df1=df_input[df_input.img_type==key]
        epochs=np.unique(df1.epoch.values).astype(int)

        for epoch in epochs:### Extract the last few steps in each epoch
            arr_step=df1[df1.epoch==epoch].step.values   ## Get all steps
            idxs=np.round(np.linspace(0,len(arr_step)-1,num_sliced).astype(int)) ## Get indices with equal spacing 
            df2=df1[df1.step.isin(arr_step[idxs])]        ## Get dataframe with these steps
            df_output=df_output.append(df2)  
    
    return df_output.reset_index(drop=True)

def f_get_images_df(df_files):
    '''
    Read dataframe with file names, read files and create new dataframe with images as numpy arrays
    Also computes number of images with intensity beyond a cutoff
    '''
    
    def f_row(df_row):
        '''
        Extract image
        '''
        fname,key=df_row.fname,df_row.img_type
        a1=np.load(fname)
        if key.endswith('input'): 
            size=np.int(np.sqrt(a1.shape[-1])) ### Extract size of images (=128)
            batch_size=a1.shape[0] ### Number of batches
            samples=a1.reshape(batch_size,size,size)
        elif key.endswith('gen') : samples=a1[:,0,:,:]
        else : raise SystemError

        return samples
    
    def f_high_pixel(df_row,cutoff=0.9966):
        '''
        Get number of images with a pixel about max cut-off value
        '''
        max_arr=np.amax(df_row.images,axis=(1,2))
        num_large=max_arr[max_arr>cutoff].shape[0]

        return num_large
    
    t1=time.time()
    ##### Create new Dataframe with sorted images
    df=df_files.copy()
    df['images']=df.apply(lambda row: f_row(row), axis=1)
    t2=time.time()
    print("Time for Reading images",t2-t1)
    
    ### Store the number of images with large pixel value
    cutoff=0.9966
    df['num_large']=df.apply(lambda row: f_high_pixel(row,cutoff), axis=1)
    
    return df
    


In [7]:
def f_get_sample_epochs(df,img_type,start_epoch=None,end_epoch=None):
    '''
    Module to extract images for a range of epochs given a dataframe
    '''
    if start_epoch==None and end_epoch==None:
        max_epoch=np.int(np.max(df.epoch.values))
#         print(max_epoch)
        start_epoch=0; end_epoch=max_epoch
#     if end_epoch==None: end_epoch=start_epoch+1
    
    arr=df[(df.epoch>=start_epoch) & (df.epoch<=end_epoch) & (df.img_type==img_type)].images.values
    arr=np.vstack(arr)
    
    return arr


def f_get_step(df,img_type,epoch,step):
    '''
    Module to extract images for a specific step and epoch
    '''
    
    arr=df[(df.epoch==epoch) & (df.step==step) & (df.img_type==img_type)].images.values
    arr=np.vstack(arr)
    
    return arr

def f_get_step_group(df,img_type,step_list):
    '''
    Module to extract images for a range of epochs given a dataframe
    '''
    arr=df[(df.step.isin(step_list)) & (df.img_type==img_type)].images.values
    arr=np.vstack(arr)
    
    return arr

## Extract image data 

In [8]:
# fldr_name='20200529_111342_seed3273_80epochs'
# fldr_name='20200701_065330_batchsize_512/'
# fldr_name='20200701_070005_batchsize_256/'
fldr_name='20200718_114324_batchsize_512/'
# fldr_name='20200718_135530_batchsize_256/'

main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/{0}/dump_outs/trainer0/model0/'.format(fldr_name)
print(main_dir)

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200718_114324_batchsize_512//dump_outs/trainer0/model0/


In [9]:
### Extract validation data
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_2_smoothing_200k/norm_1_train_val.npy'
s_val=np.load(fname,mmap_mode='r')[:8000][:,0,:,:]
print(s_val.shape)

(8000, 128, 128)


In [10]:
### Get dataframe with file names, sorted by epoch and step
df_files=f_get_files_df_sorted()
### Slice out rows to keep only the last few steps for each epoch.
# df_files=f_filter_epoch(df_files,num_sliced=10)

#############################################################
### Read images one by one into a numpy array and create a new DataFrame
df_full=f_get_images_df(df_files)
print(df_full.shape)
# ### Filter to keep just one step per epoch
# df_full=f_filter_epoch(df_full,1)

train_gen 1191
Time for Sorting 4.624292850494385
Time for Reading images 103.36069083213806
(1191, 6)


In [11]:
# df_files.head(20)

## Chi-square

In [12]:
def f_compute_chisqr(df,s_input):
    ''' Compute chi-sqr values of pixel intensity histogram and spectrum for each row
    Uses the module f_pixel_intensity to compute histograms and f_compute_spectrum for spectrum
    '''
    
    def f_chisqr(df_row,val_hist,val_err,val_spec,val_spec_err,bins,transform):
        ''' Compute chi-sqr of rows wrt to input data'''
        
        val_dr=val_hist.copy()
        val_dr[val_dr<=0.]=1.0    ### Avoiding division by zero for zero bins
        
        ### Get all images in a batch
        sample=df_row.images if not transform else f_invtransform(df_row.images)

        ### Compute pixel histogram for row   ### !!Both pixel histograms MUST have same bins and normalization!
        gen_hist,gen_err=f_pixel_intensity(sample,plot=False,normalize=True,bins=bins,mode='avg')
        spec,spec_err=f_compute_spectrum(sample,plot=False)

        ### Compute chi-sqr
        ### Used in keras code : np.sum(np.divide(np.power(valhist - samphist, 2.0), valhist))
        ###  chi_sqr :: sum((Obs-Val)^2/(Val))
        sq_diff=(gen_hist-val_hist)**2        
        chi_sqr_list=[]
        
        for count,(start,end) in enumerate(zip([0,22,38,0],[22,38,None,None])):  # 4 lists : small, medium, large pixel values and full 
            chi_sqr_list.append(np.sum(np.divide(sq_diff[start:end],val_dr[start:end])))
        
        idx=None  # Choosing the number of histograms to use. Eg : -5 to skip last 5 bins
        
        chi_sqr_list.append(np.sum(np.divide(sq_diff[:idx],1.0))) ## chi-sqr without denominator division
        chi_sqr_list.append(np.sum(gen_err[:idx])/np.sum(val_err[:idx])) ## measures total spread in histograms wrt to input data
        
        ### computing the spectral loss chi-square
        chi_sqr_list.append(np.sum((val_spec[:50]-spec[:50])**2/(spec[:50]**2)))
        
#         chi_sqr1=np.sum(np.divide(np.power(gen_hist[:idx] - val_hist[:idx], 2.0), val_dr[:idx]))
#         chi_sqr2=np.sum(np.divide(np.power(gen_hist[:idx] - val_hist[:idx], 2.0), 1.0))
#         chi_sqr3=np.sum(gen_err[:idx])/np.sum(val_err[:idx])  ## measures total spread in histograms wrt to input data
        
        return chi_sqr_list
    
    ########################
    ###### Code starts ########
    transform=False  # If true, it computes histogram in the orignal scale of pixels ie. 0-2000 
    
    ## Get bins for histograms
    bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
    if not transform: bins=f_transform(bins)   ### scale to (-1,1)
#     bins=100
#     print(bins)
    
    ### Get pixel histogram of all input data
    val_hist,val_err=f_pixel_intensity(s_input,plot=False,normalize=True,bins=bins,mode='avg')    
    ### Computing spectrum ###
    val_spec,val_spec_err=f_compute_spectrum(s_input,plot=False)
    del s_input


    ### Get chi-sqr for each row (step-epoch) for generated data
    chi_sqrs=df.apply(lambda row: f_chisqr(row,val_hist=val_hist,val_err=val_err,val_spec=val_spec,val_spec_err=val_spec_err,bins=bins,transform=transform), axis=1).values
    chi_vals=np.array(list(zip(*chi_sqrs)))  ## transposing list of list
        
    chi_sqr_keys=['chi_sqr1a','chi_sqr1b','chi_sqr1c','chi_sqr1d','chi_sqr2','chi_img_var','chi_spec']
    for key,chi_val in zip(chi_sqr_keys,chi_vals):
        df[key]=chi_val
    
    return df

def f_get_best_chisqr_models(df):
    
    chi_sqr_keys=['chi_sqr1a','chi_sqr1b','chi_sqr1c','chi_sqr1d','chi_sqr2','chi_img_var','chi_spec']
    q_dict=dict(df.quantile(q=0.2,axis=0)[chi_sqr_keys])
    print(q_dict)
    
    df_sliced=df.query('chi_sqr1d < {0} & chi_spec < {1}'.format(q_dict['chi_sqr1d'],q_dict['chi_spec']))[['epoch','step','img_type','num_large']+chi_sqr_keys]
    
    return df_sliced


In [13]:
t1=time.time()
# df1=f_compute_chisqr(df_full.loc[[0,1,2,3]],s_val) # Test on small df
df_full=f_compute_chisqr(df_full,s_val)
t2=time.time()
print("Time to compute chi-sqr",t2-t1)

Time to compute chi-sqr 1327.955484867096


In [14]:
df=df_full.copy()

In [15]:
df_sliced=f_get_best_chisqr_models(df_full)
print(df_sliced.shape)

{'chi_sqr1a': 0.006639435447355462, 'chi_sqr1b': 0.0015197749166827547, 'chi_sqr1c': 0.00436794207645463, 'chi_sqr1d': 0.0209644692685795, 'chi_sqr2': 0.002158372915136797, 'chi_img_var': 3.8579657967629752, 'chi_spec': 0.290335942226161}
(109, 11)


In [16]:
df_sliced

Unnamed: 0,epoch,step,img_type,num_large,chi_sqr1a,chi_sqr1b,chi_sqr1c,chi_sqr1d,chi_sqr2,chi_img_var,chi_spec
517,26.0,10340.0,train_gen,0,0.007180,0.008461,0.003576,0.019217,0.003065,5.090208,0.167641
644,32.0,12880.0,train_gen,0,0.008244,0.007410,0.003983,0.019637,0.002239,3.794643,0.257526
653,32.0,13060.0,train_gen,0,0.006324,0.005145,0.004215,0.015684,0.005843,3.705136,0.256142
657,33.0,13140.0,train_gen,0,0.005799,0.009406,0.004032,0.019237,0.005611,3.718602,0.192347
670,33.0,13400.0,train_gen,0,0.004185,0.001189,0.013480,0.018853,0.000416,4.062990,0.269271
707,35.0,14140.0,train_gen,0,0.001134,0.013567,0.004473,0.019174,0.000322,3.836082,0.276532
728,36.0,14560.0,train_gen,0,0.001854,0.006157,0.007909,0.015920,0.001185,3.986074,0.182779
745,37.0,14900.0,train_gen,0,0.005406,0.001140,0.004742,0.011288,0.001491,4.487515,0.283064
752,37.0,15040.0,train_gen,1,0.004004,0.004217,0.004830,0.013050,0.000536,4.029918,0.130782
756,38.0,15120.0,train_gen,0,0.001398,0.004901,0.007643,0.013942,0.000303,4.139280,0.150373


#### View best epochs

#### Locations with best chi_sqr

In [17]:
chi_sqr_keys=['epoch','step','chi_sqr1a','chi_sqr1b','chi_sqr1c','chi_sqr1d','chi_sqr2','chi_img_var','chi_spec']
# index location of min/max values of chi squares
inds=np.array([df.chi_spec.idxmin(axis=1),df.chi_sqr1d.idxmin(axis=1),df.chi_img_var.idxmax(axis=1)])
df.loc[inds][chi_sqr_keys]

Unnamed: 0,epoch,step,chi_sqr1a,chi_sqr1b,chi_sqr1c,chi_sqr1d,chi_sqr2,chi_img_var,chi_spec
1065,53.0,21300.0,0.023039,0.000653,0.009766,0.033458,0.013052,5.288517,0.072174
1036,52.0,20720.0,0.000728,0.000863,0.001318,0.002909,0.000483,4.056787,0.235563
0,0.0,0.0,193.108875,0.110686,0.013167,193.232728,12.39876,10.757241,17731.082532


### Sorted dataframe by best chi-squares

In [18]:
df_full.sort_values(by=['chi_spec','chi_sqr1d'])[['epoch','step','chi_sqr1d','chi_spec']].head(20)

Unnamed: 0,epoch,step,chi_sqr1d,chi_spec
1065,53.0,21300.0,0.033458,0.072174
1014,51.0,20280.0,0.01308,0.083109
926,46.0,18520.0,0.028072,0.083602
958,48.0,19160.0,0.047092,0.095514
793,39.0,15860.0,0.035201,0.101665
943,47.0,18860.0,0.01518,0.107749
867,43.0,17340.0,0.014532,0.108531
905,45.0,18100.0,0.0197,0.109562
906,45.0,18120.0,0.020013,0.112525
1141,57.0,22820.0,0.077618,0.112565


In [19]:
df_full[chi_sqr_keys].describe()


Unnamed: 0,epoch,step,chi_sqr1a,chi_sqr1b,chi_sqr1c,chi_sqr1d,chi_sqr2,chi_img_var,chi_spec
count,1191.0,1191.0,1191.0,1191.0,1191.0,1191.0,1191.0,1191.0,1191.0
mean,29.476071,11900.0,0.279168,0.015427,0.11785,0.412444,0.07672,4.497628,84.8699
std,17.324269,6879.127852,5.654854,0.027909,1.293057,5.804783,0.431816,0.888723,1244.601044
min,0.0,0.0,0.000307,0.000129,0.000503,0.002909,6.7e-05,0.862455,0.072174
25%,14.5,5950.0,0.008995,0.00196,0.004909,0.024053,0.00314,3.956633,0.327335
50%,29.0,11900.0,0.02365,0.005966,0.009291,0.046784,0.012356,4.373948,0.615323
75%,44.0,17850.0,0.065484,0.016222,0.02127,0.129353,0.044381,4.855019,1.291721
max,59.0,23800.0,193.108875,0.326767,34.132857,193.232728,12.39876,10.757241,32209.181642


In [20]:
df_sliced

Unnamed: 0,epoch,step,img_type,num_large,chi_sqr1a,chi_sqr1b,chi_sqr1c,chi_sqr1d,chi_sqr2,chi_img_var,chi_spec
517,26.0,10340.0,train_gen,0,0.007180,0.008461,0.003576,0.019217,0.003065,5.090208,0.167641
644,32.0,12880.0,train_gen,0,0.008244,0.007410,0.003983,0.019637,0.002239,3.794643,0.257526
653,32.0,13060.0,train_gen,0,0.006324,0.005145,0.004215,0.015684,0.005843,3.705136,0.256142
657,33.0,13140.0,train_gen,0,0.005799,0.009406,0.004032,0.019237,0.005611,3.718602,0.192347
670,33.0,13400.0,train_gen,0,0.004185,0.001189,0.013480,0.018853,0.000416,4.062990,0.269271
707,35.0,14140.0,train_gen,0,0.001134,0.013567,0.004473,0.019174,0.000322,3.836082,0.276532
728,36.0,14560.0,train_gen,0,0.001854,0.006157,0.007909,0.015920,0.001185,3.986074,0.182779
745,37.0,14900.0,train_gen,0,0.005406,0.001140,0.004742,0.011288,0.001491,4.487515,0.283064
752,37.0,15040.0,train_gen,1,0.004004,0.004217,0.004830,0.013050,0.000536,4.029918,0.130782
756,38.0,15120.0,train_gen,0,0.001398,0.004901,0.007643,0.013942,0.000303,4.139280,0.150373


In [21]:
### Plot chi-sqr values
df_sliced.plot(x="epoch", y=["chi_sqr1d", "chi_img_var", "chi_spec"],style='.',marker='*')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.axes._subplots.AxesSubplot at 0x2aaadcc76a58>

### High Pixel images

In [22]:
### Plot number of high pixel images
plt.figure()
plt.plot(df[df.img_type=='val_gen'].epoch,df[df.img_type=='val_gen'].num_large,linestyle='',marker='*')
plt.xlabel('Steps in Epochs')
plt.ylabel('Number of large pixel images from a batch of images')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, 0.5, 'Number of large pixel images from a batch of images')

In [23]:
df[(df.num_large>0) &(df.img_type=='val_gen')][['epoch','step','num_large']]

Unnamed: 0,epoch,step,num_large


## Compare samples

In [None]:
def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,bkgnd=[]):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
#     ### Crop out large pixel values
#     for key in sample_names:
#         print(sample_dict[key].shape)
#         sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
#         print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    
    bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        if len(bkgnd): bkgnd=f_invtransform(bkgnd)
#         hist_range=(0,2000)
    else:
        bins=f_transform(bins)
#         hist_range=(-1,0.996)
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
#         f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=None,bkgnd_arr=bkgnd)

    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale,bkgnd_arr=bkgnd)


### Compare different steps

In [None]:
# img_list,labels_list=f_get_sample_epochs(df,'train_gen',10)

img_list,labels_list=[],[]
for a,b in df_sliced.iterrows():
    epoch,step=int(b.epoch),int(b.step)
    img_list.append(f_get_step(df,'train_gen',epoch,step))
    labels_list.append('%s:%s'%(str(epoch),str(step)))

dict_samples=dict.fromkeys(labels_list)
for key,val in zip(labels_list,img_list): dict_samples[key]=val

dict_samples.keys()
# ### Compare with input
# # dict_samples['keras']=s_keras
# dict_samples['input']=s_val
bkgnd=[]
bkgnd=s_val
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),bins=IntText(value=50),mode=['avg','simple'],bkgnd=fixed(bkgnd))



### Plot step groups in best epochs

In [None]:
print(np.unique(df_sliced.epoch.values))
step_list=df_sliced[df_sliced.epoch==26].step.values
print(step_list)

In [None]:
img_list,labels_list=[],[]
for epoch in np.unique(df_sliced.epoch.values).astype(int):
    step_list=df_sliced[df_sliced.epoch==epoch].step.values
    print(epoch,step_list)
    img_list.append(f_get_step_group(df,'train_gen',step_list))
    labels_list.append('%s'%(str(epoch)))

dict_samples=dict.fromkeys(labels_list)
for key,val in zip(labels_list,img_list): dict_samples[key]=val

dict_samples.keys()
# # ### Compare with input
# # # dict_samples['keras']=s_keras
# # dict_samples['input']=s_val
# bkgnd=[]
bkgnd=s_val
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),bins=IntText(value=50),mode=['avg','simple'],bkgnd=fixed(bkgnd))



In [None]:
# img_lst=[f_invtransform(i) for i in img_list]
# bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
# # bins=200
# f_compare_pixel_intensity(img_list,labels_list,normalize=True,log_scale=True, mode='avg',bins=bins,hist_range=None)
# f_compare_spectrum(img_list,labels_list,log_scale=True)


## View image block

In [None]:
def f_plot_grid(arr,cols=16,fig_size=(15,5)):
    ''' Plot a grid of images
    '''
    size=arr.shape[0]    
    rows=int(np.ceil(size/cols))
    print(rows,cols)
    
    fig,axarr=plt.subplots(rows,cols,figsize=fig_size, gridspec_kw = {'wspace':0, 'hspace':0})
    if rows==1: axarr=np.reshape(axarr,(rows,cols))
    if cols==1: axarr=np.reshape(axarr,(rows,cols))
    
    for i in range(min(rows*cols,size)):
        row,col=int(i/cols),i%cols
        try: 
            axarr[row,col].imshow(arr[i],origin='lower',interpolation='nearest',cmap='cool', extent = [0, 128, 0, 128])
        # Drop axis label
        except Exception as e:
            print('Exception:',e)
            pass
        temp=plt.setp([a.get_xticklabels() for a in axarr[:-1,:].flatten()], visible=False)
        temp=plt.setp([a.get_yticklabels() for a in axarr[:,1:].flatten()], visible=False)
    
#     fig.subplots_adjust(wspace=0.00,hspace=0.000)
#     fig.tight_layout()

f_plot_grid(img_arr,cols=6,fig_size=(10,5))


In [None]:
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200701_054823_exagan/dump_outs/trainer0/model0/sgd.training.epoch.21.step.8480_gen_img_instance1_activation_output0.npy'
# fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200629_145233_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.0.step.0_gen_img_instance1_activation_output0.npy'
s_new=np.load(fname)[:,0,:,:]
print(s_new.shape)

In [None]:
f_plot_grid(s_new[100:118],cols=6,fig_size=(10,5))

In [None]:
f_plot_grid(s_val[100:118],cols=6,fig_size=(10,5))