In [1]:
import os
import yaml
import pickle
import numpy as np
import pandas as pd
import itertools
import scipy
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as Data
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as transforms
import networkx as nx
from torch_geometric.utils.convert import to_networkx
import matplotlib.pyplot as plt

from produce_dataset import *

import torch.nn.functional as F
from torch_geometric.nn import GCNConv

In [2]:
ntuple_dir = '/ecoderemdvol/EleGun/EPGun-PU200/data/econ_ntuples/0002/'
root_dir = 'FloatingpointAutoEncoderEMDAEMSEttbarDummyHistomaxGenmatchGenclustersntuple'

df_econ = loadEconData(ntuple_dir,root_dir,'econ_data.csv',False)

df_gen = loadGenData(ntuple_dir,root_dir,'gen_data.csv')

In [3]:
def build_graph(df_econ,
            df_gen,
            zside_select,
            phi_min,
            phi_max):
    
    if zside_select !=0:
        df_econ = df_econ[df_econ.zside == zside_select]


    #only choose wafers in some phi range now
    
    df_econ = df_econ[df_econ.tc_phi>phi_min][df_econ.tc_phi<phi_max];
    
    df_nodes=df_econ[['ECON_0', 'ECON_1', 'ECON_2', 'ECON_3', 'ECON_4', 'ECON_5', 'ECON_6',
           'ECON_7', 'ECON_8', 'ECON_9', 'ECON_10', 'ECON_11', 'ECON_12','ECON_13', 'ECON_14', 'ECON_15', 
            'wafer_energy', 'tc_eta', 'tc_phi']]
    df_nodes.reset_index(inplace=True)
    
    embeddings = torch.tensor(df_nodes.values)

    eta=df_nodes['tc_eta']
    phi=df_nodes['tc_phi']
    idx = range(len(eta))
    indices_i = np.array([i for i,j in itertools.product(idx,idx)])
    indices_j = np.array([j for i,j in itertools.product(idx,idx)])

    del_R = np.empty([len(eta),len(eta)])
    del_R.shape

    for (i, j) in zip(indices_i,indices_j):
        del_R[i][j]=np.sqrt((eta[i]-eta[j])**2+((phi[i]-phi[j])%(2*np.pi))**2)

    del_R = torch.tensor(del_R)

    adj = np.zeros([len(eta),len(eta)])
    for (i, j) in zip(indices_i,indices_j):
        if del_R[i][j] <0.045 and  del_R[i][j]> 0 :
            adj[i][j]=1.0

    adj=torch.tensor(adj) 

    edge_index = (adj > 0.0).nonzero().t()
    edge_index.shape

    #predict node level feature: PU vs electron per simenergy

    feature_df= df_econ['wafer_energy']
    labels = (feature_df.where(feature_df==0,other=1))
    features = torch.tensor(labels.values).to(torch.long)

    graph = Data.Data(x=embeddings, edge_index=edge_index, y=features)
    graph.num_classes=2
    return graph

In [4]:
#Choose only front endcap for training

df_gen = (df_gen[df_gen.eta>0])
df_gen.reset_index(drop=True,inplace=True)

In [5]:
graphs = []

In [6]:
for gen_phi in df_gen['phi']:
    phi_min,phi_max= gen_phi-np.pi/12,gen_phi+np.pi/12
    graphs.append(build_graph(df_econ,df_gen,1,phi_min,phi_max));
    

  df_econ = df_econ[df_econ.tc_phi>phi_min][df_econ.tc_phi<phi_max];


In [9]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(graphs[0].num_features, 16)
        self.conv2 = GCNConv(16, graphs[0].num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)


In [10]:
device = torch.device('cuda')
model = GCN().to(device)
model = model.double()

In [11]:
data = graphs[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [15]:
model.train()
for epoch in range(1000):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out,data.y)
    loss.backward()
    optimizer.step()

In [16]:
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred == data.y).sum()
acc = int(correct)/763
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.9109


In [31]:
sums=0
for graph in graphs:
    val_data=graph.to(device)
    pred = model(val_data).argmax(dim=1)
    correct = (pred == val_data.y).sum()
    acc = int(correct)/val_data.y.size(dim=0)
    sums =sums+  acc
    print(f'Accuracy: {acc:.4f}')

Accuracy: 0.9109
Accuracy: 0.8229
Accuracy: 0.8533
Accuracy: 0.8551
Accuracy: 0.8428
Accuracy: 0.8595
Accuracy: 0.8229
Accuracy: 0.8599
Accuracy: 0.9017
Accuracy: 0.8395


In [32]:
sums/10

0.8568444354752167