# Code to view the ATLAS 2D data 

Using classes to view the data
September 17, 2019



In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import h5py

import subprocess as sp


In [4]:
%matplotlib widget

In [5]:
data_dir='/global/project/projectdirs/dasrepo/vpa/atlas_cnn/data/RPVSusyData/'

In [10]:
class data_set:
    ''' Simple class to store the data set 
    variables: labels, images, weights
    Modules: Getting summary, Plotting
    Example objects: train data, test data, validation data)
    '''
    
    def __init__(self, filename):
        self.filename=filename
        self.f_get_data()
        print("Created object from file ",filename)
        
    def f_get_data(self):
        '''
        Function to get data from hdf5 files into images, labels and weights.
        '''
        try: 
            hf = h5py.File(self.filename)

        except Exception as e:
            print(e)
            print("Name of file",self.filename)
            raise SystemError

        idx=None  ### Index for the case when the data is too large and you want to read in a slice
        self.images = np.expand_dims(hf['all_events']['hist'][:idx], -1)
        self.labels = hf['all_events']['y'][:idx]
        weights = hf['all_events']['weight'][:idx]
        self.weights = np.log(weights+1)

        
    def f_compute_summary(self):
        ''' Module to get a summary of data
        For Labels: signal, background ratio
        For images: zero images
        '''
        
        ### General summary
        print("Array shapes: Labels: %s, Weights: %s, Images: %s"%(self.labels.shape,self.weights.shape,self.images.shape))
        
        ### Label summary ###
        total=self.labels.shape[0]
        bkgnd=self.labels[self.labels==0.0].shape[0]
        signal=total-bkgnd
        print("Signal: {0}\tBKGND: {1}\tTotal: {2}".format(signal,bkgnd,total))
        
        print("Signal ratio: {0} % ".format(signal*100.0/total))

        ### Image summary ###
        
        ### Get a list of images that are zero
        zero_lst=[i for i in validation.images[:,:,:,0] if not np.any(i)]
        if not zero_lst: 
            print("No zero images")
        else :
            print("Zero images",len(zero_lst))
            print(zero_lst)
    
        
    def f_plot_labels(self):
        ''' Plots giving information on labels.
        1- Plot of all labels i.e. signal = 1.0, background =0.0  vs run time
        2- Histogram of labels. signal-background = 1.0/0.0
        3- Plot of signal vs signal locations 
        4- Histogram of signal locations
        
        '''
        
        vals=self.labels
        fig=plt.figure()
        fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.2, hspace=0.35)

        fig.add_subplot(2,2,1)
        plt.plot(vals,linestyle='',marker='*')
        plt.title("Plot of all labels")

        fig.add_subplot(2,2,2)
        n,bins,patches=plt.hist(vals, density=None, bins=3)
        plt.title("Histogram of background(=0) and signal")

        ### Getting the location of the signals
        sigs=np.where(vals==1.0)[0]

        fig.add_subplot(2,2,3)
        plt.plot(sigs[:],np.ones(sigs.shape),linestyle='',marker='o')
        plt.title("Plot of signal locations")

        fig.add_subplot(2,2,4)
        n,bins,patches=plt.hist(sigs, density=None, bins=50,color='c')
        plt.title('Histogram of signal locations')

    
    def f_plot_weights(self):
        '''
        Plot weights
        2 plots, general plot of weights and histogram
        '''
        vals=self.weights
        fig=plt.figure()
        fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.2, hspace=0.35)

        fig.add_subplot(1,2,1)
        plt.plot(vals,linestyle='',marker='o')
        plt.title("Plot of all weights")

        fig.add_subplot(1,2,2)
        n,bins,patches=plt.hist(vals, density=None, bins=20)
        plt.title("Histogram of weights")
    
    def f_plot_images(self,idx=0):
        ''' Plots 2D for 10 images
        input argument : idx = image number
        Gives 10 images from idx to idx+10
        '''
        
        rows,cols=2,5
        
        ### Get part of the image array to plot
        arr=self.images[idx:idx+10][:,:,:,0]
        
        fig,axarr=plt.subplots(rows,cols,figsize=(10,2))
        for i in range(rows*cols):
            row,col=int(i/cols),i%cols
        #     print(i,'\t',row,col)
            axarr[row,col].imshow(arr[i],origin='lower',alpha=0.9)
            axarr[row,col].set_xticks(np.arange(0,62,10))
            axarr[row,col].set_yticks(np.arange(0,22,10))

    #     fig.subplots_adjust(hspace=0.0)
        # Drop axis labels
        temp=plt.setp([a.get_xticklabels() for a in axarr[:-1,:].flatten()], visible=False)
        temp=plt.setp([a.get_yticklabels() for a in axarr[:,1:].flatten()], visible=False)



In [7]:
validation=data_set(data_dir+'val.h5')

Created object from file  /global/project/projectdirs/dasrepo/vpa/atlas_cnn/data/RPVSusyData/train.h5


In [9]:
# vars(data_set)

validation.f_compute_summary()
validation.f_plot_labels()
validation.f_plot_weights()
validation.f_plot_images()

Array shapes: Labels: (412416,), Weights: (412416,), Images: (412416, 64, 64, 1)
Signal: 175795	BKGND: 236621	Total: 412416
Signal ratio: 42.625649829298574 % 
No zero images


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
file_list=['train.h5','val.h5','test.h5']
dict1={}
for filename,key in zip([data_dir+ s for s in file_list],['train','val','test']):
#     print(filename,key)
    dict1[key]=data_set(filename)