# Example of Graph Neural Network

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

sys.path.append('../..')
device = "cuda" if torch.cuda.is_available() else "cpu"

## Attention Mechanism

In [3]:
from LightningModules.GNN.Models.agnn import ResAGNN
from LightningModules.GNN.Models.vanilla_agnn import VanillaResAGNN
from LightningModules.GNN.Models.checkpoint_agnn import CheckpointedResAGNN
from LightningModules.GNN.Models.interaction_multistep_gnn import CheckpointedInteractionMultistepGNN

In [4]:
with open("example_gnn.yaml") as f:
        hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
# model = VanillaResAGNN(hparams)

In [6]:
model = CheckpointedResAGNN(hparams)

In [7]:
# model = CheckpointedInteractionMultistepGNN(hparams)

### Dataset

In [8]:
%%time
model.setup(stage="fit")

CPU times: user 274 µs, sys: 4.13 ms, total: 4.4 ms
Wall time: 4.19 ms


In [9]:
sample = model.trainset[0]

In [10]:
sample

Data(cell_data=[63740, 11], edge_index=[2, 402137], event_file="/project/projectdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000000001", hid=[63740], modulewise_true_edges=[2, 56793], nhits=[63740], pid=[63740], primary=[63740], pt=[63740], signal_true_edges=[2, 56502], x=[63740, 3], y=[402137], y_pid=[402137])

In [11]:
sample.y.sum()/sample.signal_true_edges.shape[1]

tensor(0.9361)

In [12]:
sample.y.sum()/sample.edge_index.shape[1]

tensor(0.1315)

In [13]:
edges = sample.edge_index

In [14]:
pid = sample.pid

In [15]:
edges.shape

torch.Size([2, 402137])

In [16]:
(sample.pid[edges[0]] == sample.pid[edges[1]]).sum()/sample.edge_index.shape[1]

tensor(0.2224)

### Memory Test

In [17]:
%%time
model.setup(stage="fit")

CPU times: user 370 µs, sys: 4.23 ms, total: 4.6 ms
Wall time: 4.35 ms


In [18]:
sample = model.trainset[0].to(device)

In [19]:
model = model.to(device)

In [20]:
torch.cuda.reset_peak_memory_stats()
output = model(sample.x.to(device), sample.edge_index.to(device))

In [21]:
print(torch.cuda.max_memory_allocated()/1024**3, "Gb")

4.124407768249512 Gb


### Train GNN

In [22]:
import ninja

In [23]:
from pytorch_lightning.plugins import DeepSpeedPlugin

In [24]:
model.setup(stage="fit")

In [25]:
import torch

