# Debug training
May 19, 2021

In [1]:
import     numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import sys
import os
import glob
import pickle 

from matplotlib.colors import LogNorm, PowerNorm, Normalize
import seaborn as sns
from functools import reduce

import socket

In [2]:
from ipywidgets import *

In [3]:
%matplotlib widget

In [4]:
base_dict={'cori':'/global/u1/v/vpa/project/jpt_notebooks/Cosmology/Cosmo_GAN/repositories/cosmogan_pytorch/',
         'summit':'/autofs/nccs-svm1_home1/venkitesh/projects/cosmogan/cosmogan_pytorch/'}
facility='cori' if socket.gethostname()[:4]=='cori' else 'summit'

base_dir=base_dict[facility]

In [5]:
sys.path.append(base_dir+'/code/modules_image_analysis/')
from modules_img_analysis import *
# sys.path.append(base_dir+'/code/5_3d_cgan/1_main_code/')
# import post_analysis_pandas as post

In [6]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


## Read data

In [7]:
dict1={'cori':{
'2d':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/128_square/',
'3d':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3D/',
'3d_cgan':'/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/'},
'summit':{'2d':'/gpfs/alpine/ast153/proj-shared/venkitesh/Cosmogan/data/results_pytorch/2d/',
          '3d':'/gpfs/alpine/ast153/proj-shared/venkitesh/Cosmogan/data/results_pytorch/3d/'}}

In [8]:
# parent_dir=u.result
parent_dir=dict1[facility]['3d_cgan']
dir_lst=[i.split('/')[-1] for i in glob.glob(parent_dir+'202106*')]
dir_lst
w=interactive(lambda x: x, x=Dropdown(options=dir_lst))
display(w)

