In [None]:
import sys
sys.path.append('/Users/theoglauch/Documents/PhD/DeepIceLearning/lib/')
from model_parse import parse_functional_model
import numpy as np
from icecube import icetray
from I3Tray import I3Tray
from configparser import ConfigParser
from functions_create_dataset import *
from collections import OrderedDict
import time
from icecube.dataclasses import I3MapStringDouble
from icecube import dataclasses


class DeepLearningClassifier(icetray.I3ConditionalModule):
    def __init__(self,context):
        icetray.I3ConditionalModule.__init__(self, context)
        self.AddParameter("pulsemap","Define the name of the pulsemap",
                          "InIceDSTPulses")
        self.AddParameter("save_as", "Define the Output key",
                          "Deep_Learning_Classification")
        print('Init Deep Learning Classifier..this may take a while')
        
    def Configure(self):
        #This is called before frames start propagation through IceTray
        self._runinfo = np.load('./run_info.npy', allow_pickle=True)[()]
        self._grid = np.load('./grid.npy', allow_pickle=True)[()]
        self._inp_shapes = self._runinfo['inp_shapes'] 
        self._out_shapes = self._runinfo['out_shapes']
        self._inp_trans = self._runinfo['inp_trans']
        self._out_trans = self._runinfo['out_trans']
        import model as func_model_def
        self._model = func_model_def.model(self._inp_shapes, self._out_shapes)
        self._model.load_weights('./weights.npy')
        dataset_configparser = ConfigParser()
        dataset_configparser.read('./config.cfg')
        inp_defs = dict()
        for key in dataset_configparser['Input_Times']:
            inp_defs[key] = dataset_configparser['Input_Times'][key]
        for key in dataset_configparser['Input_Charges']:
            inp_defs[key] = dataset_configparser['Input_Charges'][key]
        self._inputs = []
        for key in self._inp_shapes.keys():
            binput = []
            branch = self._inp_shapes[key]
            for bkey in branch.keys():
                if bkey == 'general':
                    continue
                elif 'charge_quantile' in bkey:
                    feature = 'pulses_quantiles(charges, times, {})'.format(float('0.' + bkey.split('_')[3]))
                else:
                    feature = inp_defs[bkey.replace('IC_','')]
                trans = self._inp_trans[key][bkey]
                binput.append((feature, trans))
            self._inputs.append(binput)


    def Physics(self, frame):
        timer_t0 = time.time()
        #This runs on P-frames
        key = self.GetParameter("pulsemap")
        f_slice = []
        t0 = get_t0(frame)
        pulses = frame[key].apply(frame)
        for key in self._inp_shapes.keys():
            f_slice.append(np.zeros(self._inp_shapes[key]['general']))
        for omkey in pulses.keys():
            dom = (omkey.string, omkey.om)
            if not dom in self._grid.keys():
                continue
            gpos = self._grid[dom]
            charges = np.array([p.charge for p in pulses[omkey][:]])
            times = np.array([p.time for p in pulses[omkey][:]]) - t0
            widths = np.array([p.width for p in pulses[omkey][:]])
            for branch_c, inp_branch in enumerate(self._inputs):
                for inp_c, inp in enumerate(inp_branch):
                    f_slice[branch_c][gpos[0]][gpos[1]][gpos[2]][inp_c] = inp[1](eval(inp[0]))
        prediction = self._model.predict(np.array(f_slice, ndmin=5), batch_size=None, verbose=0,
                                   steps=None)
        output = I3MapStringDouble()
        output['Skimming'] = float(prediction[0][0])
        output['Cascade'] = float(prediction[0][1])
        output['Through_Going_Track'] = float(prediction[0][2])
        output['Starting_Track'] = float(prediction[0][3])
        output['Stopping_Track'] = float(prediction[0][4])
        frame.Put(self.GetParameter("save_as"), output)
        print('Time {:.2f}'.format(time.time() - timer_t0))
        print prediction
        self.PushFrame(frame)

    def DAQ(self,frame):
        #This runs on Q-Frames
        self.PushFrame(frame)

    def Finish(self):
        #Here we can perform cleanup work (closing file handles etc.)
        pass

In [None]:
tray = I3Tray()
tray.AddModule('I3Reader','reader',
               FilenameList = ['/Users/theoglauch/Documents/PhD/I3Files/diffuse/11069_2289.i3.bz2'])

tray.AddModule(DeepLearningClassifier, "DeepLearningClassifier")
tray.Execute()
tray.Finish()