In [1]:
import os
import sys
import gin
import numpy as np
import pandas as pd
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from IPython.core.display import clear_output, display

import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [125]:
import logging
logging.getLogger().setLevel(logging.DEBUG)
    
from eval.event_evaluation import EventEvaluator
from ariadne_v2.transformations import Compose, ConstraintsNormalize, ToCylindrical, DropSpinningTracks, DropShort, DropEmpty

parse_cfg = {
    'csv_params' : {
        "sep": '\s+',
        #"nrows": 15000,
        "encoding": 'utf-8',
        "names": ['event',  'x', 'y', 'z', 'station', 'track', 'px', 'py', 'pz', 'X0', 'Y0', 'Z0']
    },
    #'input_file_mask': "C:\\Users\\egor\\dubna\\ariadne\\data_bes3\\3.txt",
    'input_file_mask': "/Users/egor/prog/dubna/ariadne/data/bes3/events/3.txt",
    'events_quantity':':'
}

global_transformer = Compose([
        DropSpinningTracks(),
        DropShort(num_stations=3),
        DropEmpty()
    ])

In [79]:
import scripts.clean_cache

#to clean cache if needed
scripts.clean_cache.clean_jit_cache('2h')

OK with that [Y/N]? Y
deleting path df/682cbbad00e5204cedc1681993c97d66
deleting path dc/95dbdc36c370159d948190a9d82d95a8
deleting path df/2cce5da32854c4132352aa3d434333f8
deleting path df/25a62677bf6b9fd3fdfd1ee6cb3ada1a
deleting path df/bc4a38d0aa7cd8744f1b0c84bae07823
deleting path df/16f48a1ad69c3304f44d065c307fbc93


# GraphNet

In [126]:
from ariadne.graph_net.graph_utils.graph_prepare_utils import to_pandas_graph_from_df, get_pd_line_graph, \
    apply_nodes_restrictions, apply_edge_restriction, construct_output_graph
from ariadne.transformations import Compose, ConstraintsNormalize, ToCylindrical

from ariadne_v2.inference import IModelLoader

import torch

class GraphModelLoader(IModelLoader):    
    def __call__(self):
        from ariadne.graph_net.model import GraphNet_v1
        import torch
        
        gin.bind_parameter('GraphNet_v1.input_dim', 5)
        gin.bind_parameter('GraphNet_v1.hidden_dim', 128)
        gin.bind_parameter('GraphNet_v1.n_iters', 1)
        
        def weights_update_g(model, checkpoint):
            model_dict = model.state_dict()    
            pretrained_dict =  checkpoint['state_dict']
            real_dict = {}
            for (k,v) in model_dict.items():
                needed_key = None
                for pretr_key in pretrained_dict:
                    if k in pretr_key:
                        needed_key = pretr_key
                        break
                assert needed_key is not None, "key %s not in pretrained_dict %r!" % (k, pretrained_dict.keys())
                real_dict[k] = pretrained_dict[needed_key]
        
            model.load_state_dict(real_dict)
            model.eval()
            return model
        
        #path_g = '/zfs/hybrilit.jinr.ru/user/g/gooldan/bes/ariadne/lightning_logs/version_63115/checkpoints/epoch=49.ckpt'
        path_g = '/Users/egor/prog/dubna/ariadne/version_32/epoch=201-step=271689.ckpt'
        #path_g = 'C:\\Users\\egor\\dubna\\ariadne\\lightning_logs\\GraphNet_v1\\version_32\\epoch=201-step=271689.ckpt'
        
        checkpoint_g = torch.load(path_g) if torch.cuda.is_available() else torch.load(path_g, map_location=torch.device('cpu'))
        model_g = weights_update_g(model=GraphNet_v1(), 
                           checkpoint=checkpoint_g)
        model_hash = {"path_g":path_g, 'gin':gin.config_str(), 'model': '%r' % model_g}
        return model_hash, model_g

suff_df = ('_p', '_c')
gin.bind_parameter('get_pd_line_graph.restrictions_0',(-0.07, 0.07))
gin.bind_parameter('get_pd_line_graph.restrictions_1', (-0.32, 0.32))
gin.bind_parameter('get_pd_line_graph.suffix_c', '_c')
gin.bind_parameter('get_pd_line_graph.suffix_p','_p')
gin.bind_parameter('get_pd_line_graph.spec_kwargs', {'suffix_c': '_c', 
                                                     'suffix_p':'_p', 
                                                     'axes':['r', 'phi', 'z']} )
_edge_restriction = 0.16

from collections import namedtuple
GraphWithIndices = namedtuple('Graph', ['X', 'Ri', 'Ro', 'y', 'v1v2v3', 'ev_id' ])

transformer_g = Compose([
    DropSpinningTracks(),
    DropShort(),
    DropEmpty(),
    ToCylindrical(),
    ConstraintsNormalize(
        columns=('r', 'phi', 'z'),
        constraints = {'phi': [-3.15, 3.15], 'r': [80.0, 167.0], 'z': [-423.5, 423.5]},
        use_global_constraints = True
    ),
])


