## Compare results for multiple cosmologies

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import subprocess as sp
import os
import glob
import sys

import itertools
import time

from ipywidgets import *

In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/lbann_cosmogan/3_analysis')
from modules_image_analysis import *

In [4]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4. + 1e-8) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s + 1e-8)

In [5]:
parent_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/'
f_strg=parent_dir+'*.npy'
f_list=glob.glob(f_strg)
print(f_list)

['/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.3_Sg1.1_H70.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H70.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.8_H100.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.3_Sg0.5_H40.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H40.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H100.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.3_Sg1.1_H40.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg1.1_H70.0.npy', '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_S

In [6]:
def f_compute_hist_spect(sample,bins):
    ''' Compute pixel intensity histograms and radial spectrum for 2D arrays
    Input : Image arrays and bins
    Output: dictionary with 5 arrays : Histogram values, errors and bin centers, Spectrum values and errors.
    '''
    ### Compute pixel histogram for row
    gen_hist,gen_err,hist_bins=f_batch_histogram(sample,bins=bins,norm=True,hist_range=None)
    ### Compute spectrum for row
    spec,spec_err=f_compute_spectrum(sample,plot=False)

    dict1={'hist_val':gen_hist,'hist_err':gen_err,'hist_bin_centers':hist_bins,'spec_val':spec,'spec_err':spec_err }
    return dict1

In [7]:
### Extract validation data
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_2_smoothing_200k/norm_1_train_val.npy'
s_val=np.load(fname,mmap_mode='r')[:8000][:,0,:,:]
print(s_val.shape)

(8000, 128, 128)


In [8]:
bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
bins=f_transform(bins)   ### scale to (-1,1) 
### Compute histogram and spectrum of raw data 
dict_val=f_compute_hist_spect(s_val,bins)

In [35]:
# print(bins)

In [43]:
df_runs=pd.DataFrame([])
dict1={}
for count,fname in enumerate(f_list):
    lst=fname.split('.npy')[0].split('/')[-1].split('_')
#     print(lst)
    keys=['omega_m','sigma_8','H_0']
    values=[float(lst[0][2:]),float(lst[1][2:]),float(lst[2][1:])]
    dict1=dict(zip(keys,values))
    strg=str(count)+'_Og={0}_Sg={1}_H0={2}'.format(dict1['omega_m'],dict1['sigma_8'],dict1['H_0'])
    print(strg)

    dict1['label']=strg
    dict1['fname']=fname
    images=f_transform(np.load(fname)[:,:,:,0])
    f_transform(images)
#     print(images.shape,np.max(images),np.min(images))
    dict_sample=f_compute_hist_spect(images,bins) ## list of 5 numpy arrays 
    dict1.update(dict_sample)
    del(images)
    df_runs=df_runs.append(dict1,ignore_index=True)
    
df_runs

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.3_Sg1.1_H70.0.npy 0_Og=0.3_Sg=1.1_H0=70.0
(6144, 128, 128) 0.9928713 -0.99987966
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H70.0.npy 1_Og=0.45_Sg=0.5_H0=70.0
(6144, 128, 128) 0.9797709 -0.9989238
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.8_H100.0.npy 2_Og=0.45_Sg=0.8_H0=100.0
(6144, 128, 128) 0.98340845 -0.99973017
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.3_Sg0.5_H40.0.npy 3_Og=0.3_Sg=0.5_H0=40.0
(6144, 128, 128) 0.9890187 -0.9984752
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H40.0.npy 4_Og=0.45_Sg=0.5_H0=40.0
(6144, 128, 128) 0.9878609 -0.9985411
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_3_different_universes_6k/Om0.45_Sg0.5_H100.0.npy 5_Og=0.45_Sg=0.5_H0=100

Unnamed: 0,H_0,fname,hist_bin_centers,hist_err,hist_val,label,omega_m,sigma_8,spec_err,spec_val
0,70.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0009079648516789308, 0.0006726266991459041,...","[1.2738733017147361, 0.7588932646952613, 0.212...",0_Og=0.3_Sg=1.1_H0=70.0,0.3,1.1,"[76684.91293652191, 798.5318860808909, 397.337...","[206959783.34564114, 109497.94028127899, 74247..."
1,70.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0005896064371543193, 0.0004095359008676652,...","[0.9127408801837834, 1.2222250299728976, 0.334...",1_Og=0.45_Sg=0.5_H0=70.0,0.45,0.5,"[46532.331353938775, 318.2938782704704, 163.83...","[191164735.4860935, 53968.635256076384, 42226...."
2,100.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0004989796421372787, 0.0003750298284552049,...","[1.1105024619732171, 0.9422431528181456, 0.277...",2_Og=0.45_Sg=0.8_H0=100.0,0.45,0.8,"[41024.957127590584, 274.0651537559888, 163.08...","[198339446.12083697, 47473.55019415706, 41648...."
3,40.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0011721732340012365, 0.0007746468675593247,...","[0.9011799669786164, 1.2711475529731002, 0.312...",3_Og=0.3_Sg=0.5_H0=40.0,0.3,0.5,"[90279.4120620341, 679.3301888056498, 310.0421...","[191748664.4415345, 102321.69588139419, 61701...."
4,40.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0009674000830841397, 0.0006665671900912941,...","[0.9687526046140419, 1.1728589282742705, 0.295...",4_Og=0.45_Sg=0.5_H0=40.0,0.45,0.5,"[75607.36567714384, 624.5992992171072, 289.742...","[193913008.11533976, 94984.244893716, 60539.48..."
5,100.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.000423274646576603, 0.00031677142474488475,...","[0.85460860409485, 1.2812215808025866, 0.37229...",5_Og=0.45_Sg=0.5_H0=100.0,0.45,0.5,"[33074.07803817575, 180.57280115908108, 108.10...","[188669374.32882974, 32758.059599335018, 29222..."
6,40.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0015370041672688928, 0.0011073430650724353,...","[1.3337079646939394, 0.6959759920380478, 0.189...",6_Og=0.3_Sg=1.1_H0=40.0,0.3,1.1,"[132254.28288502817, 1901.2041026683396, 842.4...","[210612438.37024233, 174585.3499182778, 97067...."
7,70.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0007657065360985548, 0.00056600424796568, 0...","[1.3069146160218803, 0.7092073179530239, 0.205...",7_Og=0.45_Sg=1.1_H0=70.0,0.45,1.1,"[66381.27962227042, 668.1076184972667, 360.782...","[208254232.41118303, 95086.84574871638, 71360...."
8,40.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0011230635547877885, 0.0008252430830342482,...","[1.2304222447869781, 0.8262177420468624, 0.217...",8_Og=0.45_Sg=0.8_H0=40.0,0.45,0.8,"[93056.66721009935, 1034.1837977318226, 496.95...","[204920860.56500146, 130862.43393145881, 82034..."
9,70.0,/global/cfs/cdirs/m3363/vayyar/cosmogan_data/r...,"[-1.031746031584782, -0.6161616169043975, -0.3...","[0.0011449259393211775, 0.0007876688685772156,...","[1.035386958036178, 1.0846735591013432, 0.2758...",9_Og=0.15_Sg=0.8_H0=70.0,0.15,0.8,"[91192.41188227247, 782.6977612167556, 356.387...","[196952288.63943708, 108756.3470529826, 66763...."


In [45]:

def f_plot_hist_spec_best(df,dict_bkg):

    fig=plt.figure(figsize=(6,6))
    ax1=fig.add_subplot(121)
    ax2=fig.add_subplot(122)
    for (i,row),marker in zip(df.iterrows(),itertools.cycle('>^*sDHPdpx_')):

        x1=row.hist_bin_centers
        y1=row.hist_val
        yerr1=row.hist_err
        x1=f_invtransform(x1)

        y2=row.spec_val
        yerr2=row.spec_err
        x2=np.arange(len(y2))

        label=row.label
        ax1.errorbar(x1,y1,yerr1,marker=marker,markersize=5,linestyle='',label=label)
    #     ax2.errorbar(x2,y2,yerr2,marker=marker,markersize=5,linestyle='',label='{0}-{1}'.format(epoch,step))

        ax2.fill_between(x2, y2 - yerr2, y2 + yerr2, alpha=0.4)
        ax2.plot(x2, y2, marker=marker, linestyle=':',label=label)

    ### Plot input data
    x,y,yerr=dict_bkg['hist_bin_centers'],dict_bkg['hist_val'],dict_bkg['hist_err']
    x=f_invtransform(x)
    ax1.errorbar(x, y,yerr,color='k',linestyle='-',label='bkgnd')   

    y,yerr=dict_bkg['spec_val'],dict_bkg['spec_err']
    x=np.arange(len(y))
    ax2.fill_between(x, y - yerr, y + yerr, color='k',alpha=0.8)

#     plt.legend()
    # plt.yscale('log')
    ax1.set_xscale('symlog',linthreshx=50)
    ax1.set_yscale('log')
    ax2.set_yscale('log')


f_plot_hist_spec_best(df_runs,dict_val)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [80]:

def f_plot_hist_spec_best(df,dict_bkg,plot_type):

    fig=plt.figure(figsize=(6,6))
    for (i,row),marker in zip(df.iterrows(),itertools.cycle('>^*sDHPdpx_')):
        
        label=row.label
        if plot_type=='hist':
            x1=row.hist_bin_centers
            y1=row.hist_val
            yerr1=row.hist_err
            x1=f_invtransform(x1)
            
            plt.errorbar(x1,y1,yerr1,marker=marker,markersize=5,linestyle='',label=label)
        if plot_type=='spec':
            
            y2=row.spec_val
            yerr2=row.spec_err
            x2=np.arange(len(y2))
            
            plt.fill_between(x2, y2 - yerr2, y2 + yerr2, alpha=0.4)
            plt.plot(x2, y2, marker=marker, linestyle=':',label=label)

    ### Plot input data
    if plot_type=='hist':
        x,y,yerr=dict_bkg['hist_bin_centers'],dict_bkg['hist_val'],dict_bkg['hist_err']
        x=f_invtransform(x)
        plt.errorbar(x, y,yerr,color='k',linestyle='-',label='bkgnd')   
    if plot_type=='spec':
        y,yerr=dict_bkg['spec_val'],dict_bkg['spec_err']
        x=np.arange(len(y))
        plt.fill_between(x, y - yerr, y + yerr, color='k',alpha=0.8)

    plt.legend(loc='upper',bbox_to_anchor=(0.3, 0.75),ncol=2, fancybox=True, shadow=True,prop={'size':6})
    plt.yscale('log')
#     ax1.set_xscale('symlog',linthreshx=50)
#     ax1.set_yscale('log')
#     ax2.set_yscale('log')


f_plot_hist_spec_best(df_runs,dict_val,'hist')
f_plot_hist_spec_best(df_runs,dict_val,'spec')

  This is separate from the ipykernel package so we can avoid doing imports until


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

	best
	upper right
	upper left
	lower left
	lower right
	right
	center left
	center right
	lower center
	upper center
	center
This will raise an exception in 3.3.
  This is separate from the ipykernel package so we can avoid doing imports until


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

	best
	upper right
	upper left
	lower left
	lower right
	right
	center left
	center right
	lower center
	upper center
	center
This will raise an exception in 3.3.


In [13]:
def f_widget_compare(sample_names,sample_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,bkgnd=[]):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    
    ### Crop out large pixel values
#     for key in sample_names:
#         print(sample_dict[key].shape)
#         sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
#         print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        if len(bkgnd): bkgnd=f_invtransform(bkgnd)
#         hist_range=(0,2000)
    else:
        bins=f_transform(bins)
#         hist_range=(-1,0.996)
    assert Fig_type in ['pixel','spectrum'],"Invalid mode %s"%(mode)
    
    if Fig_type=='pixel':
#         f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=None,bkgnd_arr=bkgnd)

    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,log_scale=log_scale,bkgnd_arr=bkgnd)



In [None]:
# # s_input=[]
# dict_samples={}
# for idx,row in df_runs.iterrows():
#     strg=str(idx)+'_Om={0}_Sg={1}_H0={2}'.format(row.omega_m,row.sigma_8,row.H_0)
#     print(strg)
#     img=np.load(row.fname)[:,:,:,0]
#     dict_samples.update({strg:img})

In [14]:
bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use

In [16]:
### Load validation input samples
img_raw='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/dataset_2_smoothing_200k/train.npy'
a1=np.load(img_raw)[:10000]
s_raw=f_transform(a1[:,:,:,0])[:10000]

print(s_raw.shape)

(10000, 128, 128)


In [17]:
# dict_samples={'raw': s_raw,'keras1':s_keras[0],'keras2':s_keras[1],'s_new1':s_new1,'s_new2':s_new2,'s_new3':s_new3}

bkgnd=s_raw
# bkgnd=[]
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),
                Fig_type=ToggleButtons(options=['pixel','spectrum']),
                bins=fixed(bins),
                mode=['avg','simple'],bkgnd=fixed(bkgnd))

interactive(children=(SelectMultiple(description='sample_names', options=('0_Om=0.3_Sg=1.1_H0=70.0', '1_Om=0.4…

<function __main__.f_widget_compare(sample_names, sample_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True, bkgnd=[])>