In [1]:
import os
import sys
import time
import gin
import numpy as np
import pandas as pd
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from IPython.core.display import clear_output, display

import matplotlib.pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

  from IPython.core.display import clear_output, display
  from IPython.core.display import clear_output, display


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [3]:
import logging
logging.getLogger().setLevel(logging.DEBUG)
    
from eval.event_evaluation import EventEvaluator
from ariadne_v2.transformations import Compose, DropShort, DropTracksWithHoles, DropSpinningTracks, BakeStationValues, ConstraintsNormalize, FixStationsBMN
from ariadne.utils.model import get_checkpoint_path, weights_update

parse_cfg = {
    'csv_params' : {
        "sep": '\s+',
        "encoding": 'utf-8',
        "names": ['event',  'x', 'y', 'z', 'det','station', 'track', 'px', 'py', 'pz', 'vx', 'vy', 'vz']
    },
    'input_file_mask':'/zfs/store5.hydra.local/user/p/pgonchar/data/bmn/run7/events/simdata_ArPb_3.2AGeV_mb_46.txt',
    'events_quantity':'1..2000'
}

z_values = {0:12.344, 1: 15.614, 2: 24.499, 3: 39.702, 4: 64.535, 5: 112.649, 6: 135.330,7: 160.6635, 8: 183.668}
constraints = {'x': [-81.1348, 86.1815], 'y': [-27.17125, 38.90473], 'z': [11.97, 183.82]}

global_transformer = Compose([
    FixStationsBMN(),
    DropShort(num_stations=4),
    DropTracksWithHoles(),
    DropSpinningTracks(),
    BakeStationValues(values=z_values),
    ConstraintsNormalize(constraints=constraints),
]
)

In [4]:
import scripts.clean_cache

#to clean cache if needed
#scripts.clean_cache.clean_jit_cache('20d')

## Tracknet

In [5]:
import torch
import faiss

NUM_POINTS_TO_SEARCH = 5

from ariadne_v2.inference import IModelLoader
from ariadne.tracknet_v2.model import TrackNETv2
class TrackNetModelLoader(IModelLoader):    
    def __call__(self):
        tracknet_input_features=3
        tracknet_conv_features=32
        
        tracknet_ckpt_path_dict = {'model_dir': '/zfs/hybrilit.jinr.ru/user/d/drusov/ariadne/lightning_logs/TrackNETv2', 
                                   'version': 'OnlyRNN_alpha0.01_unbalanced', 'checkpoint': 'latest'}
        path_to_tracknet_ckpt = get_checkpoint_path(**tracknet_ckpt_path_dict)

        model = weights_update(model=TrackNETv2(input_features=tracknet_input_features,
                                        conv_features=tracknet_conv_features,
                                        rnn_type='gru',
                                        batch_first=True,
                                        use_causalconv=False,
                                        use_rnn=True),
                       checkpoint=torch.load(path_to_tracknet_ckpt, map_location=torch.device(DEVICE)))
        model.eval()
        model.to(DEVICE)

        model_hash = {
            "tracknet_ckpt_path_dict":path_to_tracknet_ckpt,
            'gin':gin.config_str(), 
            'model': '%r' % model,
            'NUM_POINTS_TO_SEARCH':NUM_POINTS_TO_SEARCH
        }
        return model_hash, [model]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)
import ariadne.transformations  as trn

tracknet_transformer = trn.Compose([
    #trn.DropShort(num_stations=4),
    #trn.DropTracksWithHoles(),
    #trn.DropSpinningTracks(),
    #trn.BakeStationValues(values=z_values),
    #trn.ConstraintsNormalize(constraints=constraints),
    ]
)

NUM_COMPONENTS = 2
SUFX = ['_p', '_c']
COLS = ['x', 'y', 'z']

def build_index(target_df):
    cont = np.ascontiguousarray(target_df[COLS].values)
    return store_in_index(cont)

def build_hits(target_df):
    cont = np.ascontiguousarray(target_df[COLS].values)
    return cont

def preprocess_one_event(event_df, n_neighbours=5):
    try:
        event_df = tracknet_transformer(event_df)
    except AssertionError as err:
        #print("ASS error %r" % err)
        return None
    return event_df

