In [10]:
import numpy as np
import torch
import pandas as pd
from itertools import product
from torch_geometric.data import Data

evt = "/mnt/data0/Trackml_dataset_100_events/Example_3/trackml_100_events/event000021000-"

#cells = pd.read_csv(evt+'cells.csv')
hits = pd.read_csv(evt+'hits.csv')
#particles = pd.read_csv(evt+'particles.csv')
#truth = pd.read_csv(evt+'truth.csv')

#print("Cells:",cells.head())
#print("Hits:",hits.head())
#print("Particles:",particles.head())
#print("Truth:",truth.head())

#Only keep hits with volume_id = 8,13,17
hits = hits[hits.volume_id.isin([8,13,17])]

#Relabel layer_id to contain volume information
hits['layer_id'] = hits['layer_id']/2 - 1
hits['layer_id'] = hits['volume_id']*10 + hits['layer_id']

hits = hits.sort_values(by='layer_id')

def semi_fully_connected_graph(event_path):
    hits = pd.read_csv(event_path + 'hits.csv')
    hits = hits[hits.volume_id.isin([8,13,17])] #only consider barrel detector
    hits = hits[hits.y > 0] #only consider top half of detector

    hits['layer_id'] = hits['layer_id']/2 - 1
    hits['layer_id'] = hits['volume_id']*10 + hits['layer_id']

    hits = hits.sort_values(by='layer_id')

    edge_index = torch.empty(1,2)

    for i,layer in enumerate(hits['layer_id'].unique()):
        print("Layer:",layer)
        print("Layerindex:",i)
        print("Total number of layers:",hits['layer_id'].unique().shape[0])

        """
        Naiive approach: itertools.product contains nested for loop, not yet parallelized for every layer
        """
        last = int(hits['layer_id'].unique().shape[0]) - 2

        #Hits in current layer
        layer_hits = hits[hits.layer_id == int(layer)]

        #Create all combinations of hits in current layer and next layer
        next_layer = hits['layer_id'].unique()[i+1] 
        next_layer_hits = hits[hits.layer_id == int(next_layer)]
        combinations_next = torch.tensor(list(product(layer_hits['hit_id'],next_layer_hits['hit_id'])))
        edge_index = torch.cat((edge_index,combinations_next),dim=0)

        if i == last:
            break

        #Create all combinations of hits in current layer and nextnext layer
        nextnext_layer = hits['layer_id'].unique()[i+2]
        nextnext_layer_hits = hits[hits.layer_id == int(nextnext_layer)]
        combinations_nextnext = torch.tensor(list(product(layer_hits['hit_id'],nextnext_layer_hits['hit_id'])))
        edge_index = torch.cat((edge_index,combinations_nextnext),dim=0)

    graph = Data()
    graph.x = torch.tensor(hits['hit_id'].values)
    graph.edge_index = edge_index
    
    return graph

graph = semi_fully_connected_graph(evt)
torch.save(graph,'hit_graph_small_event000021000.pyg')

Layer: 80.0
Layerindex: 0
Total number of layers: 10
Layer: 81.0
Layerindex: 1
Total number of layers: 10
Layer: 82.0
Layerindex: 2
Total number of layers: 10
Layer: 83.0
Layerindex: 3
Total number of layers: 10
Layer: 130.0
Layerindex: 4
Total number of layers: 10
Layer: 131.0
Layerindex: 5
Total number of layers: 10
Layer: 132.0
Layerindex: 6
Total number of layers: 10
Layer: 133.0
Layerindex: 7
Total number of layers: 10
Layer: 170.0
Layerindex: 8
Total number of layers: 10


In [28]:
import random
graph = torch.load('hit_graph_small_event000021000.pyg')

#Random filter to remove edges and unconnected nodes
def random_filter(graph, p=0.9999):
    print('Number of edges before filter:',graph.edge_index.shape[0])
    remove_edges = random.sample(range(0,graph.edge_index.shape[0]),int(graph.edge_index.shape[0]*p))
    graph.edge_index = torch.from_numpy(np.delete(graph.edge_index.numpy(),remove_edges,axis=0))
    print('Number of edges after filter:',graph.edge_index.shape[0])
    graph = delete_unconnected_nodes(graph)
    return graph

def delete_unconnected_nodes(graph):
    print('Total number of nodes:',graph.x.shape[0])
    print('Number of connected nodes:',np.unique(graph.edge_index.numpy()).shape[0])
    graph.x = torch.tensor(np.unique(graph.edge_index.numpy()))
    return graph

graph = random_filter(graph)
print(graph)

Number of edges before filter: 190923020
Number of edges after filter: 19093
Total number of nodes: 33590
Number of connected nodes: 22258
Data(x=[22258], edge_index=[19093, 2])


In [None]:
def linegraph(graph):
    

    new_graph = Data()
    return new_graph

In [35]:
A = [0,1,2,3]
B = [4,5,6,7]

combinations = np.array(list(product(A,B))).T
print(combinations)

[[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
 [4 5 6 7 4 5 6 7 4 5 6 7 4 5 6 7]]
