## Used to create training files
Before running this notebook you will need to run the solution algorithm on the training files and save the solutions

In [1]:
from tqdm import tqdm
import pickle
from sklearn.cluster.dbscan_ import dbscan
import sys
sys.path.insert(0, 'other/')
from trackml.dataset import load_dataset
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from functions.other import calc_features, get_event, score_event_fast, load_obj,hit_score
from functions.expand import *
from functions.cluster import *
from functions.ml_model import merge_with_probabilities,precision_and_recall,get_features,get_predictions


In [2]:
min_hits=4
train_path='../data/LHC/train/'    #Path to traning data
clustered_path = '../data/LHC/clustered/' #path to solutions
solution_prefix = '_cluster' #the solution files prefix
num_train=25
num_val=5


In [3]:
def save_obj(obj, filename):
    """
    Example:
        filename = "folder/filename.pkl
        arr = [3,4,5]
        save_obj(arr ,filename)
    """
    with open(filename, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
    print("saved to " + filename)


def load_obj(filename):
    """
    Example:
        filename = "folder/filename.pkl
        arr = load_obj(arr ,filename)
    """
    with open(filename, 'rb') as f:
        print("loaded from " + filename)
        return pickle.load(f)


def get_true_tracks(hits,particles,truth):
    hitst=hits.merge(truth[['hit_id','particle_id']],on='hit_id',how='left')
    hitst=hitst.merge(particles[['particle_id','nhits']],on='particle_id',how='left')
    hitst=hitst[hitst.nhits>=min_hits].rename(columns={'particle_id':'track_id'})
    d=get_features(hitst)
    return d[['svolume','nclusters','nhitspercluster','xmax','ymax','zmax','xmin','ymin','zmin','zmean','xvar','yvar','zvar']]

In [7]:
df_all_subs=pd.DataFrame()


for i in range(0,num_train+num_val):
    event='event000001{:03d}'.format(i)
    print('do event',event)
    hits, cells, particles, truth = get_event(train_path,event)
    sub=pd.read_csv('{}{}{}.csv'.format(clustered_path,event,solution_prefix),index_col=False)
    ssub=hit_score(sub,truth)
    gp=ssub.groupby('track_id').agg({'score':'sum','weight':'sum','track_len':'max'}).reset_index()
    ssub=ssub[['hit_id','track_id']].merge(gp,on='track_id',how='left')
    ssub=ssub[((ssub.score==0) | (ssub.score<ssub.weight)) & (ssub.track_len>4)][['hit_id','track_id']]
    ssub=ssub.merge(hits,on='hit_id',how='left')
    df_wrong=get_features(ssub)
    df_wrong=df_wrong[['svolume','nclusters','nhitspercluster','xmax','ymax','zmax','xmin','ymin','zmin','zmean','xvar','yvar','zvar']]

    #todo 1: Select all WRONG tracks as follows. Take tracks, which: 
    #      (1) have at least 4 hits, and at most 23 hits
    #      (2) not all hits belong to the same particle_id (eg by merging the particle_id onto the 
             # the submission dataframe, then count how many unique particle_ids each track has and only
             # consider those that have >= particle_ids
    
    #todo 2: Create a dataframe df_wrong, such that each row in this dataframe is one track. In addition
    #        add all features of this track (which will be used later in the ML algorithm). Each feature
    #        is one column
    df_wrong['target']=0
    df_true=get_true_tracks(hits,particles,truth)
    df_true['target']=1
    df_true=df_true.sample(frac=0.35) #We want the num of true events==the num of wrong events
    print('true:{}'.format(df_true.shape[0]))
    print('wrong:{}'.format(df_wrong.shape[0]))
    df_both=pd.concat([df_true,df_wrong],ignore_index=False,sort=True)
    df_both['event_id']=i
    df_both=df_both.sample(frac=1).reset_index(drop=True)  # shuffle
    df_all_subs=df_all_subs.append(df_both	, ignore_index=True)


df_train=df_all_subs[df_all_subs['event_id']>num_val]
df_test=df_all_subs[df_all_subs['event_id']<=num_val]
    
save_obj(df_train,'files/df_train_v2-reduced.pkl')
save_obj(df_test,'files/df_test_v1.pkl')

do event event000001000
true:3252
wrong:3087
do event event000001001
true:2379
wrong:1922
do event event000001002
true:3383
wrong:3290
do event event000001003
true:2737
wrong:2404
do event event000001004
true:3773
wrong:3888
do event event000001005
true:2889
wrong:2632
do event event000001006
true:3151
wrong:2876
do event event000001007
true:2927
wrong:2644
do event event000001008
true:2952
wrong:2735
do event event000001009
true:2846
wrong:2540
do event event000001010
true:2695
wrong:2480
do event event000001011
true:3109
wrong:2920
do event event000001012
true:2900
wrong:2511
do event event000001013
true:2775
wrong:2503
do event event000001014
true:3372
wrong:3337
do event event000001015
true:3280
wrong:3177
do event event000001016
true:3130
wrong:2881
do event event000001017
true:3318
wrong:3226
do event event000001018
true:2329
wrong:1911
do event event000001019
true:3285
wrong:3215
do event event000001020
true:2400
wrong:1970
do event event000001021
true:2716
wrong:2450
do event e