# Analyze results with histograms and plots made from scratch
October 27, 2020

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import sys
import os
import glob
import pickle 

from matplotlib.colors import LogNorm, PowerNorm, Normalize
from ipywidgets import *

In [2]:
%matplotlib widget

In [3]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/cosmogan_pytorch/code/modules_image_analysis/')
from modules_img_analysis import *

In [6]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


## Read data

In [7]:
# main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'
# results_dir=main_dir+'20201002_064327'

In [8]:
dict1={'128':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/',
      '512':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/512sq/'}

u=interactive(lambda x: dict1[x], x=Select(options=dict1.keys()))
# display(u)


In [9]:
# parent_dir=u.result
parent_dir=dict1['128']
dir_lst=[i.split('/')[-1] for i in glob.glob(parent_dir+'20*')]
w=interactive(lambda x: x, x=Dropdown(options=dir_lst))
display(w)

interactive(children=(Dropdown(description='x', options=('20210422_122718_bs128_lr0.0002_nodes4', '20210426_15…

In [10]:
result=w.result
result_dir=parent_dir+result
print(result_dir)


/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20210426_152318_bs64_lr0.004_nodes4


## Plot Losses

In [11]:
df_metrics=pd.read_pickle(result_dir+'/df_metrics.pkle').astype(np.float64)

In [12]:
df_metrics.step.values.shape,df_metrics.G_full.values.shape

((9101,), (9101,))

In [13]:
df_metrics.head(10)

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time
0,0.0,0.0,0.751897,0.648592,1.400488,17.208977,103.863686,86.654709,1.739667,,,,,-0.096876,-0.108521,-17.208977,9.390066
1,1.0,0.0,0.30362,4.714922,5.018542,60.976093,inf,inf,1.578732,,3.952453,,,20.664104,4.786402,-60.976093,0.826967
2,2.0,0.0,0.978796,0.874735,1.853531,41.75872,151.234131,109.475403,1.627122,,3.348714,,,7.584488,-54.740135,-41.75872,0.569624
3,3.0,0.0,0.791863,53.749104,54.540966,39.112743,126.128311,87.015564,1.502819,95.93573,1.347393,,,42.349342,54.597664,-39.112743,0.570859
4,4.0,0.0,23.686087,0.412406,24.098494,2.2e-05,87.292374,87.292351,1.157086,86.956802,0.84455,,,-24.052135,-28.130615,18.958542,0.572721
5,5.0,0.0,1.030887,24.931623,25.962511,25.755295,113.052917,87.297623,0.774959,87.198143,0.532577,,,8.728026,25.339025,-25.755295,0.570574
6,6.0,0.0,7.439674,0.061945,7.501619,0.065382,96.549355,96.483971,0.985109,103.531616,1.2036,,,-7.580256,-9.98056,13.657116,0.570932
7,7.0,0.0,0.202389,15.220135,15.422523,32.878639,119.753365,86.874725,1.14379,86.765869,1.175164,,,14.271322,15.665222,-32.878639,0.574702
8,8.0,0.0,1.279805,0.463489,1.743294,31.066399,118.145065,87.078667,1.044349,87.050423,1.094693,,,7.463947,-20.486927,-31.066357,0.591222
9,9.0,0.0,1.159972,14.179234,15.339205,66.740944,153.089813,86.348869,0.852234,86.926758,0.943685,,,8.89714,14.095462,-66.740944,0.570953


In [14]:
df_metrics[df_metrics.step%50==0].head(10)

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time
0,0.0,0.0,0.751897,0.648592,1.400488,17.208977,103.863686,86.654709,1.739667,,,,,-0.096876,-0.108521,-17.208977,9.390066
50,50.0,0.0,0.29146,7.306765,7.598224,18.856049,105.540604,86.684555,1.307298,86.645187,1.366883,,,16.213894,7.41024,-18.856049,0.570443
100,100.0,1.0,14.874335,0.472264,15.3466,18.901756,104.963417,86.061661,1.477945,86.092682,1.830539,,,-14.665679,-27.193026,-18.773832,0.571011
150,150.0,2.0,0.414515,4.84883,5.263345,16.009462,102.12764,86.118179,2.567076,86.255249,2.786861,,,9.862619,4.630409,-16.009218,0.571313
200,200.0,3.0,1.309308,0.245749,1.555058,9.700077,96.883286,87.183212,1.864866,87.112305,1.847177,,,0.39636,-14.024017,-9.697426,0.571914
250,250.0,3.0,0.083656,1.28722,1.370876,4.947682,90.257057,85.309372,2.581921,85.06517,2.500983,,,8.162456,-0.203456,-4.81082,0.583194
300,300.0,4.0,0.225797,0.127405,0.353202,3.122599,88.156876,85.034279,2.176003,87.707848,2.031773,,,3.016414,-2.822923,-3.047012,0.572028
350,350.0,5.0,0.09008,0.095891,0.185971,5.180593,91.926369,86.745773,1.888858,86.742302,1.852068,,,4.31534,-3.505306,-5.170539,0.573107
400,400.0,6.0,0.723791,0.304956,1.028748,1.74502,88.513947,86.768929,1.988866,86.809692,2.068506,,,0.282616,-1.456361,-1.459283,0.57246
450,450.0,6.0,0.145443,0.190148,0.335591,3.151639,87.634117,84.482475,2.552919,84.69915,2.293831,,,2.752791,-2.525089,-3.082799,0.572256


In [15]:
def f_plot_metrics(df,col_list):
    
    plt.figure()
    for key in col_list:
        plt.plot(df_metrics[key],label=key,marker='*',linestyle='')
    plt.legend()
    
#     col_list=list(col_list)
#     df.plot(kind='line',x='step',y=col_list)
    
# f_plot_metrics(df_metrics,['spec_chi','hist_chi'])

interact_manual(f_plot_metrics,df=fixed(df_metrics), col_list=SelectMultiple(options=df_metrics.columns.values))

interactive(children=(SelectMultiple(description='col_list', options=('step', 'epoch', 'Dreal', 'Dfake', 'Dful…

<function __main__.f_plot_metrics(df, col_list)>

In [16]:
chi=df_metrics.dropna().quantile(q=0.05,axis=0)['hist_chi']
print(chi)
df_metrics[df_metrics['hist_chi']<=chi].sort_values(by=['hist_chi']).head(10)

nan


Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time


In [17]:
display(df_metrics.sort_values(by=['hist_chi']).head(8))
display(df_metrics.sort_values(by=['spec_chi']).head(8))

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time
6632,6632.0,102.0,0.087768,0.470205,0.557972,6.474729,92.738846,86.264114,-1.899475,86.199135,-1.921451,,,5.610585,-1.461596,-6.471198,0.571661
9073,9073.0,139.0,0.102981,0.173592,0.276574,5.667845,91.409142,85.741295,-1.576246,85.882576,-1.868022,,,4.013289,-3.638928,-5.661482,0.577638
6633,6633.0,102.0,0.152671,0.097149,0.24982,5.670695,91.776039,86.105347,-1.837071,86.215729,-1.852083,,,3.527763,-3.922218,-5.648987,0.569683
6636,6636.0,102.0,0.086279,0.204076,0.290355,5.626267,91.639824,86.013557,-1.895997,85.513786,-1.827704,,,4.888808,-2.85206,-5.610027,0.56998
6640,6640.0,102.0,0.122821,0.141854,0.264675,4.281014,89.791313,85.5103,-1.828347,85.578041,-1.798325,,,3.11469,-3.159945,-4.247231,0.569897
6626,6626.0,101.0,0.157181,0.095641,0.252822,5.240767,88.414749,83.173981,-1.702831,82.219383,-1.797341,,,3.253948,-5.75547,-5.2172,0.573421
6641,6641.0,102.0,0.21359,0.174547,0.388137,4.379724,89.765953,85.38623,-1.743866,85.378059,-1.796956,,,2.923847,-3.543823,-4.354399,0.570502
6637,6637.0,102.0,0.127824,0.143183,0.271007,4.673038,90.526566,85.853531,-1.826571,85.189117,-1.79579,,,3.180319,-4.770911,-4.651279,0.5694


Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time
9100,,,,,,,,,,74.77607,-0.660791,,,,,,
6741,6741.0,103.0,0.115843,0.080065,0.195908,5.795529,87.336227,81.540695,-0.614875,75.350883,-0.887191,,,3.848163,-5.83242,-5.787385,0.573493
9029,9029.0,138.0,0.101974,0.119474,0.221448,5.628984,88.211189,82.582207,-0.82636,75.712143,-0.744714,,,4.093917,-3.367816,-5.602524,0.625813
6496,6496.0,99.0,0.067673,0.121554,0.189228,9.495139,90.492851,80.997711,0.274495,76.342667,-0.020472,,,5.117194,-7.163622,-9.494532,0.571131
7958,7958.0,122.0,0.115167,0.111931,0.227098,5.526088,86.673988,81.147903,-0.717747,76.351166,-0.794518,,,3.274301,-5.396588,-5.517005,0.588912
6495,6495.0,99.0,0.145468,0.967947,1.113415,8.745124,89.869179,81.124054,-0.038785,76.399269,-0.058288,,,7.551452,-0.183608,-8.743879,0.569442
9088,9088.0,139.0,0.112672,0.132351,0.245024,4.826756,83.269814,78.443054,-0.437219,76.586319,-0.616009,,,4.424865,-3.795499,-4.806999,0.571102
7990,7990.0,122.0,0.090425,0.176208,0.266633,5.524712,84.159325,78.634613,-0.64618,76.682259,-0.488107,,,5.008602,-2.707216,-5.516378,0.571079


## Plot

In [19]:
dict_samples={}

In [20]:
sigma_list=[0.5,0.65,0.8,1.1]
label_list=[0,1,2,3]
bkgnd_dict={}
num_bkgnd=2000

for label in label_list:
    fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/128_square/dataset_5_4univ_cgan/norm_1_sig_{0}_train_val.npy'.format(sigma_list[label])
    samples=np.load(fname,mmap_mode='r+')[-num_bkgnd:][:,0,:,:]
    bkgnd_dict[str(label)]=samples

In [21]:
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20210106_063353_cgan_b_sigmafloat_3classes_m2/images/inference_spec_epoch-11_step-27150_label-1.npy'
imgs=np.load(fname)
print(imgs.shape)
dict_samples['label:1_best_spec_0.65']=imgs

FileNotFoundError: [Errno 2] No such file or directory: '/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20210106_063353_cgan_b_sigmafloat_3classes_m2/images/inference_spec_epoch-11_step-27150_label-1.npy'

In [22]:
# bkgnd_dict.keys()

In [23]:
step_list=[27150]
labels_list=[0,1,2]
for step in step_list:
    for label in labels_list:
        fname='*gen_img_label-{1}_epoch-*_step-{0}.npy'.format(step,label)
        fle=glob.glob(result_dir+'/images/'+fname)[0]
        imgs=np.load(fle)
    #     img_list.append(imgs)
    #     labels_list.append('Step:',step)
        strg='label:%s_step:%s'%(label,step)
        dict_samples[strg]=imgs
    

IndexError: list index out of range

In [24]:
def f_widget_compare(sample_names,sample_dict,label,bkgnd_dict,Fig_type='pixel',rescale=True,log_scale=True,bins=25,mode='avg',normalize=True,crop=False):
    '''
    Module to make widget plots for pixel intensity or spectrum comparison for multiple sample sets
    '''
    assert Fig_type in ['pixel','spectrum','grid'],"Invalid mode %s"%(mode)
    
    bkgnd_arr=bkgnd_dict[str(label)]
    if crop: # Crop out large pixel values
        for key in sample_names:
            print(sample_dict[key].shape)
            sample_dict[key]=np.array([arr for arr in sample_dict[key] if np.max(arr)<=0.994])
            print(sample_dict[key].shape)
    
    img_list=[sample_dict[key] for key in sample_names]
    label_list=list(sample_names)
        
    hist_range=(0,0.996)
    
    if rescale: 
        for count,img in enumerate(img_list):
            img_list[count]=f_invtransform(img)
        hist_range=(0,2000)
        bkgnd_arr=f_invtransform(bkgnd_arr)
    
    if Fig_type=='pixel':
        f_compare_pixel_intensity(img_lst=img_list,label_lst=label_list,bkgnd_arr=bkgnd_arr,normalize=normalize,log_scale=log_scale, mode=mode,bins=bins,hist_range=hist_range)
        plt.xlim(0,1500)
    elif Fig_type=='spectrum':
        f_compare_spectrum(img_lst=img_list,label_lst=label_list,bkgnd_arr=bkgnd_arr,log_scale=log_scale)
    elif Fig_type=='grid':
        for key in label_list:
            f_plot_grid(dict_samples[key][40:58],cols=6,fig_size=(6,3))


In [25]:
# dict1=dict_samples
# dict1.update(bkgnd_dict)

In [26]:
bkgnd=samples

bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
# bins=f_transform(bins)   ### scale to (-1,1) 
# bins=100
interact_manual(f_widget_compare,sample_dict=fixed(dict_samples),
                sample_names=SelectMultiple(options=dict_samples.keys()),label=label_list,
                bkgnd_dict=fixed(bkgnd_dict),
                Fig_type=ToggleButtons(options=['pixel','spectrum','grid']),bins=fixed(bins),mode=['avg','simple'])

interactive(children=(SelectMultiple(description='sample_names', options=(), value=()), Dropdown(description='…

<function __main__.f_widget_compare(sample_names, sample_dict, label, bkgnd_dict, Fig_type='pixel', rescale=True, log_scale=True, bins=25, mode='avg', normalize=True, crop=False)>

## Plot images

In [None]:
# f_plot_grid(bkgnd[100:118],cols=6,fig_size=(6,3))
f_plot_grid(dict_samples['label:1_best_spec_0.65'][40:58],cols=6,fig_size=(6,3))

## Test directly

In [None]:
dict_samples.keys()

In [None]:
img=dict_samples['label:1_best_spec_0.65']
bkgnd_img=bkgnd_dict['1']

In [None]:
f_pixel_intensity(img,bins=50,label='validation',mode='avg',normalize=False,log_scale=True,plot=True, hist_range=None)
f_pixel_intensity(bkgnd_img,bins=50,label='validation',mode='avg',normalize=False,log_scale=True,plot=True, hist_range=None)

In [None]:
f_compare_pixel_intensity(img_lst=[img],label_lst=['a'],bkgnd_arr=bkgnd_dict['1'],normalize=True,log_scale=True, mode='avg',bins=50)
# f_compare_pixel_intensity(img_lst=[img],label_lst=['a'],bkgnd_arr=bkgnd_dict['0'],normalize=True,log_scale=False, mode='avg',bins=50)



In [None]:
img_list=[img,bkgnd_dict['0'],bkgnd_dict['1'],bkgnd_dict['2'],bkgnd_dict['3']]
f_compare_spectrum(img_lst=img_list,label_lst=['inf','bk0','bk1','bk2','bk3'],bkgnd_arr=bkgnd_dict['1'],log_scale=True)