times = []

def run_tracknet_eval(result_from_prev_step, model, n_neighbours=5):
    timestart = time.time()
    model = model[0]
    preds, labels, ellipses = go_over_stations(result_from_prev_step, model, n_neighbours)
    times.append(time.time() - timestart)
    result = build_df(preds, labels, result_from_prev_step)
    for col in [f'hit_id_{i}' for i in range(9)]:
        result[col] = result[col].astype(int)
    return result

cuda


In [6]:
from ariadne.utils.base import store_in_index,search_in_index

In [7]:
def go_over_stations(df, model, n_neighbours, max_n_stations=9):
    max_batch_size = 256
    all_hits_index = build_index(df)
    tracks_vs_len, all_tracks, multiplicity = get_tracks(df)
    #print(tracks_vs_len)
    #seeds, target = to_cart(df)
    chunk_data_x = get_seeds_only_first_station(df)
    #print(chunk_data_x)
    #chunk_data_x = np.expand_dims(chunk_data_x, 1)
    #index = build_index(target)
    # search(target[:2], index)
    #chunk_data_x = seeds_to_input(seeds)
    chunk_data_len = torch.tensor(np.full(len(chunk_data_x), 1), dtype=torch.int64).to(DEVICE)
    gru_candidates = []
    x_candidates = []
    candidate_labels = []
    candidate_ellipses = []
    for stations_gone in range(1, max_n_stations):
        #print(f'===> {stations_gone}')
        num_batches = int(len(chunk_data_x) / max_batch_size) + 1
        #print(num_batches)
        num_right_batches = 0
        num_all_batches = 0
        next_stage_xs = []
        next_stage_lens = []
        if len(chunk_data_x) == 0:
            #print('Have zero ellipces on this station! Skipping other stations')
            #print(candidate_ellipses)
            #print(x_candidates)
            #print(candidate_labels)
            return x_candidates, candidate_labels, candidate_ellipses
        labels_for_empty_el_batch = []
        labels_for_full_el_batch = []
        station_df = df[df['station']==stations_gone]
        current_index = build_index(station_df)
        this_station_hits = build_hits(station_df)
        for batch_num in range(num_batches):
            min_i = max_batch_size*batch_num
            max_i = min(max_batch_size + min_i, len(chunk_data_x))
            if min_i==max_i:
                #print('Have zero ellipces on this station! Skipping other stations')
                #print(candidate_ellipses)
                #print(x_candidates)
                #print(candidate_labels)
                return x_candidates, candidate_labels, candidate_ellipses

            this_batch_x = torch.tensor(chunk_data_x[min_i: max_i]).to(DEVICE)
            this_batch_len = chunk_data_len[min_i: max_i]
            #print(this_batch_x)
            batch_prediction, batch_gru = model(this_batch_x.float(),
                                                     torch.tensor(this_batch_len, dtype=torch.int64).to(DEVICE),
                                                     return_gru_states=True)
            new_pred = torch.full((len(batch_prediction), 5), z_values[stations_gone], device=DEVICE)
            new_pred[:, :2] = batch_prediction[:, -1, :2]
            new_pred[:, 3:] = batch_prediction[:, -1, 2:]
            batch_prediction = new_pred
            batch_gru = batch_gru[:, -1]
            #if batch_num == 0:
              #  print(f'station {stations_gone+1}, on this station: {len(this_station_hits)} hits' )
            if len(this_station_hits) == 0:
            #    LOGGER.info('Have zero hits on this station! Skipping other stations')
            #    print(candidate_ellipses)
            #    print(x_candidates)
            #    print(candidate_labels)
                return x_candidates, candidate_labels, candidate_ellipses
            prediction_numpy = batch_prediction.detach().cpu().numpy()
            nearest_hits_index = search_in_index(prediction_numpy[:, :3],
                                                 current_index, 
                                                 n_neighbours,
                                                 n_dim=3)
            nearest_hits = this_station_hits[nearest_hits_index]
            nearest_hits, in_ellipse_mask = filter_hits_in_ellipses(prediction_numpy,
                                                               nearest_hits,
                                                               nearest_hits_index,
                                                               filter_station=False,
                                                               z_last=True,
                                                               find_n=nearest_hits_index.shape[1])
            nearest_hits = torch.from_numpy(nearest_hits)
            nearest_hits_mask = in_ellipse_mask 
            # here empty ellipses and all inputs for them are saved
            empty_xs, empty_ellipses, empty_grus, predicted_ellipses = get_data_for_empty_ellipses(this_batch_x, 
                                                                               batch_prediction, 
                                                                               batch_gru, 
                                                                               nearest_hits_mask)
            prolonged_batch_xs, prolonged_grus = prolong(this_batch_x,
                                                         batch_gru,
                                                         nearest_hits_mask,
                                                         nearest_hits,
                                                         stations_gone,
                                                         use_torch=True)
            next_stage_xs.append(prolonged_batch_xs)
            next_stage_lens.append(np.full(len(prolonged_batch_xs), stations_gone + 1))
            if stations_gone > 3:
                use_empty_ellipses = []
                empty_ellipses_station_intersections = []
                use_empty_ellipses = torch.ones(len(empty_ellipses), dtype=torch.bool)
                empty_xs = empty_xs[use_empty_ellipses]
                empty_grus = empty_grus[use_empty_ellipses]
                if len(empty_xs) > 0:
                    (candidate_labels,
                     labels_for_empty_el_batch,
                     gru_candidates,
                     x_candidates,
                     candidate_ellipses) =  get_candidates(empty_xs, 
                                                     tracks_vs_len[stations_gone],
                                                     prolonged_grus[:, -2], 
                                                     empty_ellipses,
                                                     all_hits_index,
                                                     candidate_labels, 
                                                     labels_for_full_el_batch, 
                                                     gru_candidates, 
                                                     x_candidates,
                                                     candidate_ellipses)
                    #print(candidate_ellipses)
            if (stations_gone == (max_n_stations-1)): # if we are predicting for last station
                if len(prolonged_batch_xs) > 0:  # now we prolong candidates with all hits *in* ellipses
                    (candidate_labels,
                     labels_for_full_el_batch, 
                     gru_candidates, 
                     x_candidates,
                     candidate_ellipses) =  get_candidates(prolonged_batch_xs, 
                                                     tracks_vs_len[stations_gone + 1],
                                                     prolonged_grus[:, -1], 
                                                     predicted_ellipses,
                                                     all_hits_index,
                                                     candidate_labels, 
                                                     labels_for_full_el_batch, 
                                                     gru_candidates, 
                                                     x_candidates,
                                                     candidate_ellipses)

        if len(next_stage_xs) > 1:
            chunk_data_x = np.concatenate(next_stage_xs, 0)
            chunk_data_len = np.concatenate(next_stage_lens, 0)
        else:
            try:
                chunk_data_x = next_stage_xs[0]
                chunk_data_len = next_stage_lens[0]
            except:
                continue
    return x_candidates, candidate_labels, candidate_ellipses

