# Extract data from output files
### Analyze the output from a single LBANN run
March 9, 2020

April 6, 2020 : Major edit to store files in order of epochs 

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import time
from scipy import fftpack
from ipywidgets import interact, interact_manual,fixed, SelectMultiple


In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/LBANN/lbann_cosmogan/3_analysis/')
from modules_image_analysis import *

[NbConvertApp] Converting notebook modules_image_analysis.ipynb to script
[NbConvertApp] Writing 13357 bytes to modules_image_analysis.py


In [20]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

In [5]:

def f_get_samples(df,key):
    '''
    Extract array of samples from the DataFrame with images
    Images are of two types:
    1. *_gen have shape (64,1,128,128)
    2. *_input have shape (64,16384)
    '''
    
    keys=['train_gen','train_input','val_gen','val_input']
    assert key in keys,"Given key %s is not the the list of keys %s"%(key,keys)
    
    lst=df[df.type==key]['image'].values
    
    if key.endswith('input'):
        size=np.int(np.sqrt(lst[0].shape[-1])) ### Extract size of images (=128)
        samples=np.array([ii[0,:].reshape(size,size) for ii in lst])
    else : 
        samples=np.array([ii[0,0,:,:] for ii in lst])
    
    return samples

## Extract image data 

In [7]:
fldr_name='20200316_112134_exagan'
fldr_name='20200406_080207_exagan_with_mcr'
fldr_name='20200407_093719_exagan_no_mcr'
fldr_name='20200409_084926_exagan_no_mcr'
fldr_name='20200409_083646_exagan_with_mcr'
fldr_name='20200413_095840_exagan'

fldr_name='20200421_055139_exagan'
fldr_name='20200421_130545_exagan'
fldr_name='20200421_132207_exagan'


### Code for set of runs
# f_list=['20200401_125919_exagan_0.1_1','20200401_130321_exagan_0.1_4',
#         '20200401_130907_exagan_0.3_1','20200401_130646_exagan_0.3_4']
# fldr_name=f_list[0]


main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/{0}/dump_outs/'.format(fldr_name)
print(main_dir)


/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_data/20200421_055139_exagan/dump_outs/


In [8]:

## Get images files and .npy arrays for each image in dump_outs folder
files_dict={}
keys=['train_gen','train_input','val_gen','val_input']
file_strg_lst=['model0-training*-gen_img*-output0.npy','model0-training*-inp_img*-output0.npy','model0-validation*-gen_img*-output0.npy','model0-validation*-inp_img*-output0.npy']
for key,file_strg in zip(keys,file_strg_lst):
    files_dict[key]=np.array(glob.glob(main_dir+file_strg))
    if files_dict[key].shape[0]>1000 : 
        print('Warning the number of files is very large. Possibility of memory overload')

df_files=pd.DataFrame([])
dict1={}
t1=time.time()
### First get sorted Dataframe with file names
for key in keys: 
    files_arr=files_dict[key]  # Get array of files
    print(key,len(files_arr))
    for fname in files_arr:
        ### Extract the Epoch number and step number from the file name
        dict1['type']=key
        dict1['epoch']=np.int32(fname.split('epoch')[-1].split('-')[0])
        dict1['step']=np.int64(fname.split('step')[-1].split('-')[0])
        dict1['fname']=fname
        
        df_files=df_files.append(dict1,ignore_index=True)
## Sort values
df_files=df_files.sort_values(by=['type','epoch','step']).reset_index(drop=True)
# df_files
print("Sorting done")

t2=time.time()
### Then read images one by one into a numpy array and create a new DataFrame
sorted_fnames=df_files.fname.values
### Read images one by one. This is time-consuming.
### Deliberately kept as list because some of the input arrays have different dimensions, causing creation of array of arrays in some cases
images=[np.load(fname) for fname in sorted_fnames]  

##### Create new Dataframe with sorted images
df_full=pd.DataFrame([])
df_full['image']=images
t3=time.time()
for col in ['epoch','step','type','fname']: df_full[col]=df_files[col].values
print("Extraction done")

print("Time for Sorting",t2-t1)
print("Time for Reading images",t3-t2)

df=df_full.copy()
print(df.shape)


train_gen 2220
train_input 2220
val_gen 248
val_input 248
Sorting done
Extraction done
Time for Sorting 20.03162717819214
Time for Reading images 184.31417989730835
(4936, 5)


In [9]:
## Slice DataFrame before getting samples. Get 1 images per epochs (choose the last step)

