# Analyze results with histograms and plots made from scratch
October 27, 2020

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import sys
import os
import glob
import pickle 

from matplotlib.colors import LogNorm, PowerNorm, Normalize
from ipywidgets import *

In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/lbann_cosmogan/3_analysis/')
from modules_image_analysis import *

In [4]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


## Read data

In [5]:
# main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'
# results_dir=main_dir+'20201002_064327'

In [6]:
dict1={'128':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/',
      '512':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/512sq/'}

u=interactive(lambda x: dict1[x], x=Select(options=dict1.keys()))
# display(u)


In [7]:
# parent_dir=u.result
parent_dir=dict1['128']
dir_lst=[i.split('/')[-1] for i in glob.glob(parent_dir+'20*')]
w=interactive(lambda x: x, x=Dropdown(options=dir_lst))
display(w)

interactive(children=(Dropdown(description='x', options=('20201202_094646_cgan_model2', '20201002_073628', '20…

In [8]:
result=w.result
result_dir=parent_dir+result
print(result_dir)


/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20201202_094646_cgan_model2


## Plot Losses

In [9]:
df_metrics=pd.read_pickle(result_dir+'/df_metrics.pkle').astype(np.float64)

In [10]:
df_metrics.step.values.shape,df_metrics.G_full.values.shape

((41401,), (41401,))

In [11]:
df_metrics.head(10)

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
0,0.0,0.0,0.67296,0.750564,1.423525,3.602696,346.572235,342.969543,6.603271,,,0.053961,0.083414,-3.540892,19.436753
1,1.0,0.0,0.851392,0.398092,1.249485,0.30658,343.163757,342.857178,6.7256,343.068054,7.19015,-0.202783,-0.929195,1.802442,0.14972
2,2.0,0.0,0.083336,3.283406,3.366741,3.237345,345.247284,342.009949,6.656205,342.222504,6.623254,4.191892,3.189792,-3.190942,0.133791
3,3.0,0.0,0.465928,0.190177,0.656105,3.798806,344.589539,340.790741,6.547747,340.908936,5.977823,0.587307,-1.71475,-3.768611,0.134441
4,4.0,0.0,0.468809,0.151318,0.620128,1.824184,342.248016,340.423828,6.490149,340.363647,6.483237,0.597263,-2.128833,-1.639891,0.13238
5,5.0,0.0,0.09116,1.061312,1.152472,5.579582,346.395569,340.815979,6.411236,340.842041,6.225676,3.071388,0.596954,-5.571957,0.135365
6,6.0,0.0,0.48335,0.110982,0.594332,4.167254,344.618896,340.45163,6.4181,340.107788,6.366264,0.650413,-3.611985,-4.148392,0.19442
7,7.0,0.0,0.137368,0.221207,0.358575,3.297258,343.015045,339.717773,6.338644,339.913055,6.126013,2.391047,-1.672691,-3.257637,0.135216
8,8.0,0.0,0.094639,0.337416,0.432055,5.915369,347.367279,341.451904,6.124936,341.955536,6.286113,3.179992,-1.08884,-5.911065,0.13496
9,9.0,0.0,0.205061,0.09617,0.301231,4.588552,343.866516,339.277954,6.145626,339.299408,6.013839,1.786248,-3.772867,-4.576407,0.135608


In [12]:
df_metrics[df_metrics.step%50==0].head(10)

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
0,0.0,0.0,0.67296,0.750564,1.423525,3.602696,346.572235,342.969543,6.603271,,,0.053961,0.083414,-3.540892,19.436753
50,50.0,0.0,0.762363,0.318434,1.080797,20.783304,360.805267,340.021973,2.282651,339.934723,0.711537,0.369389,-19.837168,-20.783304,0.140395
100,100.0,0.0,0.125499,3.644004,3.769504,15.110949,356.888489,341.777527,2.761917,342.013885,1.204785,4.838884,3.664841,-15.110945,0.141217
150,150.0,0.0,0.152837,0.125762,0.278599,3.991689,343.73172,339.740021,0.653847,340.702881,0.52926,2.793908,-3.061347,-3.965852,0.141147
200,200.0,0.0,0.1121,0.159,0.271101,5.673723,347.27417,341.600433,-4.217441,340.596497,-2.505859,3.501714,-2.252074,-5.66935,0.140421
250,250.0,0.0,0.090053,0.087462,0.177515,7.406257,344.893738,337.487488,3.252811,339.291687,1.022795,2.934886,-6.198949,-7.405168,0.1419
300,300.0,0.0,0.092608,0.097842,0.19045,6.374515,344.067657,337.693146,5.277166,335.678894,3.734727,3.210438,-6.42559,-6.371436,0.147581
350,350.0,0.0,0.210713,0.307418,0.518131,5.329717,342.498779,337.169067,6.577486,335.769806,6.96396,4.630486,-4.110314,-5.289142,0.147337
400,400.0,0.0,0.111493,0.063506,0.174999,4.762858,338.448517,333.685669,7.880299,335.123535,6.808002,4.128319,-4.084805,-4.748539,0.139996
450,450.0,0.0,0.089481,0.103592,0.193073,4.641104,336.853607,332.212494,7.273122,332.681702,6.340181,4.223292,-4.393451,-4.627389,0.141747


In [13]:
def f_plot_metrics(df,col_list):
    
    plt.figure()
    for key in col_list:
        plt.plot(df_metrics[key],label=key,marker='*',linestyle='')
    plt.legend()
    
#     col_list=list(col_list)
#     df.plot(kind='line',x='step',y=col_list)
    
# f_plot_metrics(df_metrics,['spec_chi','hist_chi'])

interact_manual(f_plot_metrics,df=fixed(df_metrics), col_list=SelectMultiple(options=df_metrics.columns.values))

interactive(children=(SelectMultiple(description='col_list', options=('step', 'epoch', 'Dreal', 'Dfake', 'Dful…

<function __main__.f_plot_metrics(df, col_list)>

In [14]:
chi=df_metrics.dropna().quantile(q=0.05,axis=0)['hist_chi']
print(chi)
df_metrics[df_metrics['hist_chi']<=chi].sort_values(by=['hist_chi']).head(10)

-8.051040649414062


Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
41208,41208.0,9.0,0.080038,0.080873,0.160911,4.253773,324.765503,320.511719,-8.132553,325.683441,-9.8409,4.207676,-4.299721,-4.23812,0.144696
40710,40710.0,9.0,0.084432,0.085563,0.169994,4.285488,317.735962,313.45047,-8.312079,322.920776,-9.838562,4.442131,-4.261558,-4.2703,0.144514
35959,35959.0,8.0,0.107712,0.078438,0.186151,4.100684,304.961395,300.860718,-7.500225,323.116608,-9.794497,4.281426,-4.1561,-4.082481,0.154152
41196,41196.0,9.0,0.060627,0.078132,0.138759,4.295483,331.643127,327.347656,-9.048525,327.350311,-9.771997,4.676192,-4.272749,-4.280361,0.143568
41167,41167.0,9.0,0.089724,0.077798,0.167522,4.735286,322.160461,317.425171,-7.775158,320.057373,-9.760061,4.741716,-4.67995,-4.726073,0.14363
41213,41213.0,9.0,0.076932,0.078849,0.155781,4.01766,312.682678,308.665009,-8.854787,319.672638,-9.742317,4.717999,-4.05975,-3.998183,0.141837
35554,35554.0,8.0,0.080147,0.065477,0.145624,4.40283,310.893066,306.490234,-7.785518,320.356934,-9.72444,4.356524,-4.357229,-4.387339,0.14383
35551,35551.0,8.0,0.084702,0.085277,0.169979,4.065558,325.611298,321.545746,-8.185736,314.398346,-9.724274,4.577144,-4.024002,-4.046775,0.142547
33082,33082.0,7.0,0.08771,0.081305,0.169015,4.713353,329.81604,325.102692,-5.672636,316.338348,-9.680647,4.370604,-4.642154,-4.703523,0.154768
41163,41163.0,9.0,0.074545,0.088874,0.163418,4.345701,321.499695,317.153992,-8.966484,316.946777,-9.679063,4.790658,-3.994018,-4.330885,0.144475


In [15]:
display(df_metrics.sort_values(by=['hist_chi']).head(8))
display(df_metrics.sort_values(by=['spec_chi']).head(8))

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
41208,41208.0,9.0,0.080038,0.080873,0.160911,4.253773,324.765503,320.511719,-8.132553,325.683441,-9.8409,4.207676,-4.299721,-4.23812,0.144696
40710,40710.0,9.0,0.084432,0.085563,0.169994,4.285488,317.735962,313.45047,-8.312079,322.920776,-9.838562,4.442131,-4.261558,-4.2703,0.144514
35959,35959.0,8.0,0.107712,0.078438,0.186151,4.100684,304.961395,300.860718,-7.500225,323.116608,-9.794497,4.281426,-4.1561,-4.082481,0.154152
41196,41196.0,9.0,0.060627,0.078132,0.138759,4.295483,331.643127,327.347656,-9.048525,327.350311,-9.771997,4.676192,-4.272749,-4.280361,0.143568
41167,41167.0,9.0,0.089724,0.077798,0.167522,4.735286,322.160461,317.425171,-7.775158,320.057373,-9.760061,4.741716,-4.67995,-4.726073,0.14363
41213,41213.0,9.0,0.076932,0.078849,0.155781,4.01766,312.682678,308.665009,-8.854787,319.672638,-9.742317,4.717999,-4.05975,-3.998183,0.141837
35554,35554.0,8.0,0.080147,0.065477,0.145624,4.40283,310.893066,306.490234,-7.785518,320.356934,-9.72444,4.356524,-4.357229,-4.387339,0.14383
35551,35551.0,8.0,0.084702,0.085277,0.169979,4.065558,325.611298,321.545746,-8.185736,314.398346,-9.724274,4.577144,-4.024002,-4.046775,0.142547


Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
16644,16644.0,4.0,0.074879,0.135497,0.210376,5.063136,313.686462,308.623322,-3.31355,298.52356,-0.293596,4.749221,-3.869322,-5.048425,0.140701
14043,14043.0,3.0,0.09992,0.131062,0.230981,5.278117,317.761383,312.483276,-1.835354,299.012512,-2.198043,4.715596,-4.265014,-5.263249,0.141332
13610,13610.0,3.0,0.099271,0.093766,0.193036,4.761892,313.654358,308.892456,-3.091039,299.299805,-3.646973,6.208107,-4.027986,-4.749377,0.143846
9978,9978.0,2.0,0.110018,0.142275,0.252293,5.761327,320.387054,314.625732,2.07321,299.981506,0.688096,4.680516,-4.731729,-5.746074,0.146882
13600,13600.0,3.0,0.11529,0.091482,0.206772,4.118304,324.477203,320.358887,-2.297143,300.0495,-6.769023,4.823624,-4.328673,-4.08246,0.140881
11225,11225.0,2.0,0.145078,0.113686,0.258764,6.397127,319.578278,313.181152,-1.29037,300.227997,1.810378,3.507452,-6.589367,-6.38949,0.136481
13599,13599.0,3.0,0.13364,0.05679,0.19043,4.077094,323.247986,319.170898,-3.642219,300.27594,-3.586799,3.694057,-4.724432,-4.049538,0.138298
13597,13597.0,3.0,0.095615,0.072385,0.168,4.601977,331.261108,326.659119,-4.623798,300.74881,-6.327124,4.244675,-4.159423,-4.579351,0.138917


## Plot

In [16]:
dict_samples={}

In [17]:
sigma_list=[0.5,0.65,0.8,1.1]
label_list=[0,1,2,3]
bkgnd_dict={}
num_bkgnd=2000

for label in label_list:
    fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/128_square/dataset_5_4univ_cgan/norm_1_sig_{0}_train_val.npy'.format(sigma_list[label])
    samples=np.load(fname,mmap_mode='r+')[-num_bkgnd:][:,0,:,:]
    bkgnd_dict[str(label)]=samples

In [18]:
step_list=[40800]
labels_list=[0,1,2,3]
for step in step_list:
    for label in labels_list:
        fname='*gen_img_label-{1}_epoch-*_step-{0}.npy'.format(step,label)
        fle=glob.glob(result_dir+'/images/'+fname)[0]
        imgs=np.load(fle)
    #     img_list.append(imgs)
    #     labels_list.append('Step:',step)
        strg='label:%s_step:%s'%(label,step)
        dict_samples[strg]=imgs
    

In [19]:
def f_widget_compare(sample_names,sample_dict,label,bkgnd_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,crop=False):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    assert Fig_type in ['pixel','spectrum','grid'],"Invalid mode %s"%(mode)
    
    bkgnd_arr=bkgnd_dict[str(label)]
    if crop: # Crop out large pixel values
        for key in sample_names:
            print(sample_dict[key].shape)
            sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
            print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
        
    hist_range=(0,0.996)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        hist_range=(0,2000)
        bkgnd_arr=f_invtransform(bkgnd_arr)
    
    if Fig_type=='pixel':
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,bkgnd_arr=bkgnd_arr,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        plt.xlim(0,1500)
    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,bkgnd_arr=bkgnd_arr,log_scale=log_scale)
    elif Fig_type=='grid':
        for key in label_list:
            f_plot_grid(dict_samples[key][40:58],cols=6,fig_size=(6,3))


In [20]:
# dict1=dict_samples
# dict1.update(bkgnd_dict)

In [21]:
bkgnd=samples

bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
# bins=f_transform(bins)   ### scale to (-1,1) 
# bins=100
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),label=label_list,
                bkgnd_dict=fixed(bkgnd_dict),
                Fig_type=ToggleButtons(options=['pixel','spectrum','grid']),bins=fixed(bins),mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=('label:0_step:40800', 'label:1_step:…

<function __main__.f_widget_compare(sample_names, sample_dict, label, bkgnd_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True, crop=False)>

## Plot images

In [None]:
# f_plot_grid(bkgnd[100:118],cols=6,fig_size=(6,3))
f_plot_grid(dict_samples['label:3_step:40800'][40:58],cols=6,fig_size=(6,3))

In [23]:
dict_samples.keys()

dict_keys(['label:0_step:40800', 'label:1_step:40800', 'label:2_step:40800', 'label:3_step:40800'])

In [27]:
img=dict_samples['label:0_step:40800']
bkgnd_img=bkgnd_dict['0']

In [29]:
f_pixel_intensity(img,bins=50,label='validation',mode='avg',normalize=False,log_scale=True,plot=True, hist_range=None)
f_pixel_intensity(bkgnd_img,bins=50,label='validation',mode='avg',normalize=False,log_scale=True,plot=True, hist_range=None)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(array([1.9136950e+02, 9.9779850e+02, 1.8696150e+03, 1.8893970e+03,
        1.4258565e+03, 1.1977155e+03, 1.3601545e+03, 1.3363435e+03,
        1.0688030e+03, 7.9219850e+02, 6.3844050e+02, 5.4602950e+02,
        4.5231900e+02, 3.6659100e+02, 3.0921750e+02, 2.6155500e+02,
        2.2118350e+02, 1.8883450e+02, 1.6224450e+02, 1.3967000e+02,
        1.2106750e+02, 1.0571700e+02, 9.1712500e+01, 8.0341000e+01,
        7.1059000e+01, 6.2470000e+01, 5.4669000e+01, 4.8535000e+01,
        4.2796000e+01, 3.7751000e+01, 3.3472000e+01, 2.9657000e+01,
        2.6354000e+01, 2.3410500e+01, 2.0811500e+01, 1.8300000e+01,
        1.6251000e+01, 1.4422000e+01, 1.2442000e+01, 1.0896500e+01,
        9.5050000e+00, 8.2150000e+00, 7.1485000e+00, 6.0400000e+00,
        4.9315000e+00, 3.9805000e+00, 3.0995000e+00, 2.1330000e+00,
        1.1805000e+00, 2.9550000e-01]),
 array([1.1898544 , 3.58897673, 4.05028541, 2.32789093, 1.26038147,
        0.89202846, 1.11004406, 1.06009658, 1.12010539, 1.08693884,
        

In [30]:
f_compare_pixel_intensity(img_lst=[img],label_lst=['a'],bkgnd_arr=bkgnd_dict['0'],normalize=True,log_scale=False, mode='avg',bins=50)
# f_compare_pixel_intensity(img_lst=[img],label_lst=['a'],bkgnd_arr=bkgnd_dict['0'],normalize=True,log_scale=False, mode='avg',bins=50)




Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
  hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