def get_tracks(df, min_len=4):
    tracks = df[df['track'] != -1].groupby('track')
    multiplicity = tracks.ngroups
    tracks_vs_len = {3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}
    all_tracks = []
    for i, track in tracks:
        temp_track = track[['x', 'y', 'z']].values
        if len(temp_track) >= min_len:
            tracks_vs_len[len(temp_track)].append(temp_track)
            all_tracks.append(temp_track)
    for stations_in_track, this_track_list in tracks_vs_len.items():
        if len(this_track_list) > 0:
            tracks_vs_len[stations_in_track] = np.stack(this_track_list, 0)
    return tracks_vs_len, all_tracks, multiplicity

def get_seeds_only_first_station(df, columns=['x','y','z']):
    real = df#[df['track']!=-1]
    temp1 = real[real.station == 0]
    st0_hits = temp1[columns].values
    # all possible combinations
    seeds = np.zeros((len(temp1), 1, 3))
    seeds[:, 0, ] = st0_hits
    return seeds

def filter_hits_in_ellipses(ellipses, nearest_hits, hits_index, z_last=True, filter_station=True, find_n=10):
    """Function to get hits, which are in given ellipse.
    Space is 3-dimentional, so either first component of ellipce must be z-coordinate or third.
    Ellipse semiaxises must include x- and y-axis.
    Arguments:
        ellipse (np.array of size 5): predicted index with z-component like
                                      (x,y,z, x-semiaxis, y_semiaxis) or (z, x,y, x-semiaxis, y_semiaxis)
        nearest_hits (np.array of shape (n_hits, 3) or only 3): some hits in 3-dim space
        z_last (bool): If True, first component of vector is interpreted as z, if other, third.
        filter_station (bool): if True, only hits with same z-coordinate are considered, else all hits
    Returns:
        numpy.ndarry with filtered hits, all of them in given ellipse, sorted by increasing of distance
    """
    assert nearest_hits.shape[-1] == 3, "index is 3-dimentional, please add z-coordinate to centers"
    if nearest_hits.ndim < 2:
        nearest_hits = np.expand_dims(nearest_hits, 0)
    if nearest_hits.ndim < 3:
        nearest_hits = np.expand_dims(nearest_hits, 0)
    assert ellipses.shape[-1] == 5, "index is 3-dimentional, you need to provide z-coordinate (z_c, x_c, y_c, x_r, y_r) or (x_c, y_c, z_c, x_r, y_r)"
    #ellipses = np.expand_dims(ellipses, -1)
    #find_n = len(nearest_hits)
    ellipses = np.expand_dims(ellipses, 2)
    #found_hits = nearest_hits.reshape(-1, find_n, nearest_hits.shape[-1])
    if z_last:
        x_part = (ellipses[:,0].repeat(find_n,1) - nearest_hits[:, :, 0]) / ellipses[:, 3].repeat(find_n,1)
        #print(x_part**2)
        y_part = (ellipses[:,1].repeat(find_n,1) - nearest_hits[:, :, 1]) / ellipses[:, 4].repeat(find_n,1)
        #print(y_part**2)
    else:
        x_part = (nearest_hits[:, :, 1] - ellipses[:, 1].repeat(find_n, 1)) / ellipses[:, -2].repeat(find_n, 1)
        y_part = (nearest_hits[:, :, 2] - ellipses[:, 2].repeat(find_n, 1)) / ellipses[:, -1].repeat(find_n, 1)
    left_side = x_part**2 + y_part**2
    is_in_ellipse = left_side <= 1
    is_in_ellipse *= hits_index != -1
    return nearest_hits, is_in_ellipse