In [None]:
logger = WandbLogger(project="ITk_0.5GeV_GNN", group="Gnn_train_1500")
trainer = Trainer(gpus=1, max_epochs=10, logger=logger, precision=16, default_root_dir="/global/cfs/cdirs/m3443/usr/ryanliu/gnn_example/")
trainer.fit(model)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mexatrkx[0m (use `wandb login --relogin` to force relogin)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | edge_network  | Sequential | 32.7 K
1 | node_network  | Sequential | 21.6 K
2 | input_network | Sequential | 13.2 K
---------------------------------------------
67.6 K    Trainable params
0         Non-trainable params
67.6 K    Total params
0.135     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Validation sanity check:  50%|█████     | 1/2 [00:01<00:01,  1.10s/it]



                                                                      

  rank_zero_warn(


Epoch 0:  94%|█████████▍| 1500/1600 [1:05:59<04:23,  2.64s/it, loss=0.429, v_num=zb3a]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/100 [00:00<?, ?it/s][A




Epoch 0:  94%|█████████▍| 1502/1600 [1:06:00<04:18,  2.64s/it, loss=0.429, v_num=zb3a]




Validating:   2%|▏         | 2/100 [00:01<01:05,  1.50it/s][A




Epoch 0:  94%|█████████▍| 1504/1600 [1:06:01<04:12,  2.63s/it, loss=0.429, v_num=zb3a]




Validating:   4%|▍         | 4/100 [00:03<01:16,  1.26it/s][A




Epoch 0:  94%|█████████▍| 1506/1600 [1:06:03<04:07,  2.63s/it, loss=0.429, v_num=zb3a]




Validating:   6%|▌         | 6/100 [00:04<01:14,  1.26it/s][A




Epoch 0:  94%|█████████▍| 1508/1600 [1:06:05<04:01,  2.63s/it, loss=0.429, v_num=zb3a]




Validating:   8%|▊         | 8/100 [00:05<01:02,  1.47it/s][A




Epoch 0:  94%|█████████▍| 1510/1600 [1:06:06<03:56,  2.63s/it, loss=0.429, v_num=zb3a]




Validating:  10%|█         | 10/100 [00:07<01:02,  1.43it/s][A




Epoch 0:  94%|█████████▍| 1512/1600 [1:06:08<03:50,  2.62s/it, loss=0.429, v_num=zb3a]




Validating:  12%|█▏        | 12/100 [00:08<01:01,  1.42it/s][A




Epoch 0:  95%|█████████▍| 1514/1600 [1:06:09<03:45,  2.62s/it, loss=0.429, v_num=zb3a]




Validating:  14%|█▍        | 14/100 [00:09<00:52,  1.64it/s][A




Epoch 0:  95%|█████████▍| 1516/1600 [1:06:10<03:39,  2.62s/it, loss=0.429, v_num=zb3a]




Validating:  16%|█▌        | 16/100 [00:11<00:52,  1.59it/s][A




Epoch 0:  95%|█████████▍| 1518/1600 [1:06:11<03:34,  2.62s/it, loss=0.429, v_num=zb3a]




Validating:  18%|█▊        | 18/100 [00:12<00:54,  1.50it/s][A




Epoch 0:  95%|█████████▌| 1520/1600 [1:06:13<03:29,  2.61s/it, loss=0.429, v_num=zb3a]




Validating:  20%|██        | 20/100 [00:13<00:51,  1.56it/s][A




Epoch 0:  95%|█████████▌| 1522/1600 [1:06:14<03:23,  2.61s/it, loss=0.429, v_num=zb3a]




Validating:  22%|██▏       | 22/100 [00:14<00:47,  1.64it/s][A




Epoch 0:  95%|█████████▌| 1524/1600 [1:06:15<03:18,  2.61s/it, loss=0.429, v_num=zb3a]




Validating:  24%|██▍       | 24/100 [00:16<00:45,  1.69it/s][A




Epoch 0:  95%|█████████▌| 1526/1600 [1:06:16<03:12,  2.61s/it, loss=0.429, v_num=zb3a]




Validating:  26%|██▌       | 26/100 [00:17<00:42,  1.76it/s][A




Epoch 0:  96%|█████████▌| 1528/1600 [1:06:17<03:07,  2.60s/it, loss=0.429, v_num=zb3a]




Validating:  28%|██▊       | 28/100 [00:18<00:43,  1.64it/s][A




Epoch 0:  96%|█████████▌| 1530/1600 [1:06:19<03:02,  2.60s/it, loss=0.429, v_num=zb3a]




Validating:  30%|███       | 30/100 [00:20<00:48,  1.43it/s][A




Epoch 0:  96%|█████████▌| 1532/1600 [1:06:20<02:56,  2.60s/it, loss=0.429, v_num=zb3a]




Validating:  32%|███▏      | 32/100 [00:21<00:47,  1.44it/s][A




Epoch 0:  96%|█████████▌| 1534/1600 [1:06:21<02:51,  2.60s/it, loss=0.429, v_num=zb3a]




Validating:  34%|███▍      | 34/100 [00:22<00:44,  1.49it/s][A




Epoch 0:  96%|█████████▌| 1536/1600 [1:06:23<02:45,  2.59s/it, loss=0.429, v_num=zb3a]




Validating:  36%|███▌      | 36/100 [00:23<00:40,  1.59it/s][A




Epoch 0:  96%|█████████▌| 1538/1600 [1:06:24<02:40,  2.59s/it, loss=0.429, v_num=zb3a]




Validating:  38%|███▊      | 38/100 [00:25<00:38,  1.61it/s][A




Epoch 0:  96%|█████████▋| 1540/1600 [1:06:25<02:35,  2.59s/it, loss=0.429, v_num=zb3a]




Validating:  40%|████      | 40/100 [00:26<00:40,  1.49it/s][A




Epoch 0:  96%|█████████▋| 1542/1600 [1:06:27<02:29,  2.59s/it, loss=0.429, v_num=zb3a]




Validating:  42%|████▏     | 42/100 [00:27<00:36,  1.60it/s][A




Epoch 0:  96%|█████████▋| 1544/1600 [1:06:28<02:24,  2.58s/it, loss=0.429, v_num=zb3a]




Validating:  44%|████▍     | 44/100 [00:29<00:38,  1.47it/s][A




Epoch 0:  97%|█████████▋| 1546/1600 [1:06:29<02:19,  2.58s/it, loss=0.429, v_num=zb3a]




Validating:  46%|████▌     | 46/100 [00:30<00:36,  1.48it/s][A




Epoch 0:  97%|█████████▋| 1548/1600 [1:06:31<02:14,  2.58s/it, loss=0.429, v_num=zb3a]




Validating:  48%|████▊     | 48/100 [00:31<00:32,  1.59it/s][A




Epoch 0:  97%|█████████▋| 1550/1600 [1:06:32<02:08,  2.58s/it, loss=0.429, v_num=zb3a]
Epoch 1:  94%|█████████▍| 1500/1600 [1:05:56<04:23,  2.64s/it, loss=0.3, v_num=zb3a]  
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/100 [00:00<?, ?it/s][A




Epoch 1:  94%|█████████▍| 1502/1600 [1:05:57<04:18,  2.63s/it, loss=0.3, v_num=zb3a]




Validating:   2%|▏         | 2/100 [00:01<01:10,  1.39it/s][A




Epoch 1:  94%|█████████▍| 1504/1600 [1:05:58<04:12,  2.63s/it, loss=0.3, v_num=zb3a]




Validating:   4%|▍         | 4/100 [00:03<01:12,  1.33it/s][A




Epoch 1:  94%|█████████▍| 1506/1600 [1:06:00<04:07,  2.63s/it, loss=0.3, v_num=zb3a]




Validating:   6%|▌         | 6/100 [00:04<01:06,  1.42it/s][A




Epoch 1:  94%|█████████▍| 1508/1600 [1:06:01<04:01,  2.63s/it, loss=0.3, v_num=zb3a]




Validating:   8%|▊         | 8/100 [00:05<01:04,  1.42it/s][A




Epoch 1:  94%|█████████▍| 1510/1600 [1:06:03<03:56,  2.62s/it, loss=0.3, v_num=zb3a]




Validating:  10%|█         | 10/100 [00:07<01:05,  1.37it/s][A




Epoch 1:  94%|█████████▍| 1512/1600 [1:06:04<03:50,  2.62s/it, loss=0.3, v_num=zb3a]




Validating:  12%|█▏        | 12/100 [00:08<01:03,  1.38it/s][A




Epoch 1:  95%|█████████▍| 1514/1600 [1:06:05<03:45,  2.62s/it, loss=0.3, v_num=zb3a]




Validating:  14%|█▍        | 14/100 [00:09<00:53,  1.61it/s][A




Epoch 1:  95%|█████████▍| 1516/1600 [1:06:06<03:39,  2.62s/it, loss=0.3, v_num=zb3a]




Validating:  16%|█▌        | 16/100 [00:11<00:50,  1.66it/s][A




Epoch 1:  95%|█████████▍| 1518/1600 [1:06:08<03:34,  2.61s/it, loss=0.3, v_num=zb3a]




Validating:  18%|█▊        | 18/100 [00:12<00:52,  1.57it/s][A




Epoch 1:  95%|█████████▌| 1520/1600 [1:06:09<03:28,  2.61s/it, loss=0.3, v_num=zb3a]




Validating:  20%|██        | 20/100 [00:13<00:50,  1.59it/s][A




Epoch 1:  95%|█████████▌| 1522/1600 [1:06:10<03:23,  2.61s/it, loss=0.3, v_num=zb3a]




Validating:  22%|██▏       | 22/100 [00:14<00:47,  1.66it/s][A




Epoch 1:  95%|█████████▌| 1524/1600 [1:06:12<03:18,  2.61s/it, loss=0.3, v_num=zb3a]




Validating:  24%|██▍       | 24/100 [00:15<00:44,  1.70it/s][A




Epoch 1:  95%|█████████▌| 1526/1600 [1:06:12<03:12,  2.60s/it, loss=0.3, v_num=zb3a]




Validating:  26%|██▌       | 26/100 [00:17<00:42,  1.76it/s][A




Epoch 1:  96%|█████████▌| 1528/1600 [1:06:14<03:07,  2.60s/it, loss=0.3, v_num=zb3a]




Validating:  28%|██▊       | 28/100 [00:18<00:43,  1.65it/s][A




Epoch 1:  96%|█████████▌| 1530/1600 [1:06:15<03:01,  2.60s/it, loss=0.3, v_num=zb3a]




Validating:  30%|███       | 30/100 [00:19<00:48,  1.43it/s][A




Epoch 1:  96%|█████████▌| 1532/1600 [1:06:17<02:56,  2.60s/it, loss=0.3, v_num=zb3a]




Validating:  32%|███▏      | 32/100 [00:21<00:47,  1.44it/s][A




Epoch 1:  96%|█████████▌| 1534/1600 [1:06:18<02:51,  2.59s/it, loss=0.3, v_num=zb3a]




Validating:  34%|███▍      | 34/100 [00:22<00:44,  1.49it/s][A




Epoch 1:  96%|█████████▌| 1536/1600 [1:06:19<02:45,  2.59s/it, loss=0.3, v_num=zb3a]




Validating:  36%|███▌      | 36/100 [00:23<00:40,  1.58it/s][A




Epoch 1:  96%|█████████▌| 1538/1600 [1:06:20<02:40,  2.59s/it, loss=0.3, v_num=zb3a]




Validating:  38%|███▊      | 38/100 [00:25<00:38,  1.61it/s][A




Epoch 1:  96%|█████████▋| 1540/1600 [1:06:22<02:35,  2.59s/it, loss=0.3, v_num=zb3a]




Validating:  40%|████      | 40/100 [00:26<00:40,  1.48it/s][A




Epoch 1:  96%|█████████▋| 1542/1600 [1:06:23<02:29,  2.58s/it, loss=0.3, v_num=zb3a]




Validating:  42%|████▏     | 42/100 [00:27<00:36,  1.60it/s][A




Epoch 1:  96%|█████████▋| 1544/1600 [1:06:25<02:24,  2.58s/it, loss=0.3, v_num=zb3a]




Validating:  44%|████▍     | 44/100 [00:29<00:38,  1.47it/s][A




Epoch 1:  97%|█████████▋| 1546/1600 [1:06:26<02:19,  2.58s/it, loss=0.3, v_num=zb3a]




Validating:  46%|████▌     | 46/100 [00:30<00:36,  1.49it/s][A




Epoch 1:  97%|█████████▋| 1548/1600 [1:06:27<02:13,  2.58s/it, loss=0.3, v_num=zb3a]




Validating:  48%|████▊     | 48/100 [00:31<00:32,  1.60it/s][A




Epoch 1:  97%|█████████▋| 1550/1600 [1:06:28<02:08,  2.57s/it, loss=0.3, v_num=zb3a]
Epoch 2:  94%|█████████▍| 1500/1600 [1:06:00<04:24,  2.64s/it, loss=0.254, v_num=zb3a]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/100 [00:00<?, ?it/s][A




Epoch 2:  94%|█████████▍| 1502/1600 [1:06:01<04:18,  2.64s/it, loss=0.254, v_num=zb3a]




Validating:   2%|▏         | 2/100 [00:01<01:09,  1.40it/s][A




Epoch 2:  94%|█████████▍| 1504/1600 [1:06:02<04:12,  2.63s/it, loss=0.254, v_num=zb3a]




Validating:   4%|▍         | 4/100 [00:02<01:11,  1.34it/s][A




Epoch 2:  94%|█████████▍| 1506/1600 [1:06:04<04:07,  2.63s/it, loss=0.254, v_num=zb3a]




Validating:   6%|▌         | 6/100 [00:04<01:06,  1.42it/s][A




Epoch 2:  94%|█████████▍| 1508/1600 [1:06:05<04:01,  2.63s/it, loss=0.254, v_num=zb3a]




Validating:   8%|▊         | 8/100 [00:05<00:58,  1.58it/s][A




Epoch 2:  94%|█████████▍| 1510/1600 [1:06:07<03:56,  2.63s/it, loss=0.254, v_num=zb3a]




Validating:  10%|█         | 10/100 [00:06<01:01,  1.46it/s][A




Epoch 2:  94%|█████████▍| 1512/1600 [1:06:08<03:50,  2.62s/it, loss=0.254, v_num=zb3a]




Validating:  12%|█▏        | 12/100 [00:08<01:01,  1.44it/s][A




Epoch 2:  95%|█████████▍| 1514/1600 [1:06:09<03:45,  2.62s/it, loss=0.254, v_num=zb3a]




Validating:  14%|█▍        | 14/100 [00:09<00:52,  1.65it/s][A




Epoch 2:  95%|█████████▍| 1516/1600 [1:06:10<03:40,  2.62s/it, loss=0.254, v_num=zb3a]




Validating:  16%|█▌        | 16/100 [00:10<00:49,  1.68it/s][A




Epoch 2:  95%|█████████▍| 1518/1600 [1:06:12<03:34,  2.62s/it, loss=0.254, v_num=zb3a]




Validating:  18%|█▊        | 18/100 [00:12<00:52,  1.57it/s][A




Epoch 2:  95%|█████████▌| 1520/1600 [1:06:13<03:29,  2.61s/it, loss=0.254, v_num=zb3a]




Validating:  20%|██        | 20/100 [00:13<00:50,  1.59it/s][A




Epoch 2:  95%|█████████▌| 1522/1600 [1:06:14<03:23,  2.61s/it, loss=0.254, v_num=zb3a]




Validating:  22%|██▏       | 22/100 [00:14<00:46,  1.66it/s][A




Epoch 2:  95%|█████████▌| 1524/1600 [1:06:16<03:18,  2.61s/it, loss=0.254, v_num=zb3a]




Validating:  24%|██▍       | 24/100 [00:15<00:49,  1.54it/s][A




Epoch 2:  95%|█████████▌| 1526/1600 [1:06:17<03:12,  2.61s/it, loss=0.254, v_num=zb3a]




Validating:  26%|██▌       | 26/100 [00:17<00:46,  1.60it/s][A




Epoch 2:  96%|█████████▌| 1528/1600 [1:06:18<03:07,  2.60s/it, loss=0.254, v_num=zb3a]




Validating:  28%|██▊       | 28/100 [00:18<00:45,  1.58it/s][A




Epoch 2:  96%|█████████▌| 1530/1600 [1:06:19<03:02,  2.60s/it, loss=0.254, v_num=zb3a]




Validating:  30%|███       | 30/100 [00:19<00:49,  1.40it/s][A




Epoch 2:  96%|█████████▌| 1532/1600 [1:06:21<02:56,  2.60s/it, loss=0.254, v_num=zb3a]




Validating:  32%|███▏      | 32/100 [00:21<00:47,  1.43it/s][A




Epoch 2:  96%|█████████▌| 1534/1600 [1:06:22<02:51,  2.60s/it, loss=0.254, v_num=zb3a]




Validating:  34%|███▍      | 34/100 [00:22<00:44,  1.49it/s][A




Epoch 2:  96%|█████████▌| 1536/1600 [1:06:24<02:46,  2.59s/it, loss=0.254, v_num=zb3a]




Validating:  36%|███▌      | 36/100 [00:23<00:40,  1.58it/s][A




Epoch 2:  96%|█████████▌| 1538/1600 [1:06:25<02:40,  2.59s/it, loss=0.254, v_num=zb3a]




Validating:  38%|███▊      | 38/100 [00:25<00:38,  1.61it/s][A




Epoch 2:  96%|█████████▋| 1540/1600 [1:06:26<02:35,  2.59s/it, loss=0.254, v_num=zb3a]




Validating:  40%|████      | 40/100 [00:26<00:45,  1.33it/s][A




Epoch 2:  96%|█████████▋| 1542/1600 [1:06:28<02:30,  2.59s/it, loss=0.254, v_num=zb3a]




Validating:  42%|████▏     | 42/100 [00:28<00:39,  1.48it/s][A




Epoch 2:  96%|█████████▋| 1544/1600 [1:06:29<02:24,  2.58s/it, loss=0.254, v_num=zb3a]




Validating:  44%|████▍     | 44/100 [00:29<00:39,  1.41it/s][A




Epoch 2:  97%|█████████▋| 1546/1600 [1:06:30<02:19,  2.58s/it, loss=0.254, v_num=zb3a]




Validating:  46%|████▌     | 46/100 [00:30<00:37,  1.45it/s][A




Epoch 2:  97%|█████████▋| 1548/1600 [1:06:32<02:14,  2.58s/it, loss=0.254, v_num=zb3a]




Validating:  48%|████▊     | 48/100 [00:32<00:33,  1.57it/s][A




Epoch 2:  97%|█████████▋| 1550/1600 [1:06:33<02:08,  2.58s/it, loss=0.254, v_num=zb3a]
Epoch 3:  94%|█████████▍| 1500/1600 [1:05:57<04:23,  2.64s/it, loss=0.26, v_num=zb3a] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/100 [00:00<?, ?it/s][A




Epoch 3:  94%|█████████▍| 1502/1600 [1:05:58<04:18,  2.64s/it, loss=0.26, v_num=zb3a]




Validating:   2%|▏         | 2/100 [00:01<01:15,  1.30it/s][A




Epoch 3:  94%|█████████▍| 1504/1600 [1:06:00<04:12,  2.63s/it, loss=0.26, v_num=zb3a]




Validating:   4%|▍         | 4/100 [00:03<01:15,  1.27it/s][A




Epoch 3:  94%|█████████▍| 1506/1600 [1:06:01<04:07,  2.63s/it, loss=0.26, v_num=zb3a]




Validating:   6%|▌         | 6/100 [00:04<01:08,  1.38it/s][A




Epoch 3:  94%|█████████▍| 1508/1600 [1:06:02<04:01,  2.63s/it, loss=0.26, v_num=zb3a]




Validating:   8%|▊         | 8/100 [00:05<00:59,  1.55it/s][A




Epoch 3:  94%|█████████▍| 1510/1600 [1:06:04<03:56,  2.63s/it, loss=0.26, v_num=zb3a]




Validating:  10%|█         | 10/100 [00:07<01:01,  1.46it/s][A




Epoch 3:  94%|█████████▍| 1512/1600 [1:06:05<03:50,  2.62s/it, loss=0.26, v_num=zb3a]




Validating:  12%|█▏        | 12/100 [00:08<01:01,  1.43it/s][A




Epoch 3:  95%|█████████▍| 1514/1600 [1:06:06<03:45,  2.62s/it, loss=0.26, v_num=zb3a]




Validating:  14%|█▍        | 14/100 [00:09<00:52,  1.65it/s][A




Epoch 3:  95%|█████████▍| 1516/1600 [1:06:07<03:39,  2.62s/it, loss=0.26, v_num=zb3a]




Validating:  16%|█▌        | 16/100 [00:10<00:50,  1.68it/s][A




Epoch 3:  95%|█████████▍| 1518/1600 [1:06:09<03:34,  2.61s/it, loss=0.26, v_num=zb3a]




Validating:  18%|█▊        | 18/100 [00:12<00:52,  1.57it/s][A




Epoch 3:  95%|█████████▌| 1520/1600 [1:06:10<03:28,  2.61s/it, loss=0.26, v_num=zb3a]




Validating:  20%|██        | 20/100 [00:13<00:50,  1.59it/s][A




Epoch 3:  95%|█████████▌| 1522/1600 [1:06:11<03:23,  2.61s/it, loss=0.26, v_num=zb3a]




Validating:  22%|██▏       | 22/100 [00:14<00:47,  1.65it/s][A




Epoch 3:  95%|█████████▌| 1524/1600 [1:06:13<03:18,  2.61s/it, loss=0.26, v_num=zb3a]




Validating:  24%|██▍       | 24/100 [00:15<00:44,  1.70it/s][A




Epoch 3:  95%|█████████▌| 1526/1600 [1:06:13<03:12,  2.60s/it, loss=0.26, v_num=zb3a]




Validating:  26%|██▌       | 26/100 [00:16<00:42,  1.76it/s][A




Epoch 3:  96%|█████████▌| 1528/1600 [1:06:15<03:07,  2.60s/it, loss=0.26, v_num=zb3a]




Validating:  28%|██▊       | 28/100 [00:18<00:43,  1.64it/s][A




Epoch 3:  96%|█████████▌| 1530/1600 [1:06:16<03:01,  2.60s/it, loss=0.26, v_num=zb3a]




Validating:  30%|███       | 30/100 [00:19<00:48,  1.43it/s][A




Epoch 3:  96%|█████████▌| 1532/1600 [1:06:18<02:56,  2.60s/it, loss=0.26, v_num=zb3a]




Validating:  32%|███▏      | 32/100 [00:21<00:47,  1.44it/s][A




Epoch 3:  96%|█████████▌| 1534/1600 [1:06:19<02:51,  2.59s/it, loss=0.26, v_num=zb3a]




Validating:  34%|███▍      | 34/100 [00:22<00:44,  1.49it/s][A




Epoch 3:  96%|█████████▌| 1536/1600 [1:06:20<02:45,  2.59s/it, loss=0.26, v_num=zb3a]




Validating:  36%|███▌      | 36/100 [00:23<00:40,  1.59it/s][A




Epoch 3:  96%|█████████▌| 1538/1600 [1:06:21<02:40,  2.59s/it, loss=0.26, v_num=zb3a]




Validating:  38%|███▊      | 38/100 [00:24<00:38,  1.61it/s][A




Epoch 3:  96%|█████████▋| 1540/1600 [1:06:23<02:35,  2.59s/it, loss=0.26, v_num=zb3a]




Validating:  40%|████      | 40/100 [00:26<00:40,  1.49it/s][A




Epoch 3:  96%|█████████▋| 1542/1600 [1:06:24<02:29,  2.58s/it, loss=0.26, v_num=zb3a]




Validating:  42%|████▏     | 42/100 [00:27<00:36,  1.60it/s][A




Epoch 3:  96%|█████████▋| 1544/1600 [1:06:26<02:24,  2.58s/it, loss=0.26, v_num=zb3a]




Validating:  44%|████▍     | 44/100 [00:29<00:38,  1.45it/s][A




Epoch 3:  97%|█████████▋| 1546/1600 [1:06:27<02:19,  2.58s/it, loss=0.26, v_num=zb3a]




Validating:  46%|████▌     | 46/100 [00:30<00:36,  1.48it/s][A




Epoch 3:  97%|█████████▋| 1548/1600 [1:06:28<02:13,  2.58s/it, loss=0.26, v_num=zb3a]




Validating:  48%|████▊     | 48/100 [00:31<00:32,  1.59it/s][A




Epoch 3:  97%|█████████▋| 1550/1600 [1:06:29<02:08,  2.57s/it, loss=0.26, v_num=zb3a]
Epoch 4:  72%|███████▏  | 1156/1600 [50:35<19:25,  2.63s/it, loss=0.191, v_num=zb3a] 

In [None]:
print(torch.cuda.max_memory_allocated()/1024**3, "Gb")

FP16

In [None]:
print(torch.cuda.max_memory_allocated()/1024**3, "Gb")

## Load Model

In [None]:
checkpoint_path = "/global/cfs/cdirs/m3443/usr/ryanliu/gnn_example/ITk_0.5GeV_GNN/1ummswz1/checkpoints/epoch=0-step=4.ckpt"
checkpoint = torch.load(checkpoint_path)

model = VanillaResAGNN.load_from_checkpoint(checkpoint_path).to(device)

In [None]:
model.eval();

In [None]:
model.hparams["datatype_split"] = [200, 1, 10]
model.setup(stage="fit")

In [None]:
model = model.to(device)

In [None]:
output_dir = "/global/cfs/cdirs/m3443/usr/ryanliu/gnn_example/ITk_0.5GeV_GNN/"

In [None]:
with torch.no_grad():
    for batch in model.train_dataloader():

        print(batch)

        output = model.shared_evaluation(batch.to(device), 0, log=False)

        print(output)
        print(os.path.split(batch.event_file[0])[-1])

        gnn_results = np.vstack([torch.cat([batch.edge_index, batch.edge_index.flip(0)], dim=-1).cpu().numpy(), 
                                 output["score"].cpu().numpy(), 
                                 output["truth"].cpu().numpy()])
        
        gnn_recarray = np.rec.fromarrays(gnn_results, names=["senders", "receivers", "score", "truth"])

        with open(os.path.join(output_dir, os.path.split(batch.event_file[0])[-1][-4:] + ".npz"), 'wb') as f:
            np.save(f, gnn_recarray)