def f_filter_epoch(df_input):
    '''
    Get just the last stored step image for each epoch
    '''
    df_output=pd.DataFrame([])
    for key in ['train_gen','train_input','val_gen','val_input']: 
        ### For each type of images, get list of epochs
        df1=df_input[df_input.type==key]
        epochs=np.unique(df1.epoch.values).astype(int)
        for epoch in epochs:### Extract the last step in each epoch
            df2=df1[df1.epoch==epoch]
            df_output=df_output.append(df2.iloc[-1])  
    
    return df_output.reset_index(drop=True)

df=f_filter_epoch(df_full)
df.shape

(240, 5)

## Extract samples 

In [10]:
### Available options : keys=['train_gen','train_input','val_gen','val_input']
samples1=f_get_samples(df,'train_input')
print(samples1.shape)
samples2=f_get_samples(df,'val_gen')
print(samples2.shape)

samples3=f_get_samples(df,'train_gen')
print(samples3.shape)
samples4=f_get_samples(df,'val_input')
print(samples4.shape)

(60, 128, 128)
(60, 128, 128)
(60, 128, 128)
(60, 128, 128)


## Find the region without very high pixel values


In [11]:
def f_plot_max_values(samples,cutoff=0.994):
    '''
    Make a plot of max values of images of a given set of sample images
    cutoff used to discard high values
    '''
    ### Get max pixel values of images
    max_values=np.array([np.max(i) for i in samples])
    ### Less than cutoff
    lesser_idx=np.where(max_values<cutoff)[0]
    higher_idx=np.where(max_values>=cutoff)[0]
    
    plt.figure()
    plt.plot(lesser_idx,max_values[lesser_idx],linestyle='',marker='*',color='r')
    plt.plot(higher_idx,max_values[higher_idx],linestyle='',marker='D',color='b')

    plt.axhline(y=cutoff,linestyle='--',color='k')
    plt.ylim(0.9,1.0)
    
f_plot_max_values(samples2,0.9945)
# f_plot_max_values(samples4,0.992)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Compare images

In [75]:
def f_pixel_intensity(img_arr,bins=25,label='validation',mode='avg',normalize=False,log_scale=True,plot=True):
    '''
    Module to compute and plot histogram for pixel intensity of images
    Has 2 modes : simple and avg
        simple mode: No errors. Just flatten the input image array and compute histogram of full data
        avg mode(Default) : 
            - Compute histogram for each image in the image array
            - Compute errors across each histogram 
    '''
    
    norm=normalize # Whether to normalize the histogram
    
    def f_batch_histogram(img_arr,bins,norm):
        ''' Compute histogram statistics for a batch of images'''
        
        ## Extracting the range. This is important to ensure that the different histograms are compared correctly
        ulim,llim=np.max(img_arr),np.min(img_arr)
        ### array of histogram of each image
        hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
        hist=np.stack(hist_arr[:,0]) # First element is histogram array
        print(hist.shape)
        bin_list=np.stack(hist_arr[:,1]) # Second element is bin value 
        ### Compute statistics over histograms of individual images
        mean,err=np.mean(hist,axis=0),np.std(hist,axis=0)/np.sqrt(hist.shape[0])
        bin_edges=bin_list[0]
        centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    
        return mean,err,centers
    
    if plot: 
        plt.figure()
        plt.xlabel('Pixel value')
        plt.ylabel('Counts')
        plt.title('Pixel Intensity Histogram')

        if log_scale: plt.yscale('log')
    
    if mode=='simple':
        hist, bin_edges = np.histogram(img_arr.flatten(), bins=25, density=norm)
        centers = (bin_edges[:-1] + bin_edges[1:]) / 2

        if plot: plt.errorbar(centers, hist, fmt='o-', label=label)
        return hist,None
    
    elif mode=='avg': 
        ### Compute histogram for each image. 
        mean,err,centers=f_batch_histogram(img_arr,bins,norm)

        if plot: plt.errorbar(centers,mean,yerr=err,fmt='o-',label=label)  
        return mean,err

# _,_=f_pixel_intensity(f_invtransform(samples2),bins=100,label='validation',mode='simple',normalize=False,log_scale=True,plot=True)

In [82]:
def f_widget_pixel_intensity(arr,label,start,end,normalize=True,log_scale=True,rescale=True):
    '''
    Module to plot pixel intensity with options for normalization, log-scal, and rescale
    Rescale converts image pixel values from (-1,1) to the original pixel range
    '''
    try :
        sliced_arr=arr[start:end]
        if sliced_arr.shape[0]<1:
            print('Input indices %s %s are invalid.\nUsing full array'%(start,end))
            start0,end=0,'end'
            sliced_arr=arr[:]
    except Exception as e:
        print(e)
    if rescale: ### Converting from pixel intensity range (-1,1) to original range
        sliced_arr=f_invtransform(sliced_arr)
    print('Array size used',sliced_arr.shape)
    
    f_pixel_intensity(sliced_arr,label=label+': {0}-{1}'.format(str(start),str(end)),normalize=normalize,mode='simple')

