# Extract data from output files
### Analyze the output from a single LBANN run
March 9, 2020 \
April 6, 2020 : Major edit to store files in order of epochs \
April 21, 2020: Major edit, added jupyter widgets to compare pixel intensity plots \
May 8, 2020: Major edit, using all images for a given batch

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import time
from scipy import fftpack
# from ipywidgets import interact, interact_manual,fixed, SelectMultiple, IntText, IntSlider, FloatSlider,SelectionSlider,BoundedIntText
from ipywidgets import *

In [4]:
%matplotlib widget

In [5]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/LBANN/lbann_cosmogan/3_analysis/')
from modules_image_analysis import *

[NbConvertApp] Converting notebook modules_image_analysis.ipynb to script
[NbConvertApp] Writing 16243 bytes to modules_image_analysis.py
(197000, 1, 128, 128) <class 'numpy.float32'>


In [6]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.


def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

## Explore image samples

In [29]:

def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    
    ### Crop out large pixel values
    for key in sample_names:
        print(sample_dict[key].shape)
        sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
        print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
        
    hist_range=(0,0.996)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        hist_range=(0,2000)

    
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale)

        
        
        
def f_compare_pixel_intensity(img_lst,label_lst=['img1','img2'],bkgnd_arr=[],log_scale=True, normalize=True, mode='avg',bins=25, hist_range=None):
    '''
    Module to compute and plot histogram for pixel intensity of images
    Has 2 modes : simple and avg
    simple mode: No errors. Just flatten the input image array and compute histogram of full data
    avg mode(Default) : 
        - Compute histogram for each image in the image array
        - Compute errors across each histogram 
        
    bkgnd_arr : histogram of this array is plotting with +/- sigma band
    '''
    
    norm=normalize # Whether to normalize the histogram
    
    def f_batch_histogram(img_arr,bins,norm,hist_range):
        ''' Compute histogram statistics for a batch of images'''
        
        ## Extracting the range. This is important to ensure that the different histograms are compared correctly
        if hist_range==None : ulim,llim=np.max(img_arr),np.min(img_arr)
        else: ulim,llim=hist_range[1],hist_range[0]
#         print(ulim,llim)
        ### array of histogram of each image
        hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
        hist=np.stack(hist_arr[:,0]) # First element is histogram array
#         print(hist.shape)

        bin_list=np.stack(hist_arr[:,1]) # Second element is bin value 
        ### Compute statistics over histograms of individual images
        mean,err=np.mean(hist,axis=0),np.std(hist,axis=0)/np.sqrt(hist.shape[0])
        bin_edges=bin_list[0]
        centers = (bin_edges[:-1] + bin_edges[1:]) / 2
#         print(bin_edges,centers)

        return mean,err,centers
    
    plt.figure()
    
    ## Plot background distribution
    if len(bkgnd_arr):
        if mode=='simple':
            hist, bin_edges = np.histogram(bkgnd_arr.flatten(), bins=bins, density=norm, range=hist_range)
            centers = (bin_edges[:-1] + bin_edges[1:]) / 2
            plt.errorbar(centers, hist, color='k',marker='*',linestyle=':', label='bkgnd')

        elif mode=='avg':
            ### Compute histogram for each image. 
            mean,err,centers=f_batch_histogram(bkgnd_arr,bins,norm,hist_range)
            plt.plot(centers,mean,linestyle=':',color='k',label='bkgnd')
            plt.fill_between(centers, mean - err, mean + err, color='k', alpha=0.4)
    
    ### Plot the rest of the datasets
    for img,label in zip(img_lst,label_lst):     
        if mode=='simple':
            hist, bin_edges = np.histogram(img.flatten(), bins=bins, density=norm, range=hist_range)
            centers = (bin_edges[:-1] + bin_edges[1:]) / 2
            plt.errorbar(centers, hist, fmt='o-', label=label)

        elif mode=='avg':
            ### Compute histogram for each image. 
            mean,err,centers=f_batch_histogram(img,bins,norm,hist_range)
#             print('Centers',centers)
            plt.errorbar(centers,mean,yerr=err,fmt='o-',label=label)

    if log_scale: 
        plt.yscale('log')
        plt.xscale('symlog',linthreshx=50)

    plt.legend()
    plt.xlabel('Pixel value')
    plt.ylabel('Counts')
    plt.title('Pixel Intensity Histogram')
    

