In [1]:
from wmpgnn.configs.config_loader import ConfigLoader
# Load the configuration
config_loader = ConfigLoader("config_files/heteromp_gnn.yaml", environment_prefix="DL")
print(config_loader.get("model"))          # Outputs: "resnet50"
print(config_loader.get("training.starting_learning_rate")) # Outputs: 0.001

# Override via environment variable
#os.environ["DL_MODEL_BATCH_SIZE"] = "64"
print(config_loader.get("training.batch_size"))    # Outputs: 64


{'type': 'heterognn', 'gnn_layers': 6, 'mlp_output_size': 16, 'mlp_layers': 4, 'mlp_channels': 128, 'weight_mlp_layers': 4, 'weight_mlp_channels': 16, 'weighted_mp': False, 'use_edge_weights': True, 'use_node_weights': True, 'node_types': ['tracks', 'PVs'], 'edge_types': ['tracks_tracks', 'tracks_PVs']}
0.001
32


In [2]:
print(config_loader.get("dataset"))  

{'data_type': 'heterogeneous', 'data_dir': '/home/sutclw/Work/Zurich/LHCb/GNNs/cached_hetero_data'}


In [3]:
# files_input_tr = sorted(glob.glob('/home/sutclw/Work/Zurich/LHCb/GNNs/cached_LCA_training_events_PYTHIA/training_dataset/input_*'))
# files_target_tr = sorted(glob.glob('/home/sutclw/Work/Zurich/LHCb/GNNs/cached_LCA_training_events_PYTHIA/training_dataset/target_*'))

# files_input_vl = sorted(glob.glob('/home/sutclw/Work/Zurich/LHCb/GNNs/cached_LCA_training_events_PYTHIA/validation_dataset/input_*'))
# files_target_vl = sorted(glob.glob('/home/sutclw/Work/Zurich/LHCb/GNNs/cached_LCA_training_events_PYTHIA/validation_dataset/target_*'))

In [4]:
# import os
# pv_loc = "/home/sutclw/Work/weighted_MP_gnn/pv_cached_data/"
# for file in files_target_tr:
#     # if "input" in file:
#     #     name = file[file.find("input"):]
#     #     cmd= "cp " + pv_loc + name + f" ../cached_hetero_data/validation_dataset/{name}"
#     #     os.system(cmd)
#     if "target" in file:
#         name = file[file.find("target"):]
#         cmd= "cp " + pv_loc + name + f" ../cached_hetero_data/training_dataset/{name}"
#         os.system(cmd)

# To do now

## Training
* load hetero model
* save dataframe of results
* get hetero dataset in the correct format
* check hetero training
* Figure out performance loss homog training
* integrate options for setting loss pars
* Finalize training script

## Inference
* inference class
* reconstruction performance 

## Paper

## Longer term
* load transformer model
* hyperpar optimization

In [5]:
import numpy as np
import networkx as nx
import os, time

import os.path as osp
import glob
from wmpgnn.datasets.graph_dataset import CustomDataset
from wmpgnn.datasets.hetero_graph_dataset import CustomHeteroDataset
from torch_geometric.loader import DataLoader
import contextlib
import torch
from torch import nn
from torch_scatter import scatter_add

In [6]:
#dataset = config_loader.get("dataset.data_dir")

In [7]:
#print(dataset)

In [8]:
class DataHandler:
    """ Class which loads data using an appropriate 
        dataset class """
    def __init__(self, config):
        config_loader = config
        print
        data_path = config_loader.get("dataset.data_dir")
        data_type = config_loader.get("dataset.data_type")
        files_input_tr = sorted(glob.glob(f'{data_path}/training_dataset/input_*'))
        files_target_tr = sorted(glob.glob(f'{data_path}/training_dataset/target_*'))
        files_input_vl = sorted(glob.glob(f'{data_path}/validation_dataset/input_*'))
        files_target_vl = sorted(glob.glob(f'{data_path}/validation_dataset/target_*'))


        if data_type == "homogeneous":
            self.train_dataset = CustomDataset(files_input_tr, files_target_tr )
            self.val_dataset = CustomDataset(files_input_vl, files_target_vl )
        elif data_type == "heterogeneous":
            self.train_dataset = CustomHeteroDataset(files_input_tr, files_target_tr )
            self.val_dataset = CustomHeteroDataset(files_input_vl, files_target_vl )  
        else:
            raise Exception(f"Unexpected data type {data_type}. Please use homogeneous or heterogeneous.")
            
    def load_data(self):
         self.dataset_tr = self.train_dataset.get()    
         self.dataset_vl = self.val_dataset.get() 
        
    def get_train_dataloader(self, batch_size=32):   
        return DataLoader(self.dataset_tr, batch_size=batch_size, drop_last=True)


    def get_val_dataloader(self, batch_size=32):    
        return DataLoader(self.dataset_vl, batch_size=batch_size, drop_last=True)