def construct_graph_with_indices(graph, v1v2v3, ev_id):
    return GraphWithIndices(graph.X, graph.Ri, graph.Ro, graph.y, v1v2v3, ev_id)


def get_graph(event):
    event = event[['event','x','y','z','station','track', 'index_old']]
    
    try:
        event = transformer_g(event)
    except AssertionError as err:
        print("ASS error %r" % err)
        return None
    
    event.index = event['index_old'].values
    event =  event[['event','r','phi','z','station','track']]
    
    G = to_pandas_graph_from_df(event, suffixes=suff_df, compute_is_true_track=True)

    
    nodes_t, edges_t = get_pd_line_graph(G, apply_nodes_restrictions)

    edges_filtered = apply_edge_restriction(edges_t, edge_restriction=_edge_restriction)
    graph = construct_output_graph(nodes_t, edges_filtered, ['y_p', 'y_c', 'z_p', 'z_c', 'z'],
                                     [np.pi, np.pi, 1., 1., 1.], 'edge_index_p', 'edge_index_c')
    ev_id = event.event.values[0]
    graph_with_inds = construct_graph_with_indices(graph,
                                                   edges_filtered[['from_ind', 'cur_ind', 'to_ind']].values, ev_id)
    return graph_with_inds

from ariadne.graph_net.dataset import collate_fn
def eval_event(tgt_graph, model_g):
    batch_input, batch_target = collate_fn([tgt_graph])
    with torch.no_grad():
        y_pred = model_g(batch_input['inputs']).numpy().flatten() > 0.5
    
    eval_df = pd.DataFrame(columns=['track_pred', 'hit_id_0', 'hit_id_1', 'hit_id_2'])
    eval_df['track_pred'] = y_pred
    eval_df[['hit_id_0', 'hit_id_1', 'hit_id_2']] = tgt_graph.v1v2v3
    return eval_df


In [None]:
N_STATIONS = 3

evaluator = EventEvaluator(parse_cfg, global_transformer, N_STATIONS)
events = evaluator.prepare(model_loader=GraphModelLoader())[0]
all_tracks, all_events = evaluator.build_all_tracks()
reco_tracks, reco_events = evaluator.run_model(get_graph, eval_event)


read entry 682cbbad00e5204cedc1681993c97d66 hit
[prepare]: started processing a df 3.txt with 1130161 rows:
read entry 0fc24df1e7a03a01a0a8f98d0f88573a hit


In [82]:
all_tracks

Unnamed: 0,event_id,track,px,py,pz,pred,hit_id_0,hit_id_1,hit_id_2
0,0,11,-0.193425,-0.004823,-0.029994,0,8,9,10
1,0,22,0.342074,-0.229190,0.254322,0,11,12,13
2,0,23,-0.258273,-0.097875,-0.528234,0,14,15,16
3,1,11,0.094377,0.492997,-0.320608,0,76,77,78
4,1,12,-0.302329,0.193706,0.509225,0,79,80,81
...,...,...,...,...,...,...,...,...,...
395,97,18,0.213483,0.201679,0.224834,0,4558,4559,4560
396,98,8,0.260138,-0.645867,0.196305,0,4589,4590,4591
397,98,9,-0.325851,0.927250,-0.192919,0,4592,4593,4594
398,98,11,0.614087,-0.697357,0.039403,0,4595,4596,4597


In [124]:
tracks_pred = reco_tracks[reco_tracks.track_pred]
reco_tracks_impulses = tracks_pred[['px', 'py', 'pz']]
reco_tracks_preds = tracks_pred[['event_id','track_pred','hit_id_0','hit_id_1', 'hit_id_2']]
reco_tracks_preds['idx_old'] = tracks_pred.index

results = pd.merge(all_tracks, reco_tracks_preds,  how='outer', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])

not_found_tracks = (results.track_pred != False) & (results.track_pred != True)
results.loc[not_found_tracks, 'track_pred'] = False

results.loc[results.track_pred, ['pred']] = 1

ghosts_idx_all = pd.isna(results.track)
ghosts_idx_reco = results[ghosts_idx_all].idx_old.astype('int')
ghosts_impulses = reco_tracks_impulses.loc[ghosts_idx]
results.loc[ghosts_idx_all, ['px', 'py', 'pz']] = ghosts_impulses[['px','py','pz']].values
results.loc[ghosts_idx_all, 'track'] = -1
results.loc[ghosts_idx_all, 'pred'] = -1
results = results.drop(['track_pred', 'idx_old'], axis=1)
results['pred'] = results['pred'].astype('int')
results['track'] = results['track'].astype('int')

#results[(results.track_pred == True) & (results.pred == 0), 'pred'] = 1

#
#


results[results.pred != 1]


#ghosts_impulses

