In [1]:
import os
import sys
import gin
import numpy as np
import pandas as pd
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from IPython.core.display import clear_output, display

import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [53]:
import logging
logging.getLogger().setLevel(logging.DEBUG)
    
from eval.event_evaluation import EventEvaluator
from ariadne_v2.transformations import Compose, ConstraintsNormalize, ToCylindrical, DropSpinningTracks, DropShort, DropEmpty

parse_cfg = {
    'csv_params' : {
        "sep": '\s+',
        #"nrows": 15000,
        "encoding": 'utf-8',
        "names": ['event',  'x', 'y', 'z', 'station', 'track', 'px', 'py', 'pz', 'X0', 'Y0', 'Z0']
    },
    'input_file_mask': "C:\\Users\\egor\\dubna\\ariadne\\data_bes3\\3.txt",
    'events_quantity':'0..100'
}

global_transformer = Compose([
        DropSpinningTracks(),
        DropShort(num_stations=3),
        DropEmpty()
    ])

In [78]:
import scripts.clean_cache

#to clean cache if needed
#scripts.clean_cache.clean_jit_cache('1h')

# GraphNet

In [79]:
from ariadne.graph_net.graph_utils.graph_prepare_utils import to_pandas_graph_from_df, get_pd_line_graph, \
    apply_nodes_restrictions, apply_edge_restriction, construct_output_graph
from ariadne.transformations import Compose, ConstraintsNormalize, ToCylindrical

from ariadne_v2.inference import IModelLoader

import torch

class GraphModelLoader(IModelLoader):    
    def __call__(self):
        from ariadne.graph_net.model import GraphNet_v1
        import torch
        
        gin.bind_parameter('GraphNet_v1.input_dim', 5)
        gin.bind_parameter('GraphNet_v1.hidden_dim', 128)
        gin.bind_parameter('GraphNet_v1.n_iters', 1)
        
        def weights_update_g(model, checkpoint):
            model_dict = model.state_dict()    
            pretrained_dict =  checkpoint['state_dict']
            real_dict = {}
            for (k,v) in model_dict.items():
                needed_key = None
                for pretr_key in pretrained_dict:
                    if k in pretr_key:
                        needed_key = pretr_key
                        break
                assert needed_key is not None, "key %s not in pretrained_dict %r!" % (k, pretrained_dict.keys())
                real_dict[k] = pretrained_dict[needed_key]
        
            model.load_state_dict(real_dict)
            model.eval()
            return model
        
        #path_g = '/zfs/hybrilit.jinr.ru/user/g/gooldan/bes/ariadne/lightning_logs/version_63115/checkpoints/epoch=49.ckpt'
        #path_g = '/Users/egor/prog/dubna/models2606/zfs 4/hybrilit.jinr.ru/user/g/gooldan/bes/ariadne/lightning_logs/version_63115/checkpoints/epoch=49.ckpt'
        path_g = 'C:\\Users\\egor\\dubna\\ariadne\\lightning_logs\\GraphNet_v1\\version_32\\epoch=201-step=271689.ckpt'
        
        checkpoint_g = torch.load(path_g) if torch.cuda.is_available() else torch.load(path_g, map_location=torch.device('cpu'))
        model_g = weights_update_g(model=GraphNet_v1(), 
                           checkpoint=checkpoint_g)
        model_hash = {"path_g":path_g, 'gin':gin.config_str(), 'model': '%r' % model_g}
        return model_hash, model_g

suff_df = ('_p', '_c')
gin.bind_parameter('get_pd_line_graph.restrictions_0',(-0.07, 0.07))
gin.bind_parameter('get_pd_line_graph.restrictions_1', (-0.32, 0.32))
gin.bind_parameter('get_pd_line_graph.suffix_c', '_c')
gin.bind_parameter('get_pd_line_graph.suffix_p','_p')
gin.bind_parameter('get_pd_line_graph.spec_kwargs', {'suffix_c': '_c', 
                                                     'suffix_p':'_p', 
                                                     'axes':['r', 'phi', 'z']} )
_edge_restriction = 0.16