In [9]:
config_loader.get("dataset")

{'data_type': 'heterogeneous',
 'data_dir': '/home/sutclw/Work/Zurich/LHCb/GNNs/cached_hetero_data'}

In [10]:
data_loader = DataHandler(config_loader)
data_loader.load_data()

In [11]:
train_loader = data_loader.get_train_dataloader()
val_loader = data_loader.get_val_dataloader()

In [12]:
len(train_loader)

1234

In [13]:
config_loader = ConfigLoader("config_files/heteromp_gnn.yaml", environment_prefix="DL")

config_loader.get("model")

{'type': 'heterognn',
 'gnn_layers': 6,
 'mlp_output_size': 16,
 'mlp_layers': 4,
 'mlp_channels': 128,
 'weight_mlp_layers': 4,
 'weight_mlp_channels': 16,
 'weighted_mp': False,
 'use_edge_weights': True,
 'use_node_weights': True,
 'node_types': ['tracks', 'PVs'],
 'edge_types': ['tracks_tracks', 'tracks_PVs']}

In [14]:

edges = config_loader.get("model")['edge_types']

In [15]:
[(edge.split('_')[0],'to', edge.split('_')[1]) for edge in edges]

[('tracks', 'to', 'tracks'), ('tracks', 'to', 'PVs')]

In [16]:
from wmpgnn.model.gnn_model import GNN
from wmpgnn.model.hetero_gnn_model import HeteroGNN
class ModelLoader:
    """ Class to set up the model """
    def __init__(self, config):
        config_loader = config
        model_type = config_loader.get("model.type")
        model_type="heterognn"
        if model_type == "mpgnn":
            self.model = GNN(mlp_output_size=config_loader.get("model.mlp_output_size"), edge_op=4, num_blocks=config_loader.get("model.gnn_layers"),
                             mlp_layers = config_loader.get("model.mlp_layers"),
                             mlp_channels = config_loader.get("model.mlp_channels"),
                             weight_mlp_channels = config_loader.get("model.weight_mlp_channels"),
                             weight_mlp_layers= config_loader.get("model.weight_mlp_layers"),
                             use_edge_weights = config_loader.get("model.use_edge_weights"),
                             use_node_weights = config_loader.get("model.use_node_weights"),
                             weighted_mp = config_loader.get("model.weighted_mp")
                            )                            
        elif model_type == "heterognn":
            nodes = config_loader.get("model")['node_types']
            edges = config_loader.get("model")['edge_types']
            edges = [(edge.split('_')[0],'to', edge.split('_')[1]) for edge in edges]
            self.model = HeteroGNN(node_types = nodes, edge_types = edges,
                             mlp_output_size=config_loader.get("model.mlp_output_size"), edge_op=4,
                             num_blocks=config_loader.get("model.gnn_layers"),
                             mlp_layers = config_loader.get("model.mlp_layers"),
                             mlp_channels = config_loader.get("model.mlp_channels"),
                             weight_mlp_channels = config_loader.get("model.weight_mlp_channels"),
                             weight_mlp_layers= config_loader.get("model.weight_mlp_layers"),
                             use_edge_weights = config_loader.get("model.use_edge_weights"),
                             use_node_weights = config_loader.get("model.use_node_weights"),
                             weighted_mp = config_loader.get("model.weighted_mp")
                            )    
        elif model_type == "transformer":
            pass
            

    def get_model(self):
        return self.model
        


In [17]:
model_loader = ModelLoader(config_loader)
model = model_loader.get_model()

In [18]:
model

