# Analyze results for 3D CGAN
Feb 22, 2021

In [64]:
import     numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import sys
import os
import glob
import pickle 

from matplotlib.colors import LogNorm, PowerNorm, Normalize
import seaborn as sns
from functools import reduce

In [65]:
from ipywidgets import *

In [66]:
%matplotlib widget

In [67]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/cosmogan_pytorch/code/modules_image_analysis/')
from modules_img_analysis import *

In [68]:
sys.path.append('/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/cosmogan_pytorch/code/5_3d_cgan/1_main_code/')
import post_analysis_pandas as post


In [69]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


### Read validation data

In [70]:
# bins=np.concatenate([np.array([-0.5]),np.arange(0.5,20.5,1),np.arange(20.5,100.5,5),np.arange(100.5,1000.5,50),np.array([2000])]) #bin edges to use
bins=np.concatenate([np.array([-0.5]),np.arange(0.5,100.5,5),np.arange(100.5,300.5,20),np.arange(300.5,1000.5,50),np.array([2000])]) #bin edges to use

bins=f_transform(bins)   ### scale to (-1,1)
# ### Extract validation data
sigma_lst=[0.5,0.65,0.8,1.1]
labels_lst=range(len(sigma_lst))
bkgnd_dict={}
num_bkgnd=200

for label in labels_lst:
    fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_{0}_train_val.npy'.format(sigma_lst[label])
    print(fname)
    samples=np.load(fname,mmap_mode='r')[-num_bkgnd:][:,0,:,:]
    
    dict_val=post.f_compute_hist_spect(samples,bins)
    bkgnd_dict[str(sigma_lst[label])]=dict_val
    del samples

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_0.5_train_val.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_0.65_train_val.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_0.8_train_val.npy
/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_1.1_train_val.npy


## Read data

In [71]:
# main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'
# results_dir=main_dir+'20201002_064327'