from collections import namedtuple
GraphWithIndices = namedtuple('Graph', ['X', 'Ri', 'Ro', 'y', 'v1v2v3', 'ev_id' ])

transformer_g = Compose([
    DropSpinningTracks(),
    DropShort(),
    DropEmpty(),
    ToCylindrical(),
    ConstraintsNormalize(
        columns=('r', 'phi', 'z'),
        constraints = {'phi': [-3.15, 3.15], 'r': [80.0, 167.0], 'z': [-423.5, 423.5]},
        use_global_constraints = True
    ),
])


def construct_graph_with_indices(graph, v1v2v3, ev_id):
    return GraphWithIndices(graph.X, graph.Ri, graph.Ro, graph.y, v1v2v3, ev_id)


def get_graph(event):
    event = event[['event','x','y','z','station','track', 'index_old']]
    
    try:
        event = transformer_g(event)
    except AssertionError as err:
        print("ASS error %r" % err)
        return None
    
    event.index = event['index_old'].values
    event =  event[['event','r','phi','z','station','track']]
    
    G = to_pandas_graph_from_df(event, suffixes=suff_df, compute_is_true_track=True)

    
    nodes_t, edges_t = get_pd_line_graph(G, apply_nodes_restrictions)

    edges_filtered = apply_edge_restriction(edges_t, edge_restriction=_edge_restriction)
    graph = construct_output_graph(nodes_t, edges_filtered, ['y_p', 'y_c', 'z_p', 'z_c', 'z'],
                                     [np.pi, np.pi, 1., 1., 1.], 'edge_index_p', 'edge_index_c')
    ev_id = event.event.values[0]
    graph_with_inds = construct_graph_with_indices(graph,
                                                   edges_filtered[['from_ind', 'cur_ind', 'to_ind']].values, ev_id)
    return graph_with_inds

from ariadne.graph_net.dataset import collate_fn
def eval_event(tgt_graph, model_g):
    batch_input, batch_target = collate_fn([tgt_graph])
    with torch.no_grad():
        y_pred = model_g(batch_input['inputs']).numpy().flatten() > 0.5
    
    eval_df = pd.DataFrame(columns=['track_pred', 'hit_id_0', 'hit_id_1', 'hit_id_2'])
    eval_df['track_pred'] = y_pred
    eval_df[['hit_id_0', 'hit_id_1', 'hit_id_2']] = tgt_graph.v1v2v3
    return eval_df


In [134]:
N_STATIONS = 3

evaluator = EventEvaluator(parse_cfg, global_transformer, N_STATIONS)
events = evaluator.prepare(model_loader=GraphModelLoader())[0]
all_tracks, all_events = evaluator.build_all_tracks()
reco_tracks, reco_events = evaluator.run_model(get_graph, eval_event)


read entry b235ba1f4d43a05f0736916b7e3f7bb4 hit
[prepare]: started processing a df 3.txt with 4601 rows:
read entry 013a46cb50fd316ba73140aa64898d77 hit
[prepare] finished
[prepare] loading your model(s)...
[prepare] finished loading your model(s)...
[build_all_tracks] start
read entry f073cf26fd4ba34f3ee42aee72ca10be hit
read entry 7ca8839d25ac3d884b5544ad8b8ae9c5 hit
[build_all_tracks] cache hit, finish
[run model] start
read entry fed44061f32622f56226346ef247293e hit
read entry 036c49b4b9637290fe40c0d2b93991fb hit
[run model] cache hit, finish


In [135]:
all_tracks

Unnamed: 0,event_id,track,px,py,pz,pred,hit_id_0,hit_id_1,hit_id_2
0,0,11,-0.193425,-0.004823,-0.029994,0,8,9,10
0,0,22,0.342074,-0.229190,0.254322,0,11,12,13
0,0,23,-0.258273,-0.097875,-0.528234,0,14,15,16
0,1,11,0.094377,0.492997,-0.320608,0,76,77,78
0,1,12,-0.302329,0.193706,0.509225,0,79,80,81
...,...,...,...,...,...,...,...,...,...
0,97,18,0.213483,0.201679,0.224834,0,4558,4559,4560
0,98,8,0.260138,-0.645867,0.196305,0,4589,4590,4591
0,98,9,-0.325851,0.927250,-0.192919,0,4592,4593,4594
0,98,11,0.614087,-0.697357,0.039403,0,4595,4596,4597


