In [1]:
# Before you use this make sure which pax version are you using

import pax, hax
INFO = """
Pax
version: {pax_version}
build: {pax_file}
Hax
version: {hax_version}
build: {hax_file}
"""

INFO = INFO.format(
pax_version = pax.__version__,
pax_file = pax.__file__,
hax_version = hax.__version__,
hax_file = hax.__file__
)

print(INFO)


Pax
version: 6.10.1
build: /project/lgrandi/anaconda3/envs/pax_head/lib/python3.4/site-packages/pax-6.10.1-py3.4.egg/pax/__init__.py
Hax
version: 2.5.0
build: /project/lgrandi/anaconda3/envs/pax_head/lib/python3.4/site-packages/hax-2.5.0-py3.4.egg/hax/__init__.py



In [7]:
import os, sys, time
import numpy as np
from numpy import sqrt, exp, pi, square
import pandas as pd
pd.options.mode.chained_assignment = None        # default='warn'
import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'   # enable if you have a retina display
from scipy.optimize import curve_fit, minimize
from scipy.interpolate import interp1d, UnivariateSpline
import warnings
warnings.filterwarnings('ignore')
from multihist import Histdd, Hist1d
from tqdm import tqdm_notebook as tqdm
from multiprocessing import Pool
from contextlib import contextmanager

def plt_config(title=None, xlim=None, ylim=None, xlabel=None, ylabel=None, colorbar=False, sci=False, art=plt):
    for field in ['title', 'xlim', 'ylim', 'xlabel', 'ylabel']:
        if eval(field) != None: getattr(art, field)(eval(field))
    if isinstance(sci, str): art.ticklabel_format(style='sci', axis=sci, scilimits=(0,0))
    if isinstance(colorbar,str): art.colorbar(label=colorbar)
    elif colorbar: art.colorbar(label = '$Number\ of\ Entries$')

@contextmanager
def initiate_plot(dimx=24, dimy=9):
    plt.rcParams['figure.figsize'] = (dimx, dimy)
    global fig; fig = plt.figure()
    yield
    plt.show()
    
### Plotting ####
params = {
    'backend': 'Agg',
    # colormap
    'image.cmap' : 'viridis',
    # figure
    'figure.figsize' : (4, 2),
    'font.size' : 32,
    'font.family' : 'serif',
    'font.serif' : ['Times'],
    # axes
    'axes.titlesize' : 42,
    'axes.labelsize' : 32,
    'axes.linewidth' : 2,
    # ticks
    'xtick.labelsize' : 24,
    'ytick.labelsize' : 24,
    'xtick.major.size' : 16,
    'xtick.minor.size' : 8,
    'ytick.major.size' : 16,
    'ytick.minor.size' : 8,
    'xtick.major.width' : 2,
    'xtick.minor.width' : 2,
    'ytick.major.width' : 2,
    'ytick.minor.width' : 2,
    'xtick.direction' : 'in',
    'ytick.direction' : 'in',
    # markers
    'lines.markersize' : 12,
    'lines.markeredgewidth' : 3,
    'errorbar.capsize' : 8,
    'lines.linewidth' : 3,
    #'lines.linestyle' : None,
    'lines.marker' : None,
    'savefig.bbox' : 'tight',
    'legend.fontsize' : 24,
    'axes.labelsize': 32,
    'axes.titlesize' : 32,
    'xtick.labelsize' : 25,
    'ytick.labelsize' : 25,
    'xtick.major.pad' : 10,
    'text.latex.unicode': True,
}
plt.rcParams.update(params)
plt.rc('text', usetex=False)
    
if not hax.config:
    hax.init(raw_data_access_mode = 'local',
         raw_data_local_path = ['/project/lgrandi/xenon1t/raw_for_waveforms/'], 
         main_data_paths= ['/project2/lgrandi/xenon1t/processed/pax_v6.10.0/',''],
         minitree_paths= ['/scratch/midway2/zhut/data/SingleScatter/data/minitrees/'],
         pax_version_policy = 'loose',
         use_rundb=True,
         make_minitree=True,
         minitree_caching=False,)
    