interactive(children=(Dropdown(description='x', options=('20210602_112153_cgan_bs32_nodes8_lrd-4x-lrg_cori',),…

In [9]:
result=w.result
result_dir=parent_dir+result
print(result_dir)

/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d_cGAN/20210602_112153_cgan_bs32_nodes8_lrd-4x-lrg_cori


## Plot Losses

In [10]:
df_metrics=pd.read_pickle(result_dir+'/df_metrics.pkle').astype(np.float64)
# df_metrics.tail(10)
def f_plot_metrics(df,col_list):
    
    plt.figure()
    for key in col_list:
        plt.plot(df_metrics[key],label=key,marker='*',linestyle='')
    plt.legend()
    
#     col_list=list(col_list)
#     df.plot(kind='line',x='step',y=col_list)

f_plot_metrics(df_metrics,['hist_chi'])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
f_plot_metrics(df_metrics,['lr_d','lr_g'])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
interact_manual(f_plot_metrics,df=fixed(df_metrics), col_list=SelectMultiple(options=df_metrics.columns.values))

interactive(children=(SelectMultiple(description='col_list', options=('step', 'epoch', 'Dreal', 'Dfake', 'Dful…

<function __main__.f_plot_metrics(df, col_list)>

In [13]:
df_metrics[(df_metrics.lr_d>=6.69e-04) ]

Unnamed: 0,step,epoch,Dreal,Dfake,Dfull,G_adv,G_full,spec_loss,hist_loss,spec_chi,hist_chi,gp_loss,fm_loss,D(x),D_G_z1,D_G_z2,time,lr_d,lr_g


In [14]:
df_metrics.plot(kind='scatter',x='step',y=['lr_d'])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:xlabel='step', ylabel='[lr_d]'>

In [15]:
np.unique(df_metrics.lr_d.values),np.unique(df_metrics.lr_g.values)

# display(df_metrics.sort_values(by=['hist_chi']).head(8))
# display(df_metrics.sort_values(by=['spec_chi']).head(8))

(array([1.e-05,    nan]), array([2.5e-06,     nan]))

## Calculating learn rates

In [16]:
Nsteps=5;Lf=0.00005;Li=0.0008
Lambda=(Lf/Li)**(1.0/Nsteps)
print(Lambda,Lambda**2)

lst=[10,40,60,70,80,100]
# lst=range(1,11)
[(Li*Lambda**(count+1),i) for count,i in enumerate(lst)]

0.5743491774985174 0.3298769776932235


[(0.000459479341998814, 10),
 (0.0002639015821545788, 40),
 (0.00015157165665103975, 60),
 (8.705505632961239e-05, 70),
 (4.9999999999999975e-05, 80),
 (2.8717458874925857e-05, 100)]

In [17]:
Lambda=0.25;Li=0.004
lst=[10,40,60,70,80,100,140,180]
# lst=range(1,11)
[(Li*Lambda**(count+1),i) for count,i in enumerate(lst)]

[(0.001, 10),
 (0.00025, 40),
 (6.25e-05, 60),
 (1.5625e-05, 70),
 (3.90625e-06, 80),
 (9.765625e-07, 100),
 (2.44140625e-07, 140),
 (6.103515625e-08, 180)]

## Grid plot of images

In [18]:
epoch=37
flist=glob.glob(result_dir+'/images/gen_img*_epoch-{0}_step*'.format(epoch))
steps_list=[fname.split('/')[-1].split('step-')[-1].split('.')[0] for fname in flist]

print(*steps_list)

21830 22050 21850 22020 22080 21670 21840 21980 22160 22140 21980 22010 21730 21890 21940 21940 22110 21820 21820 21780 22050 21710 22000 22020 21970 22090 21680 21990 22130 22050 21730 21860 22190 21810 21960 21820 21800 22040 21910 21880 22010 21990 21900 22080 22060 21930 21790 21660 22110 21680 22020 21940 21930 22100 22070 22230 21930 21760 22100 22190 21740 21670 21700 22030 21830 22160 21710 21760 22140 21810 21650 22180 22220 21790 22160 21870 21840 22230 22000 22070 22130 22040 21870 21900 22060 22200 21980 21970 21650 22180 22090 21920 22150 22190 21960 21740 22200 21900 21700 21950 22150 21780 22120 21860 21830 21750 21690 21990 22070 21750 21840 22200 21750 22030 21660 21720 22110 22100 21690 21950 21890 22170 22210 22230 21860 22220 21710 21890 21950 21770 21770 21800 21800 22120 22180 22080 22090 22040 22150 21810 21660 22000 22120 21690 22010 21850 21780 22210 22170 21910 21920 21880 21730 21920 21720 22060 22170 21960 21880 21670 22140 22210 22030 21970 21760 21910 2168

In [35]:
steps_list

['21830',
 '22050',
 '21850',
 '22020',
 '22080',
 '21670',
 '21840',
 '21980',
 '22160',
 '22140',
 '21980',
 '22010',
 '21730',
 '21890',
 '21940',
 '21940',
 '22110',
 '21820',
 '21820',
 '21780',
 '22050',
 '21710',
 '22000',
 '22020',
 '21970',
 '22090',
 '21680',
 '21990',
 '22130',
 '22050',
 '21730',
 '21860',
 '22190',
 '21810',
 '21960',
 '21820',
 '21800',
 '22040',
 '21910',
 '21880',
 '22010',
 '21990',
 '21900',
 '22080',
 '22060',
 '21930',
 '21790',
 '21660',
 '22110',
 '21680',
 '22020',
 '21940',
 '21930',
 '22100',
 '22070',
 '22230',
 '21930',
 '21760',
 '22100',
 '22190',
 '21740',
 '21670',
 '21700',
 '22030',
 '21830',
 '22160',
 '21710',
 '21760',
 '22140',
 '21810',
 '21650',
 '22180',
 '22220',
 '21790',
 '22160',
 '21870',
 '21840',
 '22230',
 '22000',
 '22070',
 '22130',
 '22040',
 '21870',
 '21900',
 '22060',
 '22200',
 '21980',
 '21970',
 '21650',
 '22180',
 '22090',
 '21920',
 '22150',
 '22190',
 '21960',
 '21740',
 '22200',
 '21900',
 '21700',
 '21950',


In [47]:
# fname=flist[0]
# fname,fname.split('/')[-1].split('step-')[-1].split('.')[0]

step=9550
fname=glob.glob(result_dir+'/images/gen_img_*_epoch-{0}_step-{1}.npy'.format(epoch,step))[0]
fname

'/gpfs/alpine/ast153/proj-shared/venkitesh/Cosmogan/data/results_pytorch/3d/20210519_81818_cgan_bs16_lr0.001_nodes8_spec0.1/images/gen_img_label-0.5_epoch-170_step-9550.npy'

In [48]:
images=np.load(fname)[:,0,:,:]
print(images.shape)
f_plot_grid(images[:8,:,:,0],cols=4,fig_size=(8,4))

(32, 64, 64, 64)
2 4


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
epochs_list=np.unique(df_metrics.epoch.values)[:-1]
def f2(epoch): 
    flist=glob.glob(result_dir+'/images/gen_img*_epoch-{0}*'.format(int(epoch)))
#     print(flist)
    steps_list=[fname.split('/')[-1].split('step-')[-1].split('.')[0] for fname in flist]
#     steps_list=df_metrics[df_metrics.epoch==epoch].step.values
    return steps_list

w2=interactive(f2,epoch=Dropdown(options=epochs_list) )
display(w2)


interactive(children=(Dropdown(description='epoch', options=(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,…

In [20]:
result=w2.result
steps_list=result

In [21]:
def f3(steps):     
    for step in steps:
        
        fname=glob.glob(result_dir+'/images/gen_img_*step-{0}.npy'.format(step))[0]
#         print(fname)

        images=np.load(fname)[:,0,:,:]
#         print(images.shape)
        f_plot_grid(images[:8,:,:,0],cols=4,fig_size=(8,4))

# w3=interactive(f3,step=Dropdown(options=steps_list))
# display(w3)

interact(f3,steps=SelectMultiple(options=steps_list))


interactive(children=(SelectMultiple(description='steps', options=('41110', '41380', '40980', '41060', '41090'…

<function __main__.f3(steps)>