# Extract data from output files
### Analyze the output from a single LBANN run
March 9, 2020 \
April 6, 2020 : Major edit to store files in order of epochs \
April 21, 2020: Major edit, added jupyter widgets to compare pixel intensity plots \
May 8, 2020: Major edit, using all images for a given batch

In [17]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import time
from scipy import fftpack
# from ipywidgets import interact, interact_manual,fixed, SelectMultiple, IntText, IntSlider, FloatSlider,SelectionSlider,BoundedIntText
from ipywidgets import *

In [18]:
%matplotlib widget

In [19]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/LBANN/lbann_cosmogan/3_analysis/')
from modules_image_analysis import *

In [20]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.


def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

## Explore image samples

In [21]:
def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,bkgnd=[]):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    
    ### Crop out large pixel values
    for key in sample_names:
        print(sample_dict[key].shape)
        sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
        print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    
    bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        if len(bkgnd): bkgnd=f_invtransform(bkgnd)
#         hist_range=(0,2000)
    else:
        bins=f_transform(bins)
#         hist_range=(-1,0.996)
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
#         f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=None,bkgnd_arr=bkgnd)

    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale,bkgnd_arr=bkgnd)



## Compare lbann images with input and keras code images

In [31]:
### Load validation input samples
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_2_smoothing_200k/train.npy'
a1=np.load(img_raw)[:10000]
s_raw=f_transform(a1[:,:,:,0])[:10000]

print(s_raw.shape)
# print(s_raw.shape,[a.shape for a in s_keras])

(10000, 128, 128)


In [7]:
# ### Load images from keras code
# img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_200k_samples_peter_dataset_20_epochs/models/gen_imgs.npy'
# img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_3_fixed_cosmology/models/gen_imgs.npy'

# s_keras=[[] for i in range(3)]
# for count,i in enumerate([1,2,3]):
#     img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_{0}_fixed_cosmology/models/gen_imgs.npy'.format(str(i))
#     a1=np.load(img_keras)
#     s_keras[count]=a1[:,:,:]

In [33]:
### Extract a few images generated by Lban 
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_071108_gen_img_exagan/'
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_074157_gen_img_exagan/'

f_strg=parent_dir+'dump_outs/trainer0/model0/sgd.testing.epoch.*.step.*_gen_img_instance1_activation_output0.npy'
f_list=glob.glob(f_strg)
print(f_list)

arr=[np.load(fname)[:,0,:,:] for fname in f_list]
s_lbann=np.vstack(arr)
print(s_lbann.shape,np.max(s_lbann))


['/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_074157_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.3.step.3_gen_img_instance1_activation_output0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_074157_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.0.step.0_gen_img_instance1_activation_output0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_074157_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.1.step.1_gen_img_instance1_activation_output0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_074157_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.2.step.2_gen_img_instance1_activation_output0.npy']
(10023, 128, 128) 0.9929066


In [None]:
# max_val=np.amax(s_lbann,axis=(1,2))
# # print(np.where(max_val>0.994))
# max_val.shape

# plt.figure()
# plt.plot(max_val)

In [None]:
# f_pixel_intensity(s_lbann[330],label='',normalize=False,log_scale=True,mode='simple')
# f_pixel_intensity(f_invtransform(s_lbann[330]),label='',normalize=False,log_scale=True,mode='simple')

In [12]:
### Another dataset
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/generate_images/20200702_071108_gen_img_exagan/dump_outs/trainer0/model0/sgd.testing.epoch.0.step.0_gen_img_instance1_activation_output0.npy'
s_new1=np.load(fname)[:,0,:,:]

fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200629_090438_batchsize_512/dump_outs/trainer0/model0/sgd.validation.epoch.46.step.4606_gen_img_instance1_activation_output0.npy'
s_new2=np.load(fname)[:,0,:,:]

fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200629_090438_batchsize_512/dump_outs/trainer0/model0/sgd.validation.epoch.46.step.4676_gen_img_instance1_activation_output0.npy'
s_new3=np.load(fname)[:,0,:,:]

In [34]:
dict_samples={'lbann':s_lbann,'raw': s_raw}
# dict_samples={'lbann1':s_lbann[:3500],'lbann2':s_lbann[3500:7000],'lbann3':s_lbann[7000:],'raw': s_raw,'new1':s_new1,'new2':s_new2,'new3':s_new3}

# dict_samples={'new':s_new, 'keras1':s_keras[0],'keras2':s_keras[1],'keras3':s_keras[2],'raw': s_raw}
bkgnd=s_raw
bkgnd=[]
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'],bkgnd=fixed(bkgnd))

interactive(children=(SelectMultiple(description='sample_names', options=('lbann', 'raw'), value=()), ToggleBu…

<function __main__.f_widget_compare(sample_names, sample_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True, bkgnd=[])>

## Comparison for older dataset with multiple cosmologies

In [None]:
### Load images from keras code
img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_200k_samples_peter_dataset_20_epochs/models/gen_imgs.npy'
a1=np.load(img_keras)
s_keras=a1[:,:,:]

### Load validation samples
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/peter_dataset/raw_train.npy'

a1=np.load(img_raw)
s_raw=f_transform(a1[:,:,:,0])[:3000]

print(s_raw.shape,s_keras.shape)

In [None]:
### Extract a few images generated by Lban directly for a set of epochs
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200529_111342_seed3273_80epochs/'

ff=[]

for epoch in [60,70]:
    f_strg=parent_dir+'dump_outs/trainer0/model0/sgd.validation.epoch.{}*gen_img*.npy'.format(epoch)
    lst=glob.glob(f_strg)
    ff.append(lst)
f_list=[fle for a in ff for fle in a] ## Flattening out a list of lists
print(len(f_list))

arr=[np.load(fname)[:,0,:,:] for fname in f_list]
s_lbann=np.vstack(arr)
print(s_lbann.shape,np.max(s_lbann))

In [None]:
dict_samples={'lbann':s_lbann, 'keras':s_keras,'raw': s_raw}
bkgnd=[]

interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'],bkgnd=fixed(bkgnd))