datasets = hax.runs.datasets

In [8]:
run_number = 11797
event_number = 1
name = datasets[datasets.number==run_number].name.values[0]

In [9]:
from pax import core
core_processor = core.Processor(config_names=('_base','XENON1T'),
    just_testing=False,
    config_dict={
        'pax':{
            'input_name' : '/project/lgrandi/xenon1t/raw_for_waveforms/%s'%name,
            'events_to_process' : [event_number],
            'output' : 'Dummy.DummyOutput',
            'pre_output' : [],
            'encoder_plugin' : None,
            }}
)

processor MainProcess L66 INFO This is PAX version 6.10.1, running with configuration for XENON1T.
ReadZipped MainProcess L102 INFO InputFromFolder: Selecting file /project/lgrandi/xenon1t/raw_for_waveforms/170804_0004/XENON1T-11797-000000000-000000999-000001000.zip (number 1/1 in folder) for reading


In [10]:
events = core_processor.get_events()
event = next(events)
event = core_processor.process_event(event)
core_processor.make_timing_report(1)

processor MainProcess L372 INFO Timing report:
+------------------------------+-------+-------------+------+-----------+
| Plugin                       |     % | /event (ms) |  #/s | Total (s) |
+------------------------------+-------+-------------+------+-----------+
| ReadZipped                   |     0 |           0 |  n/a |       0.0 |
| DecodeZPickle                |   5.1 |        97.6 | 10.2 |       0.1 |
| SortPulses                   |   0.2 |         3.8 |      |       0.0 |
| PulseProperties              |   1.1 |        21.7 | 46.1 |       0.0 |
| CheckBoundsAndCount          |   0.6 |        10.7 | 93.1 |       0.0 |
| DesaturatePulses             |   5.7 |       109.7 |  9.1 |       0.1 |
| FindHits                     |   7.5 |       144.8 |  6.9 |       0.1 |
| HitfinderDiagnosticPlots     |   0.0 |         0.1 |      |       0.0 |
| SumWaveform                  |   6.6 |       127.6 |  7.8 |       0.1 |
| GapSizeClustering            |   6.4 |       124.1 |  8.1 |    

In [11]:
plot_2d_waveform(event, xlim=[990, 1100])
plot_hit_pattern(possible_interaction_peaks)

NameError: name 'plot_2d_waveform' is not defined

In [253]:
from MultipleS2Peaks import MultipleS2Peaks

ms2p = MultipleS2Peaks()

def classify(peak, s1, s2):
    if peak.type == 's1':
        return peak.type
    if peak.type == 's2':
        if peak.area<150:
            return 'single_e'
        drift_time = (peak.index_of_maximum - s1.index_of_maximum) * ms2p.sample_duration
        z = - ms2p.drift_velocity_liquid * (drift_time -ms2p.drift_time_gate)
        for rp in peak.reconstructed_positions:
            if rp.algorithm == 'PosRecTopPatternFit':
                gof = getattr(rp, 'goodness_of_fit')
        ans = ms2p.determine_interaction(pd.DataFrame([dict(z=z, 
                                                     area=peak.area, 
                                                     goodness_of_fit_tpf=gof,
                                                     range_50p_area=list(peak.range_area_decile)[5],
                                                     s2=s2.area
                                                    )]))
        if not ans.not_interaction.values[0]:
            return 's2'
        else:
            return 'e_train'

############## Plotting while processing #######################
def hit_color(hits):
    color_factor = np.clip(hits.height/hits.noise_sigma, 0, 15)/15
    is_rejected = hits.is_rejected.astype(int)
    
    rgba_colors = np.zeros((len(hits), 4))
    rgba_colors[:, 0] = (1 - is_rejected) * color_factor
    rgba_colors[:, 1] = is_rejected
    rgba_colors[:, 2] = (1 - is_rejected) * (1 - color_factor)
    rgba_colors[:, 3] = 0.75
    
    return rgba_colors