Unnamed: 0,event_id,track,px,py,pz,pred,hit_id_0,hit_id_1,hit_id_2
58,17,11,0.039055,0.016044,-0.098930,0,704,705,706
66,20,12,0.064797,-0.056723,-0.356679,0,819,820,821
69,20,20,-0.142656,-0.029636,0.141729,0,831,832,833
81,22,17,0.133924,-0.109195,0.006847,0,1045,1046,1047
82,22,19,0.260980,-0.179360,0.094161,0,1048,1049,1050
...,...,...,...,...,...,...,...,...,...
461,95,-1,-0.278283,-0.215039,-0.194499,-1,4374,4391,4420
462,96,-1,-0.245785,-0.283860,-0.283428,-1,4454,4465,4480
463,97,-1,-0.533304,-0.373059,-0.647679,-1,4496,4509,4527
464,97,-1,-0.533304,-0.373059,-0.647679,-1,4558,4510,4528


In [60]:
recall_results = pd.merge(all_tracks, reco_tracks,  how='outer', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])
not_found_tracks = (recall_results.track_pred != False) & (recall_results.track_pred != True)
recall_results.loc[not_found_tracks, 'track_pred'] = False
recall_results.loc[recall_results.track_pred, 'pred'] = 1
recall_results = recall_results.drop(['track_pred'], axis=1)


In [61]:
recall_results

Unnamed: 0,event_id,track,px_x,py_x,pz_x,pred,hit_id_0,hit_id_1,hit_id_2,px_y,py_y,pz_y
0,0,11.0,-0.193425,-0.004823,-0.029994,1.0,8,9,10,-0.258273,-0.229190,-0.528234
1,0,22.0,0.342074,-0.229190,0.254322,1.0,11,12,13,-0.258273,-0.229190,-0.528234
2,0,23.0,-0.258273,-0.097875,-0.528234,1.0,14,15,16,-0.258273,-0.229190,-0.528234
3,1,11.0,0.094377,0.492997,-0.320608,1.0,76,77,78,-0.302329,-0.303768,-0.320608
4,1,12.0,-0.302329,0.193706,0.509225,1.0,79,80,81,-0.302329,-0.303768,-0.320608
...,...,...,...,...,...,...,...,...,...,...,...,...
6818,98,,,,,,4561,4599,4580,-0.325851,-0.697357,-0.192919
6819,98,,,,,,4598,4568,4600,-0.325851,-0.697357,-0.192919
6820,98,,,,,,4598,4568,4580,-0.325851,-0.697357,-0.192919
6821,98,,,,,,4561,4568,4600,-0.325851,-0.697357,-0.192919


In [55]:
precision_results = pd.merge(reco_tracks, all_tracks, how='left', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])

In [58]:
true_pred_tracks = reco_tracks[reco_tracks.track_pred]

precision_results = pd.merge(true_pred_tracks, all_tracks, how='left', on=['event_id', 'hit_id_0','hit_id_1', 'hit_id_2'])
reco_ghosts_idx = pd.isna(precision_results.track)
reco_ghosts = precision_results[reco_ghosts_idx]
reco_ghosts['pred'] = -1


In [59]:
reco_ghosts = 

Unnamed: 0,event_id,track_pred,px_x,py_x,pz_x,hit_id_0,hit_id_1,hit_id_2,track,px_y,py_y,pz_y,pred
6,1,True,-0.302329,-0.303768,-0.320608,82,38,60,,,,,-1
22,6,True,-0.231819,-0.105091,-0.133781,163,225,226,,,,,-1
33,8,True,-0.522000,-0.272917,-0.095293,282,262,268,,,,,-1
39,10,True,-0.639351,-0.256428,-0.084903,335,336,345,,,,,-1
49,13,True,-0.336088,-0.355324,0.000135,474,434,452,,,,,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
421,95,True,-0.278283,-0.215039,-0.194499,4374,4391,4420,,,,,-1
423,96,True,-0.245785,-0.283860,-0.283428,4454,4465,4480,,,,,-1
434,97,True,-0.533304,-0.373059,-0.647679,4496,4509,4527,,,,,-1
435,97,True,-0.533304,-0.373059,-0.647679,4558,4510,4528,,,,,-1


In [20]:
precision_results.loc[precision_results.pred_y == 0.0, 'pred_x'] = precision_results.loc[precision_results.pred_y == 0.0, 'pred_y'] 

In [21]:
precision_results[(precision_results.track_pred == True) & (precision_results.pred_x != 0.0)]

Unnamed: 0,event_id,track_pred,hit_id_0,hit_id_1,hit_id_2,pred_x,track,px,py,pz,pred_y
18,1,True,82,38,60,-1.0,,,,,
221,6,True,163,225,226,-1.0,,,,,
289,8,True,282,262,268,-1.0,,,,,
303,10,True,335,336,345,-1.0,,,,,
377,13,True,474,434,452,-1.0,,,,,
...,...,...,...,...,...,...,...,...,...,...,...
6606,95,True,4374,4391,4420,-1.0,,,,,
6652,96,True,4454,4465,4480,-1.0,,,,,
6767,97,True,4496,4509,4527,-1.0,,,,,
6769,97,True,4558,4510,4528,-1.0,,,,,
