# Test post compute 3D

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import sys
import os
import glob
import pickle 
import itertools

from matplotlib.colors import LogNorm, PowerNorm, Normalize
from ipywidgets import *

In [2]:
%matplotlib widget

In [3]:
### Transformation functions for image pixel values
def f_transform(x):
    return 2.*x/(x + 4.) - 1.

def f_invtransform(s):
    return 4.*(1. + s)/(1. - s)


## Histogram modules

In [4]:
def f_batch_histogram(img_arr,bins,norm,hist_range):
    ''' Compute histogram statistics for a batch of images'''

    ## Extracting the range. This is important to ensure that the different histograms are compared correctly
    if hist_range==None : ulim,llim=np.max(img_arr),np.min(img_arr)
    else: ulim,llim=hist_range[1],hist_range[0]
#         print(ulim,llim)
    ### array of histogram of each image
    hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
    hist=np.stack(hist_arr[:,0]) # First element is histogram array
#         print(hist.shape)
    bin_list=np.stack(hist_arr[:,1]) # Second element is bin value 
    ### Compute statistics over histograms of individual images
    mean,err=np.mean(hist,axis=0),np.std(hist,axis=0)/np.sqrt(hist.shape[0])
    bin_edges=bin_list[0]
    centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    return mean,err,centers
    

def f_pixel_intensity(img_arr,bins=25,label='validation',mode='avg',normalize=False,log_scale=True,plot=True, hist_range=None):
    '''
    Module to compute and plot histogram for pixel intensity of images
    Has 2 modes : simple and avg
        simple mode: No errors. Just flatten the input image array and compute histogram of full data
        avg mode(Default) : 
            - Compute histogram for each image in the image array
            - Compute errors across each histogram 
    '''
    
    norm=normalize # Whether to normalize the histogram
    
    if plot: 
        plt.figure()
        plt.xlabel('Pixel value')
        plt.ylabel('Counts')
        plt.title('Pixel Intensity Histogram')

        if log_scale: plt.yscale('log')
    
    if mode=='simple':
        hist, bin_edges = np.histogram(img_arr.flatten(), bins=bins, density=norm, range=hist_range)
        centers = (bin_edges[:-1] + bin_edges[1:]) / 2

        if plot: plt.errorbar(centers, hist, fmt='o-', label=label)
        return hist,None
    
    elif mode=='avg': 
        ### Compute histogram for each image. 
        mean,err,centers=f_batch_histogram(img_arr,bins,norm,hist_range)

        if plot: plt.errorbar(centers,mean,yerr=err,fmt='o-',label=label)  
        return mean,err
    
def f_compare_pixel_intensity(img_lst,label_lst=['img1','img2'],bkgnd_arr=[],log_scale=True, normalize=True, mode='avg',bins=25, hist_range=None):
    '''
    Module to compute and plot histogram for pixel intensity of images
    Has 2 modes : simple and avg
    simple mode: No errors. Just flatten the input image array and compute histogram of full data
    avg mode(Default) : 
        - Compute histogram for each image in the image array
        - Compute errors across each histogram 
        
    bkgnd_arr : histogram of this array is plotting with +/- sigma band
    '''
    
    norm=normalize # Whether to normalize the histogram
    
    def f_batch_histogram(img_arr,bins,norm,hist_range):
        ''' Compute histogram statistics for a batch of images'''
        
        ## Extracting the range. This is important to ensure that the different histograms are compared correctly
        if hist_range==None : ulim,llim=np.max(img_arr),np.min(img_arr)
        else: ulim,llim=hist_range[1],hist_range[0]
#         print(ulim,llim)
        ### array of histogram of each image
        hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
        hist=np.stack(hist_arr[:,0]) # First element is histogram array