def plot_2d_waveform(event, xlim):

    global possible_interaction_peaks
    possible_interaction_peaks = []
    
    if len(event.interactions) < 1:
        print('No interaction in this event')
        return 0

    s1 = event.peaks[event.interactions[0].s1]
    s2 = event.peaks[event.interactions[0].s2]
        
    with initiate_plot(30,15):
        axm = plt.gca()
        axm.tick_params(axis='both', bottom='off', labelbottom='off', left='off', labelleft='off')
        pos = axm.get_position()
        top = [pos.x0, pos.y0+0.5*pos.height, pos.width, pos.height*0.5]
        bot = [pos.x0, pos.y0, pos.width, pos.height*0.5]

        ####################################################
        axt = plt.axes(top)
        w = event.get_sum_waveform('tpc').samples[:]
        time = np.arange(len(w))*0.01

        ymax = np.max(w)**1.5

        plt.yscale('symlog')
        plt.plot(time, w, color='k', lw=2.0)

        plt_config(xlim = xlim, ylim=[-1, ymax], ylabel='Amplitude [pe/bin]')
        axt.tick_params(axis='x', bottom='off', labelbottom='off', left='off', labelleft='off')

        ####################################################
        axb = plt.axes(bot)
        plt.axhline(127, color='k', zorder=0)
        for peak_i, peak in enumerate(event.peaks):
            if peak.detector != 'tpc': continue

            hits = pd.DataFrame(peak.hits)
            hits = hits[(hits.channel<248) & (hits.area>2)]
            
            axb.scatter(hits.index_of_maximum*0.01, hits.channel, 
                        c=hit_color(hits), edgecolor='none', s=10 * np.clip(hits.area, 0, 10))
            if (peak.area>100) and (peak.type == 's1' or peak.type == 's2'):

                x, y = peak.index_of_maximum*0.01, w[peak.index_of_maximum]
                ytext = np.random.uniform(y, min(ymax, y*5))

                if x>xlim[1] or x<xlim[0]:
                    continue
                axt.axvspan(peak.left*0.01, peak.right*0.01, color='grey', alpha=0.3)
                axb.axvspan(peak.left*0.01, peak.right*0.01, color='grey', alpha=0.3)

                text = classify(peak, s1, s2)
                axt.text(x, ytext, text)

                cmap = dict(s1='C0', s2='C1', e_train='C2', single_e='C4')
                axt.scatter(x, y, color=cmap[text])
                
                if text in ['s1', 's2']:
                    possible_interaction_peaks.append(dict(type=text, peak=peak))

        plt_config(xlim=xlim, ylim=[0, 249], xlabel='Time [$\mu s$]', ylabel='PMT channel')