def get_data_for_empty_ellipses(x, preds, grus, mask):
    empty_ellipses_mask = (mask.sum(axis=-1) == 0) # if nothing is found in ellipse
    empty_ellipses = preds[empty_ellipses_mask]#.detach().cpu().numpy()
    full_ellipses = preds[~empty_ellipses_mask]
    empty_xs = x[empty_ellipses_mask]
    empty_grus = grus[empty_ellipses_mask]#.detach().cpu().numpy()
    return empty_xs, empty_ellipses, empty_grus, full_ellipses

def prolong(x, gru, nearest_hits_mask, nearest_hits, stations_gone, use_torch=False):
    if use_torch:
        xs_for_prolong = x.unsqueeze(1).repeat((1, nearest_hits_mask.shape[-1], 1,1))
        grus_for_prolong = gru.unsqueeze(1)
        grus_for_prolong = grus_for_prolong.repeat((1, nearest_hits_mask.shape[-1], 1))
        prolonged_xs = torch.zeros((len(xs_for_prolong),
                                    nearest_hits_mask.shape[-1],
                                    xs_for_prolong.shape[-2] + 1,
                                    xs_for_prolong.shape[-1]))
        prolonged_xs[:, :, :xs_for_prolong.shape[-2], :] = xs_for_prolong
        prolonged_xs[:, :, xs_for_prolong.shape[-2], :] = nearest_hits
        #print(prolonged_xs.shape)
        prolonged_xs = prolonged_xs.reshape(-1, stations_gone + 1, 3)
        nearest_hits_mask = nearest_hits_mask.reshape(-1)
        prolonged_xs = prolonged_xs[nearest_hits_mask]
        prolonged_grus = grus_for_prolong.reshape(-1, grus_for_prolong.shape[-1])
        prolonged_grus = prolonged_grus[nearest_hits_mask]
    else:
        xs_for_prolong = np.expand_dims(x, 1).repeat(nearest_hits_mask.shape[-1], 1)
        grus_for_prolong = np.expand_dims(gru.detach().cpu().numpy(), 1).repeat(nearest_hits_mask.shape[-1], 1)
        prolonged_xs = np.zeros(
            (len(xs_for_prolong), nearest_hits_mask.shape[-1], xs_for_prolong.shape[2] + 1, 3))

        prolonged_xs[:, :, :xs_for_prolong.shape[2], :] = xs_for_prolong
        prolonged_xs[:, :, xs_for_prolong.shape[2], :] = nearest_hits
        prolonged_xs = prolonged_xs[nearest_hits_mask].reshape(-1, stations_gone + 1, 3)
        prolonged_grus = grus_for_prolong[nearest_hits_mask].reshape(-1,
                                                                     grus_for_prolong.shape[-2],
                                                                     grus_for_prolong.shape[-1])
    return prolonged_xs, prolonged_grus

