In [23]:
import h5py
import numpy as np
import os
import torch
from torch.utils.data import DataLoader

import lava.lib.dl.slayer as slayer
from lava.lib.dl import netx

from cytometrybin import BinCytometryDataset, BinCytometryNetwork

## run network to test

In [3]:
data_folder = '../data/bin_1ms_comp_ds'
trained_folder = 'logs/trained_fold4_nodelay_mar1'
delay = False
checkpoint_idx = 49
checkpoint_name = f'net_{checkpoint_idx}.pt'
net_filename = os.path.join(trained_folder, 'netx', f'net_{checkpoint_idx}.net')
device = torch.device('cuda')

In [6]:
net_slayer = BinCytometryNetwork(delay=delay).to(device)
checkpoint = torch.load(os.path.join(trained_folder, 'checkpoints', checkpoint_name))
net_slayer.load_state_dict(checkpoint)

<All keys matched successfully>

Make sure the loaded SLAYER network performs well

In [None]:
ds = BinCytometryDataset(data_folder=data_folder)
dl = DataLoader(ds, batch_size=128, shuffle=True, num_workers=0)

no test_fi given...


In [7]:
for i, (inp, lab) in enumerate(dl):
    print('inp.shape', inp.shape, 'ratio of classes:', lab.sum()/lab.shape[0])
    net_slayer.eval()
    with torch.no_grad():
        spikes, counts = net_slayer(inp.to(device))
    break

inp.shape torch.Size([128, 1536, 1000]) ratio of classes: tensor(0.4922)


In [9]:
spks = spikes.cpu().detach()
cnts = counts.cpu().detach()
print(spks.shape, cnts.shape, lab.shape)
print('output spike rates:')
print((spks.sum(axis=2) / 1000)[:5].T)
print('accuracy:', (slayer.classifier.Rate.predict(spks) == lab).sum().item() / lab.shape[0])

torch.Size([128, 2, 1000]) torch.Size([3]) torch.Size([128])
output spike rates:
tensor([[0.0200, 0.0410, 0.0510, 0.0310, 0.0480],
        [0.1170, 0.1570, 0.1690, 0.1530, 0.1510]])
accuracy: 0.984375


Save network via h5py

In [13]:
# now that we've confirmed the network is working, let's save it to a file
net_slayer.export_hdf5('logs/working_network.net')

# Bug analysis 

WHAT COMES BELOW IS A BUG ANALYSIS - I pushed a PR to fix this in lava-dl.

This is fixed in the Loihi cloud already, see `loihi_inference.ipynb` in this repository.

Investigate network weights from h5py file directly

In [24]:
netfile = h5py.File('logs/working_network.net', 'r')
print('top left 3x3 of netx network weights:')
print(np.array(netfile['layer']['0']['weight'])[:3, :3])
print('top left 3x3 of netx network weights:')
print(net_slayer.blocks[0].synapse.weight.squeeze().detach().cpu().numpy()[:3, :3])

top left 3x3 of netx network weights:
[[ -0.  10. -24.]
 [  0.  -0.  -0.]
 [-32.  40.  38.]]
top left 3x3 of netx network weights:
[[-1.2434849e-02  1.5916014e-01 -3.6128327e-01]
 [ 1.8655114e-03 -5.3635093e-05 -3.9743161e-04]
 [-5.1087338e-01  6.0952455e-01  5.9264773e-01]]


Load network via h5py, then compare

In [25]:
net_netx = netx.hdf5.Network('logs/working_network.net')
print(net_netx)

|   Type   |  W  |  H  |  C  | ker | str | pad | dil | grp |delay|
|Dense     |    1|    1|  512|     |     |     |     |     |False|
|Dense     |    1|    1|  512|     |     |     |     |     |False|
|Dense     |    1|    1|    2|     |     |     |     |     |False|


In [29]:
print('top left 5x5 of netx config network weights:')
print(net_netx.net_config['layer'][0]['weight'][:5, :5])
print('top left 5x5 of netx loaded network weights:')
print(net_netx.layers[0].synapse.weights.init[:5, :5])

top left 5x5 of netx config network weights:
[[  0  10 -24   2   4]
 [  0   0   0   0   0]
 [-32  40  38  40 -26]
 [  0   0   0   0  -2]
 [  0   2   2   0   0]]
top left 5x5 of netx loaded network weights:
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]]


In [31]:
for i in range(3):
    print(net_netx.layers[i].synapse.weights.init.sum(), end=' ')
    print(net_netx.net_config['layer'][i]['weight'].sum())

0 -3713256
-2257818 -2257818
-960 -60


In [35]:
for l in net_netx.layers:
    print(l.__class__.__name__, l.name, l.shape, '\tsum of weights:', l.synapse.weights.init.sum())

Dense Process_1 (512,) 	sum of weights: 0
Dense Process_4 (512,) 	sum of weights: -2257818
Dense Process_7 (2,) 	sum of weights: -960