In [72]:
dict1={'64':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/',
      '512':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/'}

u=interactive(lambda x: dict1[x], x=Select(options=dict1.keys()))
# display(u)


In [73]:
# parent_dir=u.result
parent_dir=dict1['64']
dir_lst=[i.split('/')[-1] for i in glob.glob(parent_dir+'202102*')]
w=interactive(lambda x: x, x=Dropdown(options=dir_lst))
display(w)

interactive(children=(Dropdown(description='x', options=('20210223_210217_3dcgan_predict_0.8_m2', '20210225_20…

In [74]:
result=w.result
result_dir=parent_dir+result
print(result_dir)

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210227_050213_3dcgan_predict_0.65_m2


## Plot Losses

In [75]:
df_metrics=pd.read_pickle(result_dir+'/df_metrics.pkle').astype(np.float64)


In [76]:
df_metrics.tail(10)

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
67341,67341.0,49.0,0.156118,0.148114,0.304232,3.482937,111.945435,108.462494,-0.442317,108.139931,-0.488237,3.793265,-3.367856,-3.450812,0.518365
67342,67342.0,49.0,0.109853,0.132423,0.242277,3.965339,112.274101,108.308762,-0.261318,107.195915,-0.814009,3.696232,-3.085463,-3.945446,0.51671
67343,67343.0,49.0,0.138316,0.141367,0.279684,3.926706,110.151787,106.225082,-0.38567,106.280052,-0.543033,3.453577,-3.307747,-3.906112,0.497142
67344,67344.0,49.0,0.128451,0.142791,0.271242,2.863164,109.803337,106.94017,-0.453996,106.952103,-0.521182,2.970235,-3.632308,-2.799794,0.499532
67345,67345.0,49.0,0.126615,0.181449,0.308064,5.506667,111.32933,105.822662,-0.117044,107.134361,0.063623,4.096053,-2.325142,-5.502391,0.542141
67346,67346.0,49.0,0.157101,0.156469,0.31357,4.87049,112.470329,107.599838,-0.321778,107.913025,0.15602,3.298939,-5.057765,-4.861575,0.504071
67347,67347.0,49.0,0.132959,0.142505,0.275464,3.220198,112.018547,108.798347,-0.321858,107.663956,0.081704,3.699559,-4.06784,-3.173584,0.525678
67348,67348.0,49.0,0.155534,0.177702,0.333236,5.634856,114.666641,109.031784,-0.116069,106.397987,0.064157,4.152534,-2.328207,-5.630915,0.492672
67349,67349.0,49.0,0.106288,0.170729,0.277016,5.575987,113.576813,108.000824,-0.425292,107.321632,-0.888086,3.468798,-5.436711,-5.571213,0.505218
67350,,,,,,,,,,109.91964,-0.773152,,,,


In [77]:
def f_plot_metrics(df,col_list):
    
    plt.figure()
    for key in col_list:
        plt.plot(df_metrics[key],label=key,marker='*',linestyle='')
    plt.legend()
    
#     col_list=list(col_list)
#     df.plot(kind='line',x='step',y=col_list)
    
# f_plot_metrics(df_metrics,['spec_chi','hist_chi'])

interact_manual(f_plot_metrics,df=fixed(df_metrics), col_list=SelectMultiple(options=df_metrics.columns.values))

interactive(children=(SelectMultiple(description='col_list', options=('step', 'epoch', 'Dreal', 'Dfake', 'Dful…

<function __main__.f_plot_metrics(df, col_list)>

In [78]:

chi=df_metrics.quantile(q=0.2,axis=0)['hist_chi']
print(chi)
df_metrics[(df_metrics['hist_chi']<=chi)&(df_metrics.epoch>1)].sort_values(by=['hist_chi']).head(10)

-1.3490444421768188


Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,D(x),D_G_z1,D_G_z2,time
57761,57761.0,42.0,0.148994,0.137105,0.286099,3.992824,109.888618,105.895798,-1.713227,108.614792,-2.577394,3.096658,-2.991325,-3.972713,0.462246
51767,51767.0,38.0,0.139329,0.132224,0.271553,4.056891,112.532043,108.475151,-1.925579,109.068878,-2.559478,3.740458,-3.400756,-4.037302,0.469805
61628,61628.0,45.0,0.166746,0.16629,0.333036,4.223664,112.337486,108.113823,-2.153617,109.370346,-2.55757,3.433569,-2.894193,-4.207905,0.541233
61611,61611.0,45.0,0.142531,0.186677,0.329209,6.068652,115.103165,109.034515,-2.125532,109.564293,-2.522994,4.246599,-2.061249,-6.066096,0.549664
51769,51769.0,38.0,0.25996,0.132622,0.392583,3.77737,110.760124,106.982758,-1.841927,108.812904,-2.52001,3.191755,-3.325062,-3.745019,0.467042
41963,41963.0,31.0,0.110635,0.118655,0.229291,4.670769,113.709496,109.038727,-0.720874,109.528412,-2.518847,4.082439,-3.374731,-4.658276,0.53596
52354,52354.0,38.0,0.217338,0.153826,0.371164,3.185548,110.476944,107.291397,-1.575383,109.051994,-2.518464,2.943043,-3.885588,-3.139442,0.479537
53374,53374.0,39.0,0.162265,0.154714,0.316978,3.847303,112.276215,108.428909,-1.077739,109.332817,-2.512703,3.287806,-3.534106,-3.823246,0.475777
51896,51896.0,38.0,0.152291,0.133256,0.285547,3.967165,111.075348,107.108185,-1.994942,108.432281,-2.510559,3.542818,-3.589303,-3.94611,0.482049
61627,61627.0,45.0,0.134835,0.138505,0.27334,3.853326,113.611565,109.75824,-2.099251,109.541901,-2.505042,3.507229,-3.388855,-3.82971,0.513114


In [79]:
# display(df_metrics.sort_values(by=['hist_chi']).head(8))
# display(df_metrics.sort_values(by=['spec_chi']).head(8))

<!-- ### Read validation data -->

## Read stored chi-squares for images

In [82]:
## Get sigma list from saved files
flist=glob.glob(result_dir+'/df_processed*')
sigma_lst=[i.split('/')[-1].split('df_processed_')[-1].split('.pkle')[0] for i in flist]
sigma_lst.sort() ### Sorting is important for labels to match !!

labels_lst=np.arange(len(sigma_lst))

In [83]:
sigma_lst,labels_lst

(['0.5', '0.8', '1.1'], array([0, 1, 2]))

In [84]:
### Create a merged dataframe

df_list=[]
for label in labels_lst:
    df=pd.read_pickle(result_dir+'/df_processed_{0}.pkle'.format(str(sigma_lst[label])))
    df[['epoch','step']]=df[['epoch','step']].astype(int)
    df['label']=df.epoch.astype(str)+'-'+df.step.astype(str) # Add label column for plotting
    df_list.append(df)

for i,df in enumerate(df_list):
    df1=df.add_suffix('_'+str(i))
    # renaming the columns to be joined on
    keys=['epoch','step','img_type','label']
    rename_cols_dict={key+'_'+str(i):key for key in keys}
#     print(rename_cols_dict)
    df1.rename(columns=rename_cols_dict,inplace=True) 
    df_list[i]=df1
    
df_merged=reduce(lambda x, y : pd.merge(x, y, on = ['step','epoch','img_type','label']), df_list)

### Get sum of all 4 classes for 3 types of chi-squares
for chi_type in ['chi_1','chi_spec1','chi_1c']:
    keys=[chi_type+'_'+str(label) for label in labels_lst]
#     display(df_merged[keys].sum(axis=1))
    df_merged['sum_'+chi_type]=df_merged[keys].sum(axis=1)
del df_list



In [94]:

def f_plot_hist_spec(df,param_labels,sigma_lst,steps_list,bkg_dict,plot_type):

    img_size=64
    assert plot_type in ['hist','spec','grid','spec_relative'],"Invalid mode %s"%(plot_type)

    if plot_type in ['hist','spec','spec_relative']:     fig=plt.figure(figsize=(6,6))
    for par_label in param_labels:
        df=df[df.step.isin(steps_list)]
#         print(df.shape)
        idx=sigma_lst.index(par_label)
        suffix='_%s'%(idx)
        dict_bkg=bkg_dict[str(par_label)]
        
        for (i,row),marker in zip(df.iterrows(),itertools.cycle('>^*sDHPdpx_')):
            label=row.label+'_'+str(par_label)
            if plot_type=='hist':
                x1=row['hist_bin_centers'+suffix]
                y1=row['hist_val'+suffix]
                yerr1=row['hist_err'+suffix]
                x1=f_invtransform(x1)
                
                plt.errorbar(x1,y1,yerr1,marker=marker,markersize=5,linestyle='',label=label)
            if plot_type=='spec':
                y2=row['spec_val'+suffix]
                yerr2=row['spec_sdev'+suffix]/np.sqrt(row['num_imgs'+suffix])
                x2=np.arange(len(y2))

                plt.fill_between(x2, y2 - yerr2, y2 + yerr2, alpha=0.4)
                plt.plot(x2, y2, marker=marker, linestyle=':',label=label)

            if plot_type=='spec_relative':

                y2=row['spec_val'+suffix]
                yerr2=row['spec_sdev'+suffix]
                x2=np.arange(len(y2))

                ### Reference spectrum
                y1,yerr1=dict_bkg['spec_val'],dict_bkg['spec_sdev']
                y=y2/y1
                ## Variance is sum of variance of both variables, since they are uncorrelated

                # delta_r= |r| * sqrt(delta_a/a)^2 +(\delta_b/b)^2) / \sqrt(N)
                yerr=(np.abs(y))*np.sqrt((yerr1/y1)**2+(yerr2/y2)**2)/np.sqrt(row['num_imgs'+suffix])
                
                plt.fill_between(x2, y - yerr, y + yerr, alpha=0.4)
                plt.plot(x2, y, marker=marker, linestyle=':',label=label)

            if plot_type=='grid':
                images=np.load(row.fname)[:,:,:,0]
                print(images.shape)
                f_plot_grid(images[:18],cols=6,fig_size=(10,5))
            
        ### Plot reference data
        if plot_type=='hist':
            x,y,yerr=dict_bkg['hist_bin_centers'],dict_bkg['hist_val'],dict_bkg['hist_err']
            x=f_invtransform(x)
            plt.errorbar(x, y,yerr,color='k',linestyle='-',label='bkgnd')   
            plt.title('Pixel Intensity Histogram')
            plt.xscale('symlog',linthreshx=50)
        
        if plot_type=='spec':
            y,yerr=dict_bkg['spec_val'],dict_bkg['spec_sdev']/np.sqrt(num_bkgnd)
            x=np.arange(len(y))
            plt.fill_between(x, y - yerr, y + yerr, color='k',alpha=0.8)
            plt.title('Spectrum')
            plt.xlim(0,img_size/2)

        if plot_type=='spec_relative':
            plt.axhline(y=1.0,color='k',linestyle='-.')
            plt.title("Relative spectrum")
            plt.xlim(0,img_size/2)
            plt.ylim(0.5,2)

    if plot_type in ['hist','spec']:     
        plt.yscale('log')
    plt.legend(bbox_to_anchor=(0.3, 0.75),ncol=2, fancybox=True, shadow=True,prop={'size':6})


# f_plot_hist_spec(df_merged,[sigma_lst[-1]],sigma_lst,[best_step[0]],bkgnd_dict,'hist')

### Slice best steps

In [95]:
def f_slice_merged_df(df,cutoff=0.2,sort_col='chi_1',col_mode='all',label='all',params_lst=[0,1,2],head=10,epoch_range=[0,None],use_sum=True,display_flag=False):
    ''' View dataframe after slicing
    '''

    if epoch_range[1]==None: epoch_range[1]=df.max()['epoch']
    df=df[(df.epoch<=epoch_range[1])&(df.epoch>=epoch_range[0])]

    ######### Apply cutoff to keep reasonable chi1 and chispec1
    #### Add chi-square columns to use
    chi_cols=[]
    if use_sum: ## Add sum chi-square columns
        for j in ['chi_1','chi_spec1','chi_1c']: chi_cols.append('sum_'+j)
        
    if label=='all': ### Add chi-squares for all labels
        for j in ['chi_1','chi_spec1','chi_1c']:
            for idx,i in enumerate(params_lst): chi_cols.append(j+'_'+str(idx))
    else: ## Add chi-square for specific label
        assert label in params_lst, "label %s is not in %s"%(label,params_lst)
        label_idx=params_lst.index(label)
        print(label_idx)
        for j in ['chi_1','chi_spec1','chi_1c']: chi_cols.append(j+'_'+str(label_idx))
#     print(chi_cols)
    q_dict=dict(df_merged.quantile(q=cutoff,axis=0)[chi_cols])
    # print(q_dict)
    strg=['%s < %s'%(key,q_dict[key]) for key in chi_cols ]
    query=" & ".join(strg)
    # print(query)
    df=df.query(query)
    
    # Sort dataframe
    df1=df[df.epoch>0].sort_values(by=sort_col)
    chis=[i for i in df_merged.columns if 'chi' in i]
    col_list=['label']+chis+['epoch','step']
    if (col_mode=='short'): 
        col_list=['label']+[i for i in df_merged.columns if i.startswith('sum')]
        col_list=['label']+chi_cols
    df2=df1.head(head)[col_list]
    
    if display_flag: display(df2) # Display df
    
    return df2

# f_slice_merged_df(df_merged,cutoff=0.3,sort_col='sum_chi_1',label=0.65,params_lst=[0.5,0.65,0.8,1.1],use_sum=True,head=2000,display_flag=False,epoch_range=[7,None])

In [96]:
cols_to_sort=np.unique([i for i in df_merged.columns for j in ['chi_1_','chi_spec1_'] if ((i.startswith(j)) or (i.startswith('sum')))])

w=interactive(f_slice_merged_df,df=fixed(df_merged),
cutoff=widgets.FloatSlider(value=0.3, min=0, max=1.0, step=0.01), 
col_mode=['all','short'], display_flag=widgets.Checkbox(value=False),
use_sum=widgets.Checkbox(value=True),
label=ToggleButtons(options=['all']+sigma_lst), params_lst=fixed(sigma_lst),
head=widgets.IntSlider(value=10,min=1,max=20,step=1),
epoch_range=widgets.IntRangeSlider(value=[0,np.max(df.epoch.values)],min=0,max=np.max(df.epoch.values),step=1),
sort_col=cols_to_sort
)
display(w)

interactive(children=(FloatSlider(value=0.3, description='cutoff', max=1.0, step=0.01), Dropdown(description='…

In [97]:
df_sliced=w.result
# df_sliced

In [108]:
best_step=[]

best_step.append(f_slice_merged_df(df_merged,cutoff=0.9,sort_col='sum_chi_1',label='all',use_sum=True,head=4,display_flag=False,epoch_range=[7,None],params_lst=sigma_lst).step.values)
best_step.append(f_slice_merged_df(df_merged,cutoff=0.9,sort_col='sum_chi_spec1',label='all',use_sum=True,head=4,display_flag=False,epoch_range=[7,None],params_lst=sigma_lst).step.values)
best_step.append(f_slice_merged_df(df_merged,cutoff=0.9,sort_col='sum_chi_1c',label='all',use_sum=True,head=2,display_flag=False,epoch_range=[7,None],params_lst=sigma_lst).step.values)

# best_step.append([46669,34281])
best_step=np.unique([i for j in best_step for i in j])
print(best_step)
best_step

[27600 30950 42620 43830 50440 58320 58330 59500 64670 65600]


array([27600, 30950, 42620, 43830, 50440, 58320, 58330, 59500, 64670,
       65600])

In [109]:
# best_step=[6176]
# best_step= [32300, 35810, 36020, 37030, 38640, 42480, 43850]

# best_step=np.arange(40130,40135).astype(int)

In [110]:
df_best=df_merged[df_merged.step.isin(best_step)]
print(df_best.shape)
print([(df_best[df_best.step==step].epoch.values[0],df_best[df_best.step==step].step.values[0]) for step in best_step])
# print([(df_best.loc[idx].epoch,df_best.loc[idx].step) for idx in best_idx])

(10, 61)
[(20, 27600), (22, 30950), (31, 42620), (32, 43830), (37, 50440), (43, 58320), (43, 58330), (44, 59500), (48, 64670), (48, 65600)]


In [111]:
col_list=['label']+[i for i in df_merged.columns if i.startswith('sum')]

df_best[col_list]

Unnamed: 0,label,sum_chi_1,sum_chi_spec1,sum_chi_1c
2759,20-27600,0.254456,22.434903,0.002421
3094,22-30950,0.088541,20.0761,0.0463
4262,31-42620,0.050784,19.675746,0.005717
4383,32-43830,0.081891,18.40763,0.00542
5044,37-50440,0.058981,21.908821,0.008988
5833,43-58320,0.20654,12.504562,0.020564
5834,43-58330,0.29811,13.698965,0.007691
5951,44-59500,0.604574,13.098999,0.015755
6468,48-64670,0.486883,16.524098,0.002573
6561,48-65600,0.681249,13.732368,0.302913


### Interactive plot

In [112]:
interact_manual(f_plot_hist_spec,df=fixed(df_merged),
                param_labels=SelectMultiple(options=sigma_lst),sigma_lst=fixed(sigma_lst),
                steps_list=SelectMultiple(options=best_step),
                bkg_dict=fixed(bkgnd_dict),plot_type=ToggleButtons(options=['hist','spec','grid','spec_relative']))

interactive(children=(SelectMultiple(description='param_labels', options=('0.5', '0.8', '1.1'), value=()), Sel…

<function __main__.f_plot_hist_spec(df, param_labels, sigma_lst, steps_list, bkg_dict, plot_type)>

In [None]:
# ### Check deterministic
# main_dir='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/'
# epoch,step=0,230
# f1='20210113_185116_cgan_predict_0.65_m2/images/gen_img_label-0.5_epoch-{0}_step-{1}.npy'.format(epoch,step)
# f2='20210113_092234_cgan_predict_0.65_m2/images/gen_img_label-0.5_epoch-{0}_step-{1}.npy'.format(epoch,step)

# epoch,step=3,20
# f1='20210114_191648_nb_test/images/gen_img_label-0.5_epoch-{0}_step-{1}.npy'.format(epoch,step)
# f2='20210114_193009_nb_test/images/gen_img_label-0.5_epoch-{0}_step-{1}.npy'.format(epoch,step)

# a1=np.load(main_dir+f1)
# a2=np.load(main_dir+f2)
# # print(a1.shape,a2.shape)

# print(np.mean(a1),np.mean(a2))
# print(np.max(a1),np.max(a2))

### Delete unwanted stored models
(Since deterministic runs aren't working well )

In [113]:
# fldr='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128sq/20210119_134802_cgan_predict_0.65_m2/models'
fldr=result_dir
print(fldr)
flist=glob.glob(fldr+'/models/checkpoint_*.tar')
len(flist)

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210227_050213_3dcgan_predict_0.65_m2


3742

In [117]:
# Delete unwanted stored images
for i in flist:
    try:
        step=int(i.split('/')[-1].split('_')[-1].split('.')[0])
        if step not in best_step:
#             print("Found")
#             print(step)
            os.remove(i)
            pass
        else: 
            print(step)
#             print(i)
    except Exception as e:
#         print(e)
#         print(i)
        pass

42620
30950
50440
65600
43830
58330
58320


In [118]:
best_step

array([27600, 30950, 42620, 43830, 50440, 58320, 58330, 59500, 64670,
       65600])

In [119]:
np.load('/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset2a_3dcgan_4univs_64cube_simple_splicing/norm_1_sig_0.5_train_val.npy',mmap_mode='r').shape

(14376, 1, 64, 64, 64)