#         print(hist.shape)

        bin_list=np.stack(hist_arr[:,1]) # Second element is bin value 
        ### Compute statistics over histograms of individual images
        mean,err=np.mean(hist,axis=0),np.std(hist,axis=0)/np.sqrt(hist.shape[0])
        bin_edges=bin_list[0]
        centers = (bin_edges[:-1] + bin_edges[1:]) / 2
#         print(bin_edges,centers)

        return mean,err,centers
    
    plt.figure()
    
    ## Plot background distribution
    if len(bkgnd_arr):
        if mode=='simple':
            hist, bin_edges = np.histogram(bkgnd_arr.flatten(), bins=bins, density=norm, range=hist_range)
            centers = (bin_edges[:-1] + bin_edges[1:]) / 2
            plt.errorbar(centers, hist, color='k',marker='*',linestyle=':', label='bkgnd')

        elif mode=='avg':
            ### Compute histogram for each image. 
            mean,err,centers=f_batch_histogram(bkgnd_arr,bins,norm,hist_range)
            plt.plot(centers,mean,linestyle=':',color='k',label='bkgnd')
            plt.fill_between(centers, mean - err, mean + err, color='k', alpha=0.4)
    
    ### Plot the rest of the datasets
    for img,label,mrkr in zip(img_lst,label_lst,itertools.cycle('>^*sDHPdpx_')):     
        if mode=='simple':
            hist, bin_edges = np.histogram(img.flatten(), bins=bins, density=norm, range=hist_range)
            centers = (bin_edges[:-1] + bin_edges[1:]) / 2
            plt.errorbar(centers, hist, fmt=mrkr+'-', label=label)

        elif mode=='avg':
            ### Compute histogram for each image. 
            mean,err,centers=f_batch_histogram(img,bins,norm,hist_range)
#             print('Centers',centers)
            plt.errorbar(centers,mean,yerr=err,fmt=mrkr+'-',label=label)

    if log_scale: 
        plt.yscale('log')
        plt.xscale('symlog',linthreshx=50)

    plt.legend()
    plt.xlabel('Pixel value')
    plt.ylabel('Counts')
    plt.title('Pixel Intensity Histogram')

### Spectral modules

In [5]:
## numpy code
def f_radial_profile_3d(data, center=(None,None)):
    ''' Module to compute radial profile of a 2D image '''
    
    z, y, x = np.indices((data.shape)) # Get a grid of x and y values
    
    center=[]
    if not center:
        center = np.array([(x.max()-x.min())/2.0, (y.max()-y.min())/2.0, (z.max()-z.min())/2.0]) # compute centers
        
    # get radial values of every pair of points
    r = np.sqrt((x - center[0])**2 + (y - center[1])**2+ + (z - center[2])**2)
    r = r.astype(np.int)
    
    # Compute histogram of r values
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel()) 
    radialprofile = tbin / nr
    
    return radialprofile[1:-1]

def f_compute_spectrum_3d(arr):
    '''
    compute spectrum for a 3D image
    '''
#     GLOBAL_MEAN=1.0
#     arr=((arr - GLOBAL_MEAN)/GLOBAL_MEAN)
    y1=np.fft.fftn(arr)
    y1=np.fft.fftshift(y1)
#     print(y1.shape)
    y2=abs(y1)**2
    z1=f_radial_profile_3d(y2)
    return(z1)
   
def f_batch_spectrum_3d(arr):
    batch_pk=np.array([f_compute_spectrum_3d(i) for i in arr])
    return batch_pk

### Code ###
def f_image_spectrum_3d(x,num_channels):
    '''
    Compute spectrum when image has a channel index
    Data has to be in the form (batch,channel,x,y)
    '''
    mean=[[] for i in range(num_channels)]    
    sdev=[[] for i in range(num_channels)]    

    for i in range(num_channels):
        arr=x[:,i,:,:,:]
#         print(i,arr.shape)
        batch_pk=f_batch_spectrum_3d(arr)