In [231]:
def plot_hit_pattern(possible_interaction_peaks):
    n_plots = len(possible_interaction_peaks)

    pmts = {array: core_processor.config['DEFAULT']['channels_%s' % array] for array in ('top', 'bottom')}
    pmt_locations = np.array([[core_processor.config['DEFAULT']['pmts'][ch]['position']['x'],
        core_processor.config['DEFAULT']['pmts'][ch]['position']['y']]
        for ch in range(core_processor.config['DEFAULT']['n_channels'])])
    hitpattern_limits = (1e-1, 1e4)

    with initiate_plot(20, 6*n_plots):
        w = event.get_sum_waveform('tpc').samples[:]
        time = np.arange(len(w))*0.01

        scatter_plot_kwarg = dict(norm=matplotlib.colors.LogNorm(), vmin=hitpattern_limits[0],
            vmax=hitpattern_limits[1], alpha=0.4, s=300)


        for p_i, p in enumerate(possible_interaction_peaks[:]):
            t, peak = p['type'], p['peak']
            ax = fig.add_subplot(n_plots, 3, p_i*3+1)
            mid = time[peak.index_of_maximum]
            l, r = time[peak.left], time[peak.right]
            dis = max(mid-l, r-mid)

            plt.plot(time[peak.left:peak.right+1], w[peak.left:peak.right+1], color='k')
            plt.text(time[peak.index_of_maximum], w[peak.index_of_maximum]*1.1, t+':%.1f'%peak.area, ha='center')
            plt_config(ylim = [-10, np.max(w[peak.left:peak.right+1])*1.25], 
                       xlim = [mid-dis, mid+dis], 
                      )

            ############################################################
            r = core_processor.config['DEFAULT']['tpc_radius']
            if t == 's1':
                axs1 = fig.add_subplot(n_plots, 3, p_i*3+2)
                axs1.axis('off')
                axs1.add_artist(plt.Circle((0, 0), r, edgecolor='black', fill=None))
            else: 
                ax = fig.add_subplot(n_plots, 3, p_i*3+2)
                ax.axis('off')
                ax.add_artist(plt.Circle((0, 0), r, edgecolor='black', fill=None))

            pmts_hit = [ch for ch in pmts['top'] if peak.does_channel_contribute[ch]]
            q = plt.scatter(*pmt_locations[pmts_hit].T, c=peak.area_per_channel[pmts_hit],
                           **scatter_plot_kwarg)

            for pmt in pmts_hit:
                plt.text(pmt_locations[pmt, 0], pmt_locations[pmt, 1], pmt,
                        fontsize=8 if peak.is_channel_saturated[pmt] else 6,
                        va='center', ha='center',
                        color='white' if peak.is_channel_saturated[pmt] else 'black')

            if t == 's2':
                for rp in peak.reconstructed_positions:
                    if rp.algorithm == 'PosRecNeuralNet':
                        x_peak = getattr(rp, 'x')
                        y_peak = getattr(rp, 'y')

                        ax.plot([x_peak], [y_peak],
                                marker='x', color='C%d'%p_i, alpha=0.8, markersize=20, markeredgewidth=3)
                        axs1.plot([x_peak], [y_peak],
                                marker='x', color='C%d'%p_i, alpha=0.8, markersize=20, markeredgewidth=3)
            plt_config(xlim=[-53, 53], ylim=[-53, 53])

            ############################################################
            ax = fig.add_subplot(n_plots, 3, p_i*3+3)
            pmts_hit = [ch for ch in pmts['bottom'] if peak.does_channel_contribute[ch]]
            q = ax.scatter(*pmt_locations[pmts_hit].T, c=peak.area_per_channel[pmts_hit],
                           **scatter_plot_kwarg)
            ax.axis('off')
            ax.add_artist(plt.Circle((0, 0), r, edgecolor='black', fill=None))
            plt_config(xlim=[-53, 53], ylim=[-53, 53])

            for pmt in pmts_hit:
                plt.text(pmt_locations[pmt, 0], pmt_locations[pmt, 1], pmt,
                        fontsize=8 if peak.is_channel_saturated[pmt] else 6,
                        va='center', ha='center',
                        color='white' if peak.is_channel_saturated[pmt] else 'black')

In [None]:
from sklearn.mixture import GaussianMixture

means_init = [(5, 1), (4.3, 1.8), (3.9, 0.25)]

gmix = GaussianMixture(n_components=3, covariance_type='full', means_init=means_init, max_iter=200)
trsa = df_ms_peaks

def s2_width_model(z):
    w0 = 229.58  #  309.7/1.349
    coeff = 0.925
    dif_const = 31.73
    v = .1335
    return sqrt(square(w0) - 2 * coeff * dif_const * z / v ** 3) * 1.349

X = np.column_stack([np.log10(trsa.area), trsa.range_50p_area/s2_width_model(trsa.z)])
gmix.fit(X)


with initiate_plot(20, 12):
    labels_ = gmix.predict(X)
    type0 = trsa[labels_==0]; type1 = trsa[labels_==1]; type2 = trsa[labels_==2]
    
    argx, argy = '{df}.area', '{df}.range_50p_area/s2_width_model({df}.z)'

    for df_i, df in enumerate(['type0', 'type1', 'type2']):
        plt.scatter(eval(argx.format(df = df)), eval(argy.format(df = df)),
                    edgecolor='none', s=30, color='C%d'%df_i, alpha=0.3)

    plt_config(xlim=[10**2.3, 1e7], ylim=[0, 2.5])
    plt.xscale('log');


In [None]:
'''with open('/project2/lgrandi/zhut/s2_width_classifier_gmix_v6.10.0.pkl', 'wb') as f:
    pickle.dump(gmix, f)'''