In [137]:
reco_tracks

Unnamed: 0,event_id,track_pred,hit_id_0,hit_id_1,hit_id_2
0,0,True,8,9,10
1,0,False,8,9,7
2,0,False,1,5,7
3,0,True,11,12,13
4,0,True,14,15,16
...,...,...,...,...,...
35,98,False,4561,4599,4580
36,98,False,4598,4568,4600
37,98,False,4598,4568,4580
38,98,False,4561,4568,4600


In [117]:
reco_tracks

True

In [138]:
recall_results = pd.merge(all_tracks, reco_tracks,  how='left', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])

In [139]:
recall_results[recall_results.track_pred == True]

Unnamed: 0,event_id,track,px,py,pz,pred,hit_id_0,hit_id_1,hit_id_2,track_pred
0,0,11,-0.193425,-0.004823,-0.029994,0,8,9,10,True
1,0,22,0.342074,-0.229190,0.254322,0,11,12,13,True
2,0,23,-0.258273,-0.097875,-0.528234,0,14,15,16,True
3,1,11,0.094377,0.492997,-0.320608,0,76,77,78,True
4,1,12,-0.302329,0.193706,0.509225,0,79,80,81,True
...,...,...,...,...,...,...,...,...,...,...
395,97,18,0.213483,0.201679,0.224834,0,4558,4559,4560,True
396,98,8,0.260138,-0.645867,0.196305,0,4589,4590,4591,True
397,98,9,-0.325851,0.927250,-0.192919,0,4592,4593,4594,True
398,98,11,0.614087,-0.697357,0.039403,0,4595,4596,4597,True


In [140]:
reco_tracks['pred'] = -1

In [141]:
reco_tracks

Unnamed: 0,event_id,track_pred,hit_id_0,hit_id_1,hit_id_2,pred
0,0,True,8,9,10,-1
1,0,False,8,9,7,-1
2,0,False,1,5,7,-1
3,0,True,11,12,13,-1
4,0,True,14,15,16,-1
...,...,...,...,...,...,...
35,98,False,4561,4599,4580,-1
36,98,False,4598,4568,4600,-1
37,98,False,4598,4568,4580,-1
38,98,False,4561,4568,4600,-1


In [142]:
precision_results = pd.merge(reco_tracks, all_tracks, how='left', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])

In [128]:
precision_results[(precision_results.track_pred == True) & precision_results.pred != 0.0]

Unnamed: 0,event_id,track_pred,hit_id_0,hit_id_1,hit_id_2,track,px,py,pz,pred


In [149]:
precision_results[precision_results.pred_y == 0.0, 'pred_x']=precision_results[precision_results.pred_y == 0.0, 'pred_y'].pred_y.values

ValueError: Length of values (397) does not match length of index (6820)

In [152]:
precision_results.loc[precision_results.pred_y == 0.0, 'pred_x'] = precision_results.loc[precision_results.pred_y == 0.0, 'pred_y'] 

In [155]:
precision_results[(precision_results.track_pred == True) & (precision_results.pred_x != 0.0)]

Unnamed: 0,event_id,track_pred,hit_id_0,hit_id_1,hit_id_2,pred_x,track,px,py,pz,pred_y
18,1,True,82,38,60,-1.0,,,,,
221,6,True,163,225,226,-1.0,,,,,
289,8,True,282,262,268,-1.0,,,,,
303,10,True,335,336,345,-1.0,,,,,
377,13,True,474,434,452,-1.0,,,,,
...,...,...,...,...,...,...,...,...,...,...,...
6606,95,True,4374,4391,4420,-1.0,,,,,
6652,96,True,4454,4465,4480,-1.0,,,,,
6767,97,True,4496,4509,4527,-1.0,,,,,
6769,97,True,4558,4510,4528,-1.0,,,,,