#         print(batch_pk)
        mean[i]=np.mean(batch_pk,axis=0)
        sdev[i]=np.var(batch_pk,axis=0)
    mean=np.array(mean)
    sdev=np.array(sdev)
    return mean,sdev


def f_plot_spectrum_3d(img_arr,plot=False,label='input',log_scale=True):
    '''
    Module to compute Average of the 1D spectrum for a batch of 3d images
    '''
    num = img_arr.shape[0]
    Pk = f_batch_spectrum_3d(img_arr)

    #mean,std = np.mean(Pk, axis=0),np.std(Pk, axis=0)/np.sqrt(Pk.shape[0])
    mean,std = np.mean(Pk, axis=0),np.std(Pk, axis=0)
    k=np.arange(len(mean))
    
    if plot: 
        plt.figure()
        plt.plot(k, mean, 'k:')
        plt.plot(k, mean + std, 'k-',label=label)
        plt.plot(k, mean - std, 'k-')
    #     plt.xscale('log')
        if log_scale: plt.yscale('log')
        plt.ylabel(r'$P(k)$')
        plt.xlabel(r'$k$')
        plt.title('Power Spectrum')
        plt.legend()

    return mean,std


def f_compare_spectrum_3d(img_lst,label_lst=['img1','img2'],bkgnd_arr=[],log_scale=True):
    '''
    Compare the spectrum of 2 sets s: 
    img_lst contains the set of images arrays, Each is of the form (num_images,height,width)
    label_lst contains the labels used in the plot
    '''
    plt.figure()
    
    ## Plot background distribution
    if len(bkgnd_arr):
        Pk= f_batch_spectrum_3d(bkgnd_arr)
        mean,err = np.mean(Pk, axis=0),np.std(Pk, axis=0)/np.sqrt(Pk.shape[0])
        k=np.arange(len(mean))
        plt.plot(k, mean,color='k',linestyle='-',label='bkgnd')    
        plt.fill_between(k, mean - err, mean + err, color='k',alpha=0.8)
    
    
    for img_arr,label,mrkr in zip(img_lst,label_lst,itertools.cycle('>^*sDHPdpx_')): 
        Pk= f_batch_spectrum_3d(img_arr)
        mean,err = np.mean(Pk, axis=0),np.std(Pk, axis=0)/np.sqrt(Pk.shape[0])

        k=np.arange(len(mean))
#         print(mean.shape,std.shape)
        plt.fill_between(k, mean - err, mean + err, alpha=0.4)
        plt.plot(k, mean, marker=mrkr, linestyle=':',label=label)

    if log_scale: plt.yscale('log')
    plt.ylabel(r'$P(k)$')
    plt.xlabel(r'$k$')
    plt.title('Power Spectrum')
    plt.legend()  
    


### Read data

In [9]:
# fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d/20210111_104029_3d_/images/best_hist_epoch-8_step-13530.npy'
fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/results_from_other_code/pytorch/results/3d/20210210_060657_3d_l0.5_80k/images/gen_img_epoch-15_step-37730.npy'
a1=np.load(fname)[:,0,:,:,:]

fname='/global/cfs/cdirs/m3363/vayyar/cosmogan_data/raw_data/3d_data/dataset1_smoothing_const_params_64cube_100k/val.npy'
val_arr=np.load(fname,mmap_mode='r')[-500:,0,:,:,:]
print(a1.shape,val_arr.shape)

val_arr=f_transform(val_arr)

(32, 64, 64, 64) (500, 64, 64, 64)


In [10]:
np.max(val_arr),np.max(a1)

(0.9906874, 0.9971457)

### Histogram

