# Extract data from output files
### Analyze the output from a single LBANN run
March 9, 2020 \
April 6, 2020 : Major edit to store files in order of epochs \
April 21, 2020: Major edit, added jupyter widgets to compare pixel intensity plots
May 8, 2020: Major edit, using all images for a given batch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import time
from scipy import fftpack
# from ipywidgets import interact, interact_manual,fixed, SelectMultiple, IntText, IntSlider, FloatSlider,SelectionSlider,BoundedIntText
from ipywidgets import *

In [None]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/LBANN/lbann_cosmogan/3_analysis/')
from modules_image_analysis import *

[NbConvertApp] Converting notebook modules_image_analysis.ipynb to script
[NbConvertApp] Writing 15104 bytes to modules_image_analysis.py


In [60]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.


def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

## Explore image samples

In [43]:
       
def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    img_list=[arr for arr in img_list if np.max(arr)<=0.996]
    print(len(img_list))
    
    hist_range=(0,0.996)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
    
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale)



### Comparing datasets

In [6]:
dataset1='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/very_large_dataset_train.npy'
dataset2='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/large_dataset_train.npy'
dataset3='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/peter_dataset/raw_train.npy'

s_input=[[],[],[]]
for count,fname in enumerate([dataset1,dataset2,dataset3]):
    arr=np.load(fname)
    num_samples=arr.shape[0]
    print(arr.shape)
    idxs=np.random.choice(np.arange(num_samples),size=50000,replace=False)
#     idxs=np.arange(1000,51000)
    arr=arr[idxs]
#     s_input[count]=arr[:,:,:,0]
    s_input[count]=f_transform(arr[:,:,:,0])


(210124, 128, 128, 1)
(105062, 128, 128, 1)
(197000, 128, 128, 1)


In [15]:
dict_samples={'vlarge':s_input[0], 'large':s_input[1],'peter_data': s_input[2]}

interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=('vlarge', 'large', 'peter_data'), va…

<function __main__.f_widget_compare(sample_names, sample_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True)>

#### Compare lbann images with input and keras code images

In [56]:
### Load images from keras code
# img_keras='/global/cfs/cdirs/dasrepo/vpa/cosmogan/data/computed_data/exagan1/run_100k_samples_35epochs/models/gen_imgs.npy'
# img_keras='/global/cfs/cdirs/dasrepo/vpa/cosmogan/data/computed_data/exagan1/run_200k_samples_24epochs/models/gen_imgs.npy'
img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_200k_samples_peter_dataset_20_epochs/models/gen_imgs.npy'
a1=np.load(img_keras)
s_keras=a1[:,:,:]

### Load validation samples
# img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/very_large_dataset_val.npy'
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/peter_dataset/raw_val.npy'

a1=np.load(img_raw)
s_raw=f_transform(a1[:,:,:,0])[:3000]

print(s_raw.shape,s_keras.shape)

(2680, 128, 128) (3000, 128, 128)


In [85]:
### Extract a few images generated by Lban directly for a set of epochs
# parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200506_121613_exagan_200k_samples/'
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200513_121910_peters_dataset/'

ff=[]
for epoch in np.arange(55,60):
    f_strg=parent_dir+'dump_outs/model0-validation-epoch{}-*gen_img*.npy'.format(epoch)
    lst=glob.glob(f_strg)
    ff.append(lst)
f_list=[fle for a in ff for fle in a] ## Flattening out a list of lists
print(len(f_list))

arr=[np.load(fname)[:,0,:,:] for fname in f_list]
s_lbann=np.vstack(arr)
print(s_lbann.shape,np.max(s_lbann))

19
(2432, 128, 128) 0.99534506


In [96]:
max_val=np.amax(s_lbann,axis=(1,2))
print(np.where(max_val>0.994))
max_val.shape
plt.figure()
plt.plot(max_val)

(array([ 335,  450,  759,  948, 1962, 1985]),)


  after removing the cwd from sys.path.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x2aac8bde62e8>]

In [97]:
s_lbann[450]

array([[-0.97311324, -0.95562595, -0.963725  , ..., -0.9962903 ,
        -0.9947965 , -0.9611758 ],
       [-0.96600235, -0.9489714 , -0.98024267, ..., -0.97837526,
        -0.9280931 , -0.9281392 ],
       [-0.88892454, -0.92576075, -0.88998604, ..., -0.99630064,
        -0.98640203, -0.9119709 ],
       ...,
       [-0.99542767, -0.989753  , -0.9873477 , ..., -0.80070156,
        -0.85057217, -0.8973274 ],
       [-0.9970331 , -0.9861392 , -0.98753655, ..., -0.9227203 ,
        -0.97421265, -0.9459293 ],
       [-0.97018707, -0.98750865, -0.997626  , ..., -0.94706106,
        -0.9842882 , -0.9164196 ]], dtype=float32)

In [101]:

f_pixel_intensity(s_lbann[330],label='',normalize=False,log_scale=True,mode='simple')
f_pixel_intensity(f_invtransform(s_lbann[330]),label='',normalize=False,log_scale=True,mode='simple')




  plt.figure()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.figure()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(array([16047,   178,    54,    43,    23,     7,    10,     4,     3,
            1,     2,     0,     1,     2,     1,     2,     1,     0,
            1,     0,     1,     0,     1,     1,     1]), None)

In [82]:

def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    
    ### Crop out large pixel values
    for key in sample_names:
        print(sample_dict[key].shape)
        sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
        print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
        
    hist_range=(0,0.996)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        hist_range=(0,2000)

    
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale)



In [83]:
dict_samples={'lbann':s_lbann, 'keras':s_keras,'raw': s_raw}

interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=('lbann', 'keras', 'raw'), value=()),…

<function __main__.f_widget_compare(sample_names, sample_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True)>

In [84]:
stripped_arr=np.array([arr for arr in s_lbann if np.max(arr)<0.996])
print(stripped_arr.shape)

(2432, 128, 128)