def get_candidates(preds, targets, grus, ellipses, index, labels,  labels_for_batch, gru_candidates, track_candidates, candidate_ellipses):
    orig_ellipses = torch.zeros((len(ellipses), 4))
    orig_ellipses[:, :2] = ellipses[:, :2]
    orig_ellipses[:, 2:] = ellipses[:, 3:]
    labels_for_ellipses = get_labels_faiss(targets,
                                     preds.detach().cpu().numpy(), index=index)
    labels_for_batch.append(labels_for_ellipses)
    gru_candidates.extend(grus)  # because we need gru for predicted ellipse
    labels.extend(labels_for_ellipses)
    track_candidates.extend(preds)
    candidate_ellipses.extend(orig_ellipses)
    return labels, labels_for_batch, gru_candidates, track_candidates, candidate_ellipses

def get_labels_faiss(gt_tracks, predicted_tracks, index):
    labels_for_ellipses = np.zeros(len(predicted_tracks))
    assert len(predicted_tracks) > 0, 'Can not compute labels for empty set of tracks!'
    if len(gt_tracks) > 0:
        tracks_len = gt_tracks.shape[-2]
        gt_tracks = gt_tracks.reshape(-1, gt_tracks.shape[-1])
        predicted_tracks = predicted_tracks.reshape(-1, gt_tracks.shape[-1])
        gt_index = search_in_index(gt_tracks, index, 1, n_dim=3).flatten().reshape(-1, tracks_len)
        predicted_index = search_in_index(predicted_tracks, index, 1, n_dim=3).flatten().reshape(-1, tracks_len)
        expanded_gt = np.expand_dims(gt_index, 0).repeat(len(predicted_index), 0)
        expanded_preds = np.expand_dims(predicted_index, 1).repeat(expanded_gt.shape[1], 1)
        labels_for_ellipses += np.any(np.all(np.equal(expanded_gt, expanded_preds), axis=-1),axis=-1)
        assert len(labels_for_ellipses) == len(expanded_preds), 'length of labels and xs is different!'
    return labels_for_ellipses

def build_df(preds, labels, event_df):
    curr_event = event_df.event.values[0]
    n_stations = 9
    
    new_columns = [*[f'hit_id_{i}' for i in range(n_stations)], 'event', 'track_pred']
    new_df = pd.DataFrame(columns=new_columns)
    new_df.track_pred = new_df.track_pred.astype('bool')
    
    if len(preds) > 10000:
        print(f'{len(preds)} preds found in {curr_event}, skipping')
        return new_df

    big_index = build_index(event_df)
    event_idxs = np.array(event_df.index_old)
    tracks_list = []
    for idx1, pred in enumerate(preds):
        found = searching(pred, big_index)
        hit_list = event_idxs[found]
        to_pad = n_stations - len(hit_list)
        padded_hits = np.pad(hit_list, (0, to_pad), constant_values=-1)
        
        padded_hits = np.append(padded_hits, curr_event)
        padded_hits = np.append(padded_hits, True)
        tracks_list.append(padded_hits)
    
    new_df = pd.DataFrame(tracks_list, columns=new_columns)
    new_df.track_pred = new_df.track_pred.astype('bool')
    return new_df