HeteroGNN(
  (_encoder): HeteroGraphCoder(
    (_global_model): WrappedModelFnModule(
      (_model): MLP(-1, 128, 128, 128, 16)
    )
    (_edge_models_model_dict): ModuleDict(
      (('tracks', 'to', 'tracks')): WrappedModelFnModule(
        (_model): MLP(-1, 128, 128, 128, 16)
      )
      (('tracks', 'to', 'PVs')): WrappedModelFnModule(
        (_model): MLP(-1, 128, 128, 128, 16)
      )
    )
    (_node_models_model_dict): ModuleDict(
      (tracks): WrappedModelFnModule(
        (_model): MLP(-1, 128, 128, 128, 16)
      )
      (PVs): WrappedModelFnModule(
        (_model): MLP(-1, 128, 128, 128, 16)
      )
    )
  )
  (_blocks): ModuleList(
    (0-5): 6 x HeteroGraphNetwork(
      (_edge_block): HeteroEdgeBlock(
        (_edge_models_model_dict): ModuleDict(
          (('tracks', 'to', 'tracks')): MLP(-1, 128, 128, 128, 16)
          (('tracks', 'to', 'PVs')): MLP(-1, 128, 128, 128, 16)
        )
      )
      (_node_block): HeteroNodeBlock(
        (_received_edges_aggregat

In [19]:
model._blocks[0]._node_block

HeteroNodeBlock(
  (_received_edges_aggregator): HeteroEdgesToNodesAggregator()
  (_sent_edges_aggregator): HeteroEdgesToNodesAggregator()
  (_node_models_model_dict): ModuleDict(
    (tracks): MLP(-1, 128, 128, 128, 16)
    (PVs): MLP(-1, 128, 128, 128, 16)
  )
)

In [20]:
config_loader.get("model")

{'type': 'heterognn',
 'gnn_layers': 6,
 'mlp_output_size': 16,
 'mlp_layers': 4,
 'mlp_channels': 128,
 'weight_mlp_layers': 4,
 'weight_mlp_channels': 16,
 'weighted_mp': False,
 'use_edge_weights': True,
 'use_node_weights': True,
 'node_types': ['tracks', 'PVs'],
 'edge_types': ['tracks_tracks', 'tracks_PVs']}

In [21]:

from abc import ABC, abstractmethod

class Trainer(ABC):

    def __init__(self, config, model, train_loader, val_loader):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader    
        self.train_acc = []
        self.val_acc = [] 
        self.train_loss = []
        self.val_loss = [] 
        self.epochs = []
        
    @abstractmethod
    def eval_one_epoch(self, train=True):
        pass
        
    @abstractmethod
    def train(self, epochs=10, learning_rate=0.001):
        pass

    def save_model(self, file_name):
        torch.save(self.model.state_dict(), file_name)

    def save_dataframe(self, file_name):
        pass

    def plot_loss(self, file_name="loss.png", show=True):
        import matplotlib.pyplot as plt
        import os
        
        plt.plot(self.train_loss, label="Train Loss")
        plt.plot(self.val_loss, label="Validation Loss")
        
        plt.xlabel('epoch')
        plt.ylabel('Cross Entropy Loss')
        
        plt.legend()
        if show:
            plt.show()
        plt.savefig(file_name)
        

    def plot_accuracy(self, file_name="acc.png", show=True):

        class1_acc_vl=[]
        class2_acc_vl=[]
        class3_acc_vl=[]
        class4_acc_vl=[]
        
        for vl_acc in self.val_acc:
            class1_acc_vl.append(vl_acc[0])
            class2_acc_vl.append(vl_acc[1])
            class3_acc_vl.append(vl_acc[2])
            class4_acc_vl.append(vl_acc[3])
        
        class1_acc_tr=[]
        class2_acc_tr=[]
        class3_acc_tr=[]
        class4_acc_tr=[]
        
        for tr_acc in self.train_acc:
            class1_acc_tr.append(tr_acc[0])
            class2_acc_tr.append(tr_acc[1])
            class3_acc_tr.append(tr_acc[2])
            class4_acc_tr.append(tr_acc[3])
            
        fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
                
        axarr[0].plot(class1_acc_tr, label="LCA=0")
        axarr[0].plot(class2_acc_tr, label="LCA=1")
        axarr[0].plot(class3_acc_tr, label="LCA=2")
        axarr[0].plot(class4_acc_tr, label="LCA=3")
        
        axarr[0].set_xlabel('epoch')
        axarr[0].set_ylabel('training accuracy')
        
        axarr[0].legend()
        
        axarr[1].plot(class1_acc_vl, label="LCA=0")
        axarr[1].plot(class2_acc_vl, label="LCA=1")
        axarr[1].plot(class3_acc_vl, label="LCA=2")
        axarr[1].plot(class4_acc_vl, label="LCA=3")
        
        axarr[1].set_xlabel('epoch')
        axarr[1].set_ylabel('validation accuracy')
        
        axarr[1].legend()
        
        fig.tight_layout()
        if show:
            plt.show()        
        plt.savefig(file_name)
    

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [23]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [24]:
def positive_edge_weight(loader):
    sum_edges = 0 
    sum_pos = 0
    for data in loader:
        sum_edges += data.edges.shape[0]
        sum_pos  += torch.sum(data.y[:,0]==0).item()
    return sum_edges/(2*sum_pos)

def positive_node_weight(loader):
    sum_nodes = 0 
    sum_pos = 0
    for data in loader:
        num_nodes=data.nodes.shape[0]
        #out = data.edges.new_zeros(num_nodes, 4)
        node_sum = scatter_add(data.y,data.senders,dim=0)
        ynodes = (1.*(torch.sum(node_sum[:,1:],1)>0)).unsqueeze(1)
        sum_nodes += num_nodes
        sum_pos  += torch.sum(ynodes==1).item()
    return sum_nodes/(2*sum_pos)

In [25]:
#positive_edge_weight(train_loader)

In [26]:
#positive_node_weight(train_loader)

In [27]:
from wmpgnn.util.functions import weight_four_class, acc_four_class

class GNNTrainer(Trainer):
    """ Class for training """
    def __init__(self, config, model, train_loader, val_loader, add_BCE=True):
        super().__init__(config, model, train_loader, val_loader)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        weights=weight_four_class(self.train_loader)
        self.criterion = nn.CrossEntropyLoss(weight=weights)
        pos_weight = positive_edge_weight(train_loader)
        pos_weight = torch.tensor([pos_weight])  
        self.criterionBCE = nn.BCEWithLogitsLoss(weight=pos_weight) 
        pos_weight = positive_node_weight(train_loader)
        pos_weight = torch.tensor([pos_weight])
        self.criterionBCEnodes = nn.BCEWithLogitsLoss(weight=pos_weight) 
        
        self.criterion.to('cuda')
        self.criterionBCE.cuda()
        self.criterionBCEnodes.cuda()
        self.model.cuda()    

        self.add_BCE = add_BCE
        self.alpha_BCE = 0.2

    def set_alpha_BCE(alpha):
        self.alpha_BCE = alpha

    def eval_one_epoch(self, train = True):
        running_loss = 0.
        last_loss = 0.
        acc_one_epoch = []
        if train == True:
            data_loader = self.train_loader
        else:
            data_loader = self.val_loader
        last_batch = len(data_loader)
        for i, data in enumerate(data_loader): 
            data['graph_globals'] = data['graph_globals'].unsqueeze(1)
            data.receivers = data.receivers - torch.min(data.receivers)
            data.senders = data.senders - torch.min(data.senders)
            data.edgepos = data.edgepos - torch.min(data.edgepos)
            if train:
                self.optimizer.zero_grad()

            data.to('cuda')
            yBCE_start = 1.*(data.y[:,0]==0).unsqueeze(1)
            num_nodes=data.nodes.shape[0]
            out = data.edges.new_zeros(num_nodes, data.edges.shape[1])
            node_sum = scatter_add(data.y,data.senders,out=out,dim=0)
            ynodes_start = (1.*(torch.sum(node_sum[:,1:],1)>0)).unsqueeze(1)
            label0 = data.y.argmax(dim=1)
            answers = torch.ones_like(data.edges).cuda()
    
            outputs = self.model(data)
            data = outputs
            label= data.y.argmax(dim=1)
            num_nodes=data.nodes.shape[0]
            out = data.edges.new_zeros(num_nodes, data.edges.shape[1])
            node_sum = scatter_add(data.y,data.senders,out=out,dim=0)
            ynodes = (1.*(torch.sum(node_sum[:,1:],1)>0)).unsqueeze(1)
        
            loss = self.criterion(outputs.edges, label) 
            yBCE = 1.*(data.y[:,0]==0).unsqueeze(1)
            if self.add_BCE:
                count = 0 
                for block in self.model._blocks:
                    loss += self.alpha_BCE*self.criterionBCE(block._network.edge_logits, yBCE)
                    loss += self.alpha_BCE*self.criterionBCEnodes(block._network.node_logits, ynodes)
                    count += 1
        
            
            acc_one_batch = acc_four_class(outputs.edges, label)
            acc_one_epoch.append(acc_one_batch)
            if train:
                loss.backward()            
                self.optimizer.step()
    
            running_loss += loss.item()
            if (i+1) == last_batch:
                last_loss = running_loss / last_batch # loss per batch
                print('  batch {} last_batch {} loss: {}'.format(i + 1, last_batch, last_loss))
    
                running_loss = 0.
        
        acc_one_epoch=torch.stack(acc_one_epoch)
    
        return last_loss, acc_one_epoch.nanmean(dim=0)        
        
    def train(self, epochs=10, starting_epoch =0, learning_rate=0.001):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        for epoch in range(starting_epoch, epochs):
            print(f"At epoch {epoch}")
            self.epochs.append(epoch)
            train_loss, train_acc = self.eval_one_epoch()
            self.model.train(False)
            val_loss, val_acc = self.eval_one_epoch(train=False)    
            self.train_loss.append(train_loss)
            self.train_acc.append(train_acc)
            self.val_loss.append(val_loss)
            self.val_acc.append(val_acc)
            print(f'Train Loss: {train_loss:03f}')
            print(f'Val Loss: {val_loss:03f}')
            print(f'Train Acc: {train_acc}')
            print(f'Val Acc: {val_acc}')
            

            


In [47]:
from wmpgnn.util.functions import weight_four_class, acc_four_class


def hetero_positive_edge_weight(loader):
    sum_edges = 0 
    sum_pos = 0
    for data in loader:
        sum_edges += data[('tracks','to','tracks')].edges.shape[0]
        sum_pos  += torch.sum(data[('tracks','to','tracks')].y[:,0]==0).item()
    return sum_edges/(2*sum_pos)

def hetero_positive_node_weight(loader):
    sum_nodes = 0 
    sum_pos = 0
    for data in loader:
        num_nodes=data['tracks'].x.shape[0]
        #out = data.edges.new_zeros(num_nodes, 4)
        node_sum = scatter_add(data[('tracks','to','tracks')].y, data[('tracks','to','tracks')].edge_index[0],dim=0)
        ynodes = (1.*(torch.sum(node_sum[:,1:],1)>0)).unsqueeze(1)
        sum_nodes += num_nodes
        sum_pos  += torch.sum(ynodes==1).item()
    return sum_nodes/(2*sum_pos)
    
class HeteroGNNTrainer(Trainer):
    """ Class for training """
    def __init__(self, config, model, train_loader, val_loader, add_BCE=True):
        super().__init__(config, model, train_loader, val_loader)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        weights=weight_four_class(self.train_loader,hetero=True)
        self.criterion = nn.CrossEntropyLoss(weight=weights)
        pos_weight = hetero_positive_edge_weight(train_loader)
        pos_weight = torch.tensor([pos_weight])  
        print(pos_weight)
        #self.criterionBCE = nn.BCELoss(weight=pos_weight) 
        self.criterionBCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 
        pos_weight = hetero_positive_node_weight(train_loader)
        pos_weight = torch.tensor([pos_weight])
        print(pos_weight)
        #self.criterionBCEnodes = nn.BCELoss(weight=pos_weight) 
        self.criterionBCEnodes = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 
        pos_weight = torch.tensor([6.1118])
        #self.criterionBCE_PV = nn.BCELoss(weight=pos_weight) 
        self.criterionBCE_PV = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 
        self.criterionBCE_PV.cuda()
        
        self.criterion.to('cuda')
        self.criterionBCE.cuda()
        self.criterionBCEnodes.cuda()
        self.model.cuda()    

        self.add_BCE = add_BCE
        self.alpha_BCE = 0.2

        self.train_PV_acc = []
        self.val_PV_acc = []

    def set_alpha_BCE(alpha):
        self.alpha_BCE = alpha
        
    def eval_one_epoch(self, train = True):
        running_loss = 0.
        last_loss = 0.
        acc_one_epoch = []
        PV_acc_one_epoch = []
        PV_node_acc_one_epoch = []
        if train == True:
            data_loader = self.train_loader
        else:
            data_loader = self.val_loader
        last_batch = len(data_loader)
        #print(last_batch)
        for i, data in enumerate(data_loader): 
            #print(i, train)
            if train:
                self.optimizer.zero_grad()
            data.to('cuda')
            data['tracks'].x = torch.hstack( [ data['tracks'].x[:,:6] , data['tracks'].x[:,9:10] ] )
            
            outputs = self.model(data)

            data = outputs
            
            label= data[('tracks', 'to', 'tracks')].y.argmax(dim=1)
            PVlabel= torch.tensor(data[('tracks', 'to', 'PVs')].y,dtype=torch.float32)

            loss = self.criterion(outputs[('tracks', 'to', 'tracks')].edges, label) 

            num_nodes=data['tracks'].x.shape[0]
            out = data[('tracks', 'to', 'tracks')].edges.new_zeros(num_nodes, data[('tracks', 'to', 'tracks')].y.shape[1])
            node_sum = scatter_add(data[('tracks', 'to', 'tracks')].y,data[('tracks', 'to', 'tracks')].edge_index[0],out=out,dim=0)
            ynodes = (1.*(torch.sum(node_sum[:,1:],1)>0)).unsqueeze(1)
            yBCE = 1.*(data[('tracks', 'to', 'tracks')].y[:,0]==0).unsqueeze(1)
    
            yb = ynodes[data[('tracks', 'to', 'PVs')]['edge_index'][0]]*data[('tracks', 'to', 'PVs')].y
            pv_sum = scatter_add(yb,data[('tracks', 'to', 'PVs')].edge_index[1],dim=0)
            pv_target = 1.*(pv_sum > 0)
            
            if self.add_BCE:
                for block in self.model._blocks:
                    # loss += 1*self.criterionBCE(block.edge_weights[('tracks', 'to', 'tracks')], yBCE)
                    # loss += 1*self.criterionBCE_PV(block.edge_weights[('tracks', 'to', 'PVs')], PVlabel)  
                    # loss += 1*self.criterionBCEnodes(block.node_weights['tracks'], ynodes)
                    # loss += 1*self.criterionBCEnodes(block.node_weights['tracks'], ynodes)
                    loss += 1*self.criterionBCE(block.edge_logits[('tracks', 'to', 'tracks')], yBCE)
                    loss += 1*self.criterionBCE_PV(block.edge_logits[('tracks', 'to', 'PVs')], PVlabel)  
                    loss += 1*self.criterionBCEnodes(block.node_logits['tracks'], ynodes)
                    #loss += 1*self.criterionBCEnodes(block.node_logits['PVs'], pv_target)
            acc_one_batch = acc_four_class(outputs[('tracks', 'to', 'tracks')].edges, label)
            acc_one_epoch.append(acc_one_batch)

            PV_acc_one_batch =  torch.sum(PVlabel == ( self.model._blocks[-1].edge_weights[('tracks', 'to', 'PVs')]> 0.5 ))/ PVlabel.shape[0]
            PV_acc_one_epoch.append(PV_acc_one_batch)     

            PV_node_acc_one_batch =  torch.sum(pv_target == ( self.model._blocks[-1].node_weights['PVs']> 0.5 ))/ pv_target.shape[0]
            PV_node_acc_one_epoch.append(PV_node_acc_one_batch)    
            if train:
                loss.backward()
                self.optimizer.step()
    
            running_loss += loss.item()
            if (i+1) == last_batch:
                last_loss = running_loss / last_batch # loss per batch
                print('  batch {} last_batch {} loss: {}'.format(i + 1, last_batch, last_loss))
    
                running_loss = 0.
        
        acc_one_epoch=torch.stack(acc_one_epoch)
        PV_acc_one_epoch=torch.stack(PV_acc_one_epoch)
        PV_node_acc_one_epoch=torch.stack(PV_node_acc_one_epoch)

        return last_loss, acc_one_epoch.nanmean(dim=0), PV_acc_one_epoch.nanmean(dim=0), PV_node_acc_one_epoch.nanmean(dim=0)
            
     
    def train(self, epochs=10, starting_epoch =0, learning_rate=0.001):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        for epoch in range(starting_epoch, epochs):
            print(f"At epoch {epoch}")
            self.epochs.append(epoch)
            train_loss, train_acc, train_PV_acc, train_B_PV_acc = self.eval_one_epoch()
            self.model.train(False)
            val_loss, val_acc, val_PV_acc, val_B_PV_acc = self.eval_one_epoch(train=False)    
            self.train_loss.append(train_loss)
            self.train_acc.append(train_acc)
            self.train_PV_acc.append(train_PV_acc)
            self.val_loss.append(val_loss)
            self.val_acc.append(val_acc)
            self.val_PV_acc.append(val_PV_acc)
            print(f'Train Loss: {train_loss:03f}')
            print(f'Val Loss: {val_loss:03f}')
            print(f'Train Acc: {train_acc}')
            print(f'Val Acc: {val_acc}')
            print(f'Train PV edge Acc: {train_PV_acc}')
            print(f'Val PV edge Acc: {val_PV_acc}')
            print(f'Train PV edge Acc: {train_B_PV_acc}')
            print(f'Val PV edge Acc: {val_B_PV_acc}')

In [42]:
#model_loader = ModelLoader(config_loader)
#model = model_loader.get_model()

In [43]:
len(val_loader)

290

In [44]:
#weights = torch.tensor([2.5158e-01, 1.7982e+02, 6.0619e+01, 3.2675e+02])

In [48]:
trainer  = HeteroGNNTrainer(config_loader, model, train_loader, val_loader, add_BCE=True)

tensor([2.5165e-01, 1.7152e+02, 5.7953e+01, 3.1798e+02])
tensor([76.2227])
tensor([8.5526])


In [49]:
trainer.train(10)

At epoch 0


  PVlabel= torch.tensor(data[('tracks', 'to', 'PVs')].y,dtype=torch.float32)


  batch 1234 last_batch 1234 loss: 2.8590355613050042
  batch 290 last_batch 290 loss: 2.8858507131708078
Train Loss: 2.859036
Val Loss: 2.885851
Train Acc: tensor([0.9763, 0.6959, 0.5021, 0.8447])
Val Acc: tensor([0.9817, 0.7637, 0.4233, 0.8571])
Train PV edge Acc: 0.9925808310508728
Val PV edge Acc: 0.9917919039726257
Train PV edge Acc: 0.9791470766067505
Val PV edge Acc: 0.9788030982017517
At epoch 1
  batch 1234 last_batch 1234 loss: 2.809573397740746
  batch 290 last_batch 290 loss: 2.9316425019297108
Train Loss: 2.809573
Val Loss: 2.931643
Train Acc: tensor([0.9773, 0.7004, 0.5084, 0.8475])
Val Acc: tensor([0.9836, 0.7363, 0.4767, 0.8367])
Train PV edge Acc: 0.9928807616233826
Val PV edge Acc: 0.992691159248352
Train PV edge Acc: 0.9814252257347107
Val PV edge Acc: 0.9837502241134644
At epoch 2
  batch 1234 last_batch 1234 loss: 2.7549289445629768
  batch 290 last_batch 290 loss: 2.987028119481843
Train Loss: 2.754929
Val Loss: 2.987028
Train Acc: tensor([0.9775, 0.7023, 0.5130, 

In [50]:
trainer.train(12, starting_epoch=10,learning_rate=0.0001)

At epoch 10


  PVlabel= torch.tensor(data[('tracks', 'to', 'PVs')].y,dtype=torch.float32)


  batch 1234 last_batch 1234 loss: 3.289735750209184
  batch 290 last_batch 290 loss: 3.0422112662216714
Train Loss: 3.289736
Val Loss: 3.042211
Train Acc: tensor([0.9733, 0.6762, 0.5066, 0.8298])
Val Acc: tensor([0.9737, 0.7502, 0.4813, 0.8409])
Train PV edge Acc: 0.9886472821235657
Val PV edge Acc: 0.9903491735458374
Train PV edge Acc: 0.978159487247467
Val PV edge Acc: 0.978970468044281
At epoch 11
  batch 1234 last_batch 1234 loss: 2.828364562370016
  batch 290 last_batch 290 loss: 2.7923335231583692
Train Loss: 2.828365
Val Loss: 2.792334
Train Acc: tensor([0.9780, 0.7158, 0.5333, 0.8540])
Val Acc: tensor([0.9767, 0.7547, 0.5050, 0.8484])
Train PV edge Acc: 0.9919545650482178
Val PV edge Acc: 0.9925172924995422
Train PV edge Acc: 0.9785208702087402
Val PV edge Acc: 0.9764626622200012


In [None]:
trainer.train(32, starting_epoch=30,learning_rate=0.0001)

In [None]:
trainer.plot_loss(show=True)

In [None]:
trainer.model