In [11]:
_,_=f_pixel_intensity(a1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important


In [12]:
img_lst=[a1,val_arr]
label_lst=['a1','val']
f_compare_pixel_intensity(img_lst,label_lst=label_lst,bkgnd_arr=[],log_scale=True, normalize=True, mode='avg',bins=25, hist_range=None)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  hist_arr=np.array([np.histogram(arr.flatten(), bins=bins, range=(llim,ulim), density=norm) for arr in img_arr]) ## range is important
  plt.xscale('symlog',linthreshx=50)


## Power spectrum

In [28]:
a1.shape

(32, 1, 64, 64, 64)

In [34]:
# f_image_spectrum_3d(a1,1)
f_compute_spectrum_3d(a1[0])

array([7.08387989e+06, 3.28201950e+06, 1.96989945e+06, 1.29067951e+06,
       8.88855280e+05, 5.79303274e+05, 4.77015606e+05, 3.58533467e+05,
       2.55271201e+05, 1.87804421e+05, 1.52941311e+05, 1.20373592e+05,
       8.99020955e+04, 6.93547466e+04, 5.53105024e+04, 4.32803411e+04,
       3.21364362e+04, 2.45587023e+04, 1.87231209e+04, 1.53876053e+04,
       1.13001430e+04, 9.40523400e+03, 7.63251142e+03, 6.31246125e+03,
       5.21533361e+03, 4.41272406e+03, 3.82497734e+03, 3.62602297e+03,
       3.14059634e+03, 2.73483562e+03, 3.66875483e+03, 2.10761531e+03,
       1.63054553e+03, 1.35017626e+03, 1.47162297e+03, 1.06022791e+03,
       1.01064532e+03, 9.81427116e+02, 9.53277740e+02, 8.38876481e+02,
       8.19499169e+02, 8.21931962e+02, 9.07073995e+02, 2.15897225e+03,
       6.72364738e+02, 6.55671060e+02, 1.36135356e+03, 5.34185661e+02,
       5.79213578e+02, 8.48842685e+02, 5.68254283e+02, 7.06984838e+02,
       6.71464160e+02])

In [35]:
# f_image_spectrum_3d(a1,1)

In [13]:
_,_=f_plot_spectrum_3d(val_arr[:50],plot=True)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
img_lst=[a1,val_arr]
f_compare_spectrum_3d(img_lst)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
f_plot_spectrum_3d(a1,plot=False,label='input',log_scale=True)

(array([6.96167903e+06, 2.80630157e+06, 1.67154431e+06, 1.06305220e+06,
        7.13692333e+05, 4.74283132e+05, 3.92096050e+05, 2.90817689e+05,
        2.12122431e+05, 1.54404535e+05, 1.25565492e+05, 9.99182558e+04,
        7.37214371e+04, 5.73612980e+04, 4.67560598e+04, 3.70876773e+04,
        2.70082463e+04, 2.09121164e+04, 1.57593858e+04, 1.29648434e+04,
        9.58378518e+03, 7.98290854e+03, 6.47732760e+03, 5.41248283e+03,
        4.46644971e+03, 3.78834905e+03, 3.28060415e+03, 3.13369668e+03,
        2.69492795e+03, 2.33735546e+03, 3.19385736e+03, 1.79852550e+03,
        1.40233130e+03, 1.17182419e+03, 1.27015575e+03, 9.21219968e+02,
        8.65756896e+02, 8.55088268e+02, 8.25483017e+02, 7.21995490e+02,
        7.07217585e+02, 7.01261455e+02, 7.71470596e+02, 1.91281284e+03,
        5.84742510e+02, 5.79849086e+02, 1.25191567e+03, 4.69943191e+02,
        5.03094955e+02, 7.47896999e+02, 4.97527664e+02, 6.04793264e+02,
        5.89144310e+02]),
 array([1.49123570e+06, 8.11238682e+05

In [16]:
f_batch_spectrum_3d(a1)

array([[7.08387989e+06, 3.28201950e+06, 1.96989945e+06, ...,
        5.68254283e+02, 7.06984838e+02, 6.71464160e+02],
       [6.14720807e+06, 3.14619459e+06, 1.80444738e+06, ...,
        6.44194892e+02, 7.36690274e+02, 7.11681808e+02],
       [8.05727954e+06, 2.48402353e+06, 1.41107729e+06, ...,
        4.05522432e+02, 5.08364180e+02, 5.32853469e+02],
       ...,
       [6.63576575e+06, 2.25974970e+06, 1.28140746e+06, ...,
        3.92831149e+02, 4.91930796e+02, 4.86510721e+02],
       [8.26878881e+06, 3.62144638e+06, 2.23273761e+06, ...,
        5.96735444e+02, 7.28725161e+02, 6.51105340e+02],
       [7.46938330e+06, 2.79671083e+06, 1.73894088e+06, ...,
        4.83933536e+02, 6.31198788e+02, 5.70291443e+02]])

In [19]:
a1.shape

(32, 64, 64, 64)

In [21]:
print(a1.shape)
f_plot_spectrum_3d(a1,plot=False,label='input',log_scale=True)

(32, 64, 64, 64)


(array([6.96167903e+06, 2.80630157e+06, 1.67154431e+06, 1.06305220e+06,
        7.13692333e+05, 4.74283132e+05, 3.92096050e+05, 2.90817689e+05,
        2.12122431e+05, 1.54404535e+05, 1.25565492e+05, 9.99182558e+04,
        7.37214371e+04, 5.73612980e+04, 4.67560598e+04, 3.70876773e+04,
        2.70082463e+04, 2.09121164e+04, 1.57593858e+04, 1.29648434e+04,
        9.58378518e+03, 7.98290854e+03, 6.47732760e+03, 5.41248283e+03,
        4.46644971e+03, 3.78834905e+03, 3.28060415e+03, 3.13369668e+03,
        2.69492795e+03, 2.33735546e+03, 3.19385736e+03, 1.79852550e+03,
        1.40233130e+03, 1.17182419e+03, 1.27015575e+03, 9.21219968e+02,
        8.65756896e+02, 8.55088268e+02, 8.25483017e+02, 7.21995490e+02,
        7.07217585e+02, 7.01261455e+02, 7.71470596e+02, 1.91281284e+03,
        5.84742510e+02, 5.79849086e+02, 1.25191567e+03, 4.69943191e+02,
        5.03094955e+02, 7.47896999e+02, 4.97527664e+02, 6.04793264e+02,
        5.89144310e+02]),
 array([1.49123570e+06, 8.11238682e+05

In [22]:
f_compute_spectrum_3d(a1[0])

array([7.08387989e+06, 3.28201950e+06, 1.96989945e+06, 1.29067951e+06,
       8.88855280e+05, 5.79303274e+05, 4.77015606e+05, 3.58533467e+05,
       2.55271201e+05, 1.87804421e+05, 1.52941311e+05, 1.20373592e+05,
       8.99020955e+04, 6.93547466e+04, 5.53105024e+04, 4.32803411e+04,
       3.21364362e+04, 2.45587023e+04, 1.87231209e+04, 1.53876053e+04,
       1.13001430e+04, 9.40523400e+03, 7.63251142e+03, 6.31246125e+03,
       5.21533361e+03, 4.41272406e+03, 3.82497734e+03, 3.62602297e+03,
       3.14059634e+03, 2.73483562e+03, 3.66875483e+03, 2.10761531e+03,
       1.63054553e+03, 1.35017626e+03, 1.47162297e+03, 1.06022791e+03,
       1.01064532e+03, 9.81427116e+02, 9.53277740e+02, 8.38876481e+02,
       8.19499169e+02, 8.21931962e+02, 9.07073995e+02, 2.15897225e+03,
       6.72364738e+02, 6.55671060e+02, 1.36135356e+03, 5.34185661e+02,
       5.79213578e+02, 8.48842685e+02, 5.68254283e+02, 7.06984838e+02,
       6.71464160e+02])

In [23]:
a1[0].shape

(64, 64, 64)