def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,bkgnd=[]):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    
    ### Crop out large pixel values
    for key in sample_names:
        print(sample_dict[key].shape)
        sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
        print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    
    bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,700.5,100),np.array([2000])]) #bin edges to use
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        if len(bkgnd): bkgnd=f_invtransform(bkgnd)
        hist_range=(0,2000)
    else:
        bins=f_transform(bins)
        hist_range=(-1,0.996)
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
#         f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range,bkgnd_arr=bkgnd)

    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale)




## Compare lbann images with input and keras code images

In [25]:
### Load images from keras code
img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_200k_samples_peter_dataset_20_epochs/models/gen_imgs.npy'
img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_2_fixed_cosmology/models/gen_imgs.npy'
a1=np.load(img_keras)
s_keras=a1[:,:,:]

### Load validation samples
# img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/peter_dataset/raw_val.npy'
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_2_smooothing_200k/val.npy'

a1=np.load(img_raw)
s_raw=f_transform(a1[:,:,:,0])[:3000]

print(s_raw.shape,s_keras.shape)

(3000, 128, 128) (5000, 128, 128)


In [32]:
### Extract a few images generated by Lban directly for a set of epochs
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200529_111342_seed3273_80epochs/'
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200612_113211_exagan/'

ff=[]
# for epoch in np.arange(46,47):
for epoch in np.arange(66,76):
    f_strg=parent_dir+'dump_outs/trainer0/model0/sgd.validation.epoch.{}*gen_img*.npy'.format(epoch)
    lst=glob.glob(f_strg)
    ff.append(lst)
f_list=[fle for a in ff for fle in a] ## Flattening out a list of lists
print(len(f_list))

arr=[np.load(fname)[:,0,:,:] for fname in f_list]
s_lbann=np.vstack(arr)
print(s_lbann.shape,np.max(s_lbann))

48
(6144, 128, 128) 0.99737597


In [33]:
max_val=np.amax(s_lbann,axis=(1,2))
# print(np.where(max_val>0.994))
max_val.shape

plt.figure()
plt.plot(max_val)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x2aae0abb3a90>]

In [34]:
# f_pixel_intensity(s_lbann[330],label='',normalize=False,log_scale=True,mode='simple')
# f_pixel_intensity(f_invtransform(s_lbann[330]),label='',normalize=False,log_scale=True,mode='simple')

In [35]:
dict_samples={'lbann':s_lbann, 'keras':s_keras,'raw': s_raw}

interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'],bkgnd=fixed([]))

interactive(children=(SelectMultiple(description='sample_names', options=('lbann', 'keras', 'raw'), value=()),…

<function __main__.f_widget_compare(sample_names, sample_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True, bkgnd=[])>

## Compare different sample sets
#### Compare lbann images with input and keras code images

In [13]:
### Load images from keras code
# img_keras='/global/cfs/cdirs/dasrepo/vpa/cosmogan/data/computed_data/exagan1/run_100k_samples_35epochs/models/gen_imgs.npy'
# img_keras='/global/cfs/cdirs/dasrepo/vpa/cosmogan/data/computed_data/exagan1/run_200k_samples_24epochs/models/gen_imgs.npy'
img_keras='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/exagan1/run_200k_samples_peter_dataset_20_epochs/models/gen_imgs.npy'

a1=np.load(img_keras)
s_keras=a1[:,:,:]

### Load validation samples
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/very_large_dataset_val.npy'
a1=np.load(img_raw)
s_raw=f_transform(a1[:,:,:,0])[:3000]

print(s_raw.shape,s_keras.shape)

(3000, 128, 128) (3000, 128, 128)


In [14]:
### Extract a few images generated by Lban directly for a set of epochs
# parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200506_121613_exagan_200k_samples/'
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200513_121910_peters_dataset/'

ff=[]

for epoch in [34,37]:
    f_strg=parent_dir+'dump_outs/model0-validation-epoch{}-*gen_img*.npy'.format(epoch)
    lst=glob.glob(f_strg)
    ff.append(lst)
f_list=[fle for a in ff for fle in a] ## Flattening out a list of lists
print(len(f_list))

arr=[np.load(fname)[:,0,:,:] for fname in f_list]
s_lbann=np.vstack(arr)
print(s_lbann.shape,np.max(s_lbann))

0


ValueError: need at least one array to concatenate

In [None]:
dict_samples={'lbann':s_lbann, 'keras':s_keras,'raw': s_raw}

interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=SelectionSlider(options=np.arange(10,200,10),value=50),
                mode=['avg','simple'])