# Code describing the result of using pre-normalization (NMAD) of images
April 23, 2020



In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import subprocess as sp
import os
import sys
import time

In [2]:
%matplotlib widget

In [9]:
# sys.path.append('3_analysis/')
from modules_image_analysis import *

### Implementing MAD: Median Absolute Deviation
https://en.wikipedia.org/wiki/Median_absolute_deviation

$ MAD=Median \left(X-\tilde{X} \right) \ \ $   where $ \tilde{ X }$ is the median of array

$ \sigma = k . MAD $

## Implement normalization with actual data

Procedure for normalization : 

For each sample (3 image types)
- Computed sigma using MAD method on diff image
- Divide entire sample by that value
- Some images give zero sigma, ignore normalization for these


In [14]:
### This is the code to perform the pre-normalization. The file save is being disabled. The actual code 6_pre_norm.py should be used to perform the actual operation.

#### save_location='/global/project/projectdirs/dasrepo/vpa/supernova_cnn/data/gathered_data/input_npy_files/'
f1='full_x.npy'
f2='renorm_full_x.npy'

### Read data from .npy file
ip_arr=np.load(save_location+f1)
print(ip_arr.shape)

def f_rescale_samples(samples):
    ''' Rescale individual images with MAD value of diff image
    '''
    def f_mad(arr):
        '''
        Compute MAD and std
        '''
        arr2=arr.flatten()
        MD=np.median(arr2)
    #     print(MD)
        mad=np.median(np.abs(arr2-MD))
        k=1.4826 ### For normal distribution
        sigma=mad*k

        return mad,sigma
    
    
    scaled_samples=np.ones_like(samples)
    lst_zeros=[] # List to store indices where the MAD value is zero
    for i,row in enumerate(samples):
        scale=f_mad(row[:,:,2])[1]
        if scale<1e-10: 
            print("Small value",i,scale)
#             print(i,row.shape,f_mad(row[:,:,0]),f_mad(row[:,:,1]),f_mad(row[:,:,2]))
            lst_zeros.append(i)
            scale=1.0
        scaled_samples[i]=row*(1.0/scale)
    
    ### For every row, compute the MAD value for diff image (idx =2 ) and multiple its inverse to each sample
#     scaled_samples=np.array([(1.0/f_mad(i[:,:,2])[1]+1e-6)*i for i in samples])
    
    return scaled_samples,lst_zeros

t1=time.time()
rescaled_arr,zero_lst=f_rescale_samples(ip_arr[:10000])
t2=time.time()
print(t2-t1)
print('Number of zero median images',len(zero_lst))
print(rescaled_arr.shape)

# !!! Don't write to file. For that use the file 6_pre_norm.py
############## np.save(save_location+f2,rescaled_arr)

(898963, 51, 51, 3)
Small value 1808 0.0
Small value 1856 0.0
Small value 3176 0.0
Small value 5066 0.0
2.226393222808838
Number of zero median images 4
(10000, 51, 51, 3)


### Analyzing the zero NMAD valued points
- There are a few 761 (0.08%) images for which the diff images have a 0 NMAD value.
- The reason is because these images have most pixels with the same value. (typically 0 value, sometimes 1). As a result the median is equal to this value. So, the deviation gives a zero.
- We are using a rescale factor of 1.0 for these images (since we can't divide by 0).
- We don't filter out such images because there are a few good images that could be caught in this.

In [15]:
zero_lst
# zero_imgs=ip_arr[zero_lst][:,:,:,2]
# print(zero_imgs.shape)

[1808, 1856, 3176, 5066]

In [16]:
def f_plot_grid(arr,cols=16):
    
    size=arr.shape[0]
    assert cols<=size, "cols %s greater than array size %s"%(cols,size)
    
    rows=int(np.ceil(size/cols))
#     print(rows,cols)
    
    fig,axarr=plt.subplots(rows,cols,figsize=(8,8),constrained_layout=True)
    for i in range(rows*cols):
        row,col=int(i/cols),i%cols
#         print(i,'\t',row,col)
        try: 
            axarr[row,col].imshow(arr[i])
    #         axarr[row,col].imshow(arr[i],origin='lower',interpolation='nearest',cmap='cool', extent = [0, 128, 0, 128])
    #         fig.subplots_adjust(left=0.01,bottom=0.01,right=0.1,top=0.1,wspace=0.001,hspace=0.0001)
    #         fig.tight_layout()
        # Drop axis label
        except: 
            pass
        temp=plt.setp([a.get_xticklabels() for a in axarr[:-1,:].flatten()], visible=False)
        temp=plt.setp([a.get_yticklabels() for a in axarr[:,1:].flatten()], visible=False)



In [17]:
img_arr=np.array([ip_arr[idx][:,:,2] for idx in [1806,1807,1808,1809]])
# img_arr=np.array([ip_arr[idx][:,:,2] for idx in [10168,10167,10169,10170]])
# img_arr=np.array([ip_arr[idx][:,:,2] for idx in zero_lst[-13:-1]])

f_plot_grid(img_arr,cols=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
f_plot_intensity_grid(img_arr,cols=2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …