In [1]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
%matplotlib widget

In [3]:
ls ..

README.md     [0m[01;34mdatasets[0m/     [01;34mmodels[0m/     [01;34mscripts[0m/  [01;34mtrainers[0m/
[01;34m__pycache__[0m/  [01;34mdistributed[0m/  [01;34mnotebooks[0m/  temp      [01;34mutils[0m/
[01;34mconfigs[0m/      [01;34mlogs[0m/         prepare.py  train.py


In [4]:
!head ../temp

2 train batch 0 loss 0.0775 l1 4003.12 l2 30.4848 grad 2.241 idx 48891
4 train batch 0 loss 0.0732 l1 4003.12 l2 30.4848 grad 2.241 idx 14705
6 train batch 0 loss 0.0722 l1 4003.12 l2 30.4848 grad 2.241 idx 34559
3 train batch 0 loss 0.0744 l1 4003.12 l2 30.4848 grad 2.241 idx 38492
0 train batch 0 loss 0.0755 l1 4003.12 l2 30.4848 grad 2.241 idx 7233
1 train batch 0 loss 0.0706 l1 4003.12 l2 30.4848 grad 2.241 idx 14759
7 train batch 0 loss 0.0833 l1 4003.12 l2 30.4848 grad 2.241 idx 42289
5 train batch 0 loss 0.0706 l1 4003.12 l2 30.4848 grad 2.241 idx 9181
1 train batch 1 loss 0.0702 l1 4003.15 l2 30.4849 grad 0.946 idx 35995
3 train batch 1 loss 0.0767 l1 4003.15 l2 30.4849 grad 0.946 idx 38965


In [5]:
data = pd.read_csv('../temp', delim_whitespace=True,
                   names=['worker', 'batch', 'loss', 'l1', 'l2', 'grad', 'idx'],
                   usecols=[0, 3, 5, 7, 9, 11, 13])

In [6]:
data.head()

Unnamed: 0,worker,batch,loss,l1,l2,grad,idx
0,2,0,0.0775,4003.12,30.4848,2.241,48891
1,4,0,0.0732,4003.12,30.4848,2.241,14705
2,6,0,0.0722,4003.12,30.4848,2.241,34559
3,3,0,0.0744,4003.12,30.4848,2.241,38492
4,0,0,0.0755,4003.12,30.4848,2.241,7233


In [7]:
# rank 0 data
data0 = data[data.worker == 0]

In [8]:
plt.figure(figsize=(8, 6))
plt.scatter(data.batch, data.loss, s=1)
plt.plot(data0.batch, data0.loss)
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Training loss');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
plt.figure(figsize=(9,6))
plt.plot(data0.batch, data0.l2)
plt.xlabel('Batch')
plt.ylabel('L2 weight norm')
plt.title('Model weight norm');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure(figsize=(9,6))
plt.plot(data0.batch, data0.grad)
plt.xlabel('Batch')
plt.ylabel('Gradient norm')
plt.title('Gradient norm');
plt.grid()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Based on the grad plot, it looks like things go wonky starting in batch 266. The loss then starts going wonky in batch 267. It's not clear precisely where things get weird. The peaks build somewhat gradually. However, let's start by looking at the samples in batch 266.

In [11]:
batch266 = data[data.batch == 266]
batch266

Unnamed: 0,worker,batch,loss,l1,l2,grad,idx
2128,1,266,0.0691,4006.31,30.508,3.168,30426
2129,6,266,0.0729,4006.31,30.508,3.168,38365
2130,0,266,0.0767,4006.31,30.508,3.168,22538
2131,5,266,0.0755,4006.31,30.508,3.168,18115
2132,3,266,0.0683,4006.31,30.508,3.168,26133
2133,2,266,0.0762,4006.31,30.508,3.168,31060
2134,7,266,0.0713,4006.31,30.508,3.168,26522
2135,4,266,0.0736,4006.31,30.508,3.168,17986


In [12]:
idxs = batch266.idx.to_numpy()

In [13]:
idxs

array([30426, 38365, 22538, 18115, 26133, 31060, 26522, 17986])

In [14]:
input_dir = '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000'

In [15]:
filenames = np.array([os.path.join(input_dir, f) for f in os.listdir(input_dir)
                      if f.startswith('event') and not f.endswith('_ID.npz')])

In [16]:
filenames[idxs]

array(['/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000003929_g001.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000004617_g005.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000003275_g002.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000002068_g001.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000001936_g003.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000006108_g002.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000002630_g000.npz',
       '/global/cscratch1/sd/sfarrell/heptrkx/data/hitgraphs_high_000/event000002162_g006.npz'],
      dtype='<U85')

In [17]:
def process_file(filename):
    with np.load(filename) as f:
        n_nodes = f['X'].shape[0]
        n_edges = f['y'].shape[0]
        purity = f['y'].mean()
    return dict(n_nodes=n_nodes, n_edges=n_edges, purity=purity)

def process_files(filenames):
    return (pd.DataFrame.from_records([process_file(f) for f in filenames])
            .assign(file=filenames))

def summarize_dataset(data):
    print('Samples:', data.shape[0])
    print('Min nodes:', data.n_nodes.min())
    print('Max nodes:', data.n_nodes.max())
    print('Min edges:', data.n_edges.min())
    print('Max edges:', data.n_edges.max())
    print('Mean purity: %.4f' % data.purity.mean())

In [18]:
graphs = process_files(filenames[idxs])

In [19]:
graphs

Unnamed: 0,n_nodes,n_edges,purity,file
0,4451,22408,0.064575,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
1,5031,28319,0.063173,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
2,4875,26364,0.063685,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
3,4015,18062,0.073691,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
4,4243,21336,0.065382,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
5,4504,23374,0.071062,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
6,3870,17700,0.074407,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
7,5159,29866,0.060972,/global/cscratch1/sd/sfarrell/heptrkx/data/hit...