# f_widget_pixel_intensity(samples2,'s2',0,None,True,True,True)

In [63]:
interact_manual(f_widget_pixel_intensity,arr=fixed(samples2),label=fixed('s1'),start=np.arange(0,100,10),end=np.arange(0,100,10))

interactive(children=(Dropdown(description='start', options=(0, 10, 20, 30, 40, 50, 60, 70, 80, 90), value=0),…

<function __main__.f_widget_pixel_intensity(arr, label, start, end, normalize=True, log_scale=True, rescale=True)>

In [None]:
# interact_manual(f_widget_pixel_intensity,arr=[samples1,samples2],label=['s1','s2'],start=np.arange(0,100,10),end=np.arange(0,100,10))

In [None]:
# f_compare_pixel_intensity([samples1,samples2,samples3,samples4],label_lst=['s1','s2','s3','s4'],normalize=False,log_scale=True, mode='avg',bins=25)

## Compare pixel intensities

In [191]:
def f_widget_compare_pixel_intensity(sample_names,sample_dict,rescale=True,log_scale=True,bins=25,mode='avg',normalize=True):
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
    
    f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins)
    
# f_widget_compare_pixel_intensity(dict_samples.keys(),dict_samples)

### Compare different epochs

In [192]:
def f_get_sample_epochs(samples):
    
    size=samples.shape[0]
    img_list,labels_list=[],[]
    for i in range(0,size,10):
        i1,i2=i,i+10
        img_list.append(samples[i1:i2])
#         img_list.append(f_invtransform(samples[i1:i2]))

        labels_list.append('%s:%s'%(str(i1),str(i2)))
    img_list.append(samples)
    labels_list.append('0:end')
    
    return img_list,labels_list


In [193]:
img_list,label_list=f_get_sample_epochs(samples2)
dict_samples=dict.fromkeys(labels_list)
for key,val in zip(labels_list,img_list): dict_samples[key]=val

interact_manual(f_widget_compare_pixel_intensity,sample_dict=fixed(dict_samples),sample_names=SelectMultiple(options=dict_samples.keys()),bins=np.arange(10,201,20),mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=('0:10', '10:20', '20:30', '30:40', '…

<function __main__.f_widget_compare_pixel_intensity(sample_names, sample_dict, rescale=True, log_scale=True, bins=25, mode='avg', normalize=True)>

### Compare different sample sets

In [195]:
dict_samples={'s1':samples1, 's2':samples2[10:30],
              's3':samples3, 's4':samples4}
interact_manual(f_widget_compare_pixel_intensity,sample_dict=fixed(dict_samples),sample_names=SelectMultiple(options=dict_samples.keys()),bins=np.arange(10,201,20),mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=('s1', 's2', 's3', 's4'), value=()), …

<function __main__.f_widget_compare_pixel_intensity(sample_names, sample_dict, rescale=True, log_scale=True, bins=25, mode='avg', normalize=True)>

In [168]:
# f_pixel_intensity(f_invtransform(samples2),normalize=False)
# f_compare_pixel_intensity([samples1,samples2,samples3,samples4],label_lst=['s1','s2','s3','s4'],normalize=False,log_scale=True, mode='avg',bins=25)

### Plot grid of intensity histograms

In [169]:
# f_plot_intensity_grid(samples2[40:80][::5],cols=6)
# f_plot_intensity_grid(f_invtransform(samples2[22:52][::3]),cols=6)

## Spectrum

In [None]:
# f_compute_spectrum(samples1)
# f_compute_spectrum(f_invtransform(samples2[51:80]))

In [None]:
# start,end=22,52
# start,end=23,33
f_compare_spectrum(samples4[start:end],samples2[start:end],label1='input',label2='generated')
f_compare_spectrum(f_invtransform(samples4[start:end]),f_invtransform(samples2[start:end]),label1='input',label2='generated')

In [None]:
start,end=33,None
f_compare_spectrum(samples4[start:end],samples2[start:end],label1='input',label2='generated')
f_compare_spectrum(f_invtransform(samples4[start:end]),f_invtransform(samples2[start:end]),label1='input',label2='generated')