def searching(points, index):
    cont = np.ascontiguousarray(points.cpu())
    return search_in_index(cont, index, 1, n_dim=3).flatten()


In [8]:
N_STATIONS = 9

evaluator = EventEvaluator(parse_cfg, global_transformer, N_STATIONS)
events = evaluator.prepare(model_loader=TrackNetModelLoader())[0]
all_results = evaluator.build_all_tracks()
model_results = evaluator.run_model(preprocess_one_event, run_tracknet_eval)
start_time = time.time()
results_tracknet = evaluator.solve_results(model_results, all_results)
end_time = time.time()
print(f'solving took {end_time - start_time} seconds')

read entry b454a12f953a77a36bdec209025234a1 hit
[prepare]: started processing a df simdata_ArPb_3.2AGeV_mb_46.txt with 1283239 rows:
read entry 7d8f6db760f5fa6b4c0bbb8354579dc1 hit
[prepare] finished
[prepare] loading your model(s)...
[prepare] finished loading your model(s)...
[build_all_tracks] start
read entry d5bffb45e73d5372d559f219ef01293a hit
read entry c2e5618193c38aa57d62358b64cb868f hit
[build_all_tracks] cache hit, finish
[run model] start
read entry 2b3f80300b54c1942ff70da5e04e3b5d hit
read entry f1204caec0519044d6592073a79dd3aa hit


processed: 1999: 100%|██████████| 1999/1999 [02:29<00:00, 13.41it/s]
[run model] cache miss, finish
[solve results] start
[solve results] finish
[solve results] final stats:
Total events evaluated: 1746
Total tracks evaluated: 12358
Track Efficiency (recall): 0.9114 
Track Purity (precision): 0.0208 
Fully reconstructed event ratio: 0.6262
Mean cpu time per event: 0.0003 sec (3192.70 events per second) 
Mean gpu time per event: 0.0706 sec (14.

In [9]:
times_arr = np.array(times)
print(f"Building raw track candidates: time mean - {times_arr.mean()}, time std - {np.std(times_arr)}")

Building raw track candidates: time mean - 0.037587725203773154, time std - 0.018141825237064712


## Data processing

Data pattern:
    - Dict:
        - tracks_candidates - List of length N - number of track candidates:
            - np.array: shape (M, 3) - M - number of hits per track;
        - labels - np.array: shape (N) - labels for track candidates (True - real, False - ghost);
        - events - np.array: shape (N) - number of event;
        - total_tracks - int: number of real tracks in all events;

In [10]:
STATION_COLUMNS = [f"hit_id_{n}" for n in range(N_STATIONS)]

In [11]:
events_xyz = events[['x', 'y', 'z']].copy()
events_xyz.set_index(events.index_old, inplace=True)

In [12]:
from tqdm import tqdm

In [13]:
total = len(results_tracknet[0].query('pred != -1'))
results_candidates = results_tracknet[0].query('pred != 0')

labels = results_candidates.pred.values > 0
events_ids = results_candidates.event_id.values

track_candidates = []
results_candidates = results_candidates[STATION_COLUMNS]
for track_candidate in tqdm(results_candidates.itertuples(index=False, name=None), total=len(results_candidates)):
    track_candidate = np.array(track_candidate)
    track_candidate = track_candidate[track_candidate >= 0]
    track_candidates.append(events_xyz.loc[track_candidate].values)

result_dict = {'track_candidates': track_candidates,
               'labels': labels,
               'events': events_ids,
               'total_tracks': total
}

100%|██████████| 541894/541894 [01:41<00:00, 5319.23it/s]


In [14]:
import pickle

In [15]:
with open("data.pkl", "wb") as f:
    pickle.dump(result_dict, f)

In [16]:
with open("data.pkl", "rb") as f:
    output = pickle.load(f)

In [17]:
len(output['labels'][output['labels']]) / output['total_tracks']

0.9113934293575012

In [18]:
len(output['labels'][output['labels']]) / len(output['labels'])

0.020784507671241976