In [7]:
import numpy as np
from lava.proc.lif.process import LIFReset
from lava.proc.io.source import RingBuffer
from lava.proc.dense.process import Dense
from lava.proc.monitor.process import Monitor
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg
from matplotlib import pyplot as plt
from tqdm import tqdm
import typing as ty

import gsc_dataset_loader

# Import Process level primitives
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort

# Import parent classes for ProcessModels
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.model.py.model import PyLoihiProcessModel

# Import ProcessModel ports, data-types
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType

# Import execution protocol and hardware resources
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.resources import CPU

# Import decorators
from lava.magma.core.decorator import implements, requires

In [17]:
# load gsc dataset

train_x, train_y, validation_x, validation_y, test_x, test_y = gsc_dataset_loader.load_gsc("/its/home/ts468/data/rawSC/rawSC_80input/", 
                                                                        1, 
                                                                        80,
                                                                        1,
                                                                        True)

print(train_x.shape)

!! validation dataset loaded successfully
(2000, 80)


In [18]:
train_y

array([27])

In [9]:
params = {}
params["DT_MS"] = 1.0
params["TAU_MEM"] = 20.0
params["TAU_SYN"] = 5.0


In [10]:
# transform some parmeters
tau_mem_fac = 1.0-np.exp(-params["DT_MS"]/params["TAU_MEM"])
tau_syn_fac = 1.0-np.exp(-params["DT_MS"]/params["TAU_SYN"])

# load connections
w_i2h = np.load("0-Conn_Pop0_Pop1-g.npy")
w_i2h = w_i2h.reshape((80,512)).T
w_i2h *= tau_mem_fac
#w_i2h /= p["TAU_MEM"]
w_h2h = np.load("0-Conn_Pop1_Pop1-g.npy")
w_h2h = w_h2h.reshape((512,512)).T
w_h2h *= tau_mem_fac
#w_h2h /= p["TAU_MEM"]
w_h2o = np.load("0-Conn_Pop1_Pop2-g.npy")
w_h2o = w_h2o.reshape((512,35)).T
w_h2o *= tau_mem_fac
#w_h2o /= p["TAU_MEM"]

In [11]:
class SpikeInput(AbstractProcess):
    def __init__(self,
                 vth: int,
                 num_steps_per_image: ty.Optional[int] = 2000):
        super().__init__()
        shape = (80,)
        self.spikes_out = OutPort(shape=shape)  # Input spikes to the classifier
        self.label_out = OutPort(shape=(1,))  # Ground truth labels to OutputProc
        self.num_steps_per_image = Var(shape=(1,), init=num_steps_per_image)
        self.input_img = Var(shape=shape)
        self.ground_truth_label = Var(shape=(1,))
        self.v = Var(shape=shape, init=0)
        self.vth = Var(shape=(1,), init=vth)

@implements(proc=SpikeInput, protocol=LoihiProtocol)
@requires(CPU)
class PySpikeInputModel(PyLoihiProcessModel):
    spikes_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
    label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32,
                                      precision=32)
    num_steps_per_image: int = LavaPyType(int, int, precision=32)
    input_img: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    ground_truth_label: int = LavaPyType(int, int, precision=32)
    v: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    vth: int = LavaPyType(int, int, precision=32)
    
    def __init__(self, proc_params):
        super().__init__(proc_params=proc_params)
        self.curr_img_id = 0

    def post_guard(self):
        """Guard function for PostManagement phase.
        """
        if self.time_step % self.num_steps_per_image == 1:
            return True
        return False

    def run_post_mgmt(self):
        """Post-Management phase: executed only when guard function above 
        returns True.
        """
        self.ground_truth_label = train_y[self.curr_img_id]
        self.v = np.zeros(self.v.shape)
        self.label_out.send(np.array([self.ground_truth_label]))
        self.curr_img_id += 1

    def run_spk(self):
        """Spiking phase: executed unconditionally at every time-step
        """
        print(self.time_step)
        self.v[:] = self.v #+ self.input_img
        s_out = self.v > self.vth
        self.v[s_out] = 0  # reset voltage to 0 after a spike
        self.spikes_out.send(s_out)

In [12]:
input = SpikeInput(vth = 1.,
                   num_steps_per_image= 80)

hidden = LIFReset(shape=(512, ),                         # Number and topological layout of units in the process
                  vth=1.,                             # Membrane threshold
                  dv=tau_mem_fac,                              # Inverse membrane time-constant
                  du=tau_syn_fac,                              # Inverse synaptic time-constant
                  bias_mant=0.0,           # Bias added to the membrane voltage in every timestep
                  name="hidden",
                  reset_interval=1000)

output = LIFReset(shape=(35, ),                         # Number and topological layout of units in the process
                  vth=1e9,                             # Membrane threshold
                  dv=tau_mem_fac,                              # Inverse membrane time-constant
                  du=tau_syn_fac,                              # Inverse synaptic time-constant
                  bias_mant=0.0,           # Bias added to the membrane voltage in every timestep
                  name="output",
                  reset_interval=1000)

in_to_hid = Dense(weights= w_i2h,     # Initial value of the weights, chosen randomly
              name='in_to_hid')

hid_to_hid = Dense(weights=w_h2h,
                   name='hid_to_hid')

hid_to_out = Dense(weights=w_h2o,
                   name= 'hid_to_out')

input.spikes_out.connect(in_to_hid.s_in)
in_to_hid.a_out.connect(hidden.a_in)
hidden.s_out.connect(hid_to_hid.s_in)
hidden.s_out.connect(hid_to_out.s_in)
hid_to_hid.a_out.connect(hidden.a_in)
hid_to_out.a_out.connect(output.a_in)

# monitor outputs

monitor_output = Monitor()
num_steps = int(1000/params["DT_MS"])

monitor_output.probe(output.v, train_x.shape[1])

# run something
run_condition = RunSteps(num_steps=num_steps)
run_cfg = Loihi1SimCfg(select_tag="floating_pt")

n_sample = train_x.shape[-2]//num_steps
for i in tqdm(range(n_sample)):
    output.run(condition=run_condition, run_cfg=run_cfg)

output_v = monitor_output.get_data()
good = 0
for i in range(1):#n_sample):
    continue
    out_v = output_v["output"]["v"][i*num_steps:(i+1)*num_steps,:]
    sum_v = np.sum(out_v,axis=0)
    pred = np.argmax(sum_v)
    print(f"Pred: {pred}, True:{Y_test[i]}")
    if pred == Y_test[i]:
        good += 1

#print(f"test accuracy: {good/n_sample*100}")
output.stop()

# Got 87.1% from the pre-trained model that had 89% in GeNN


  0%|          | 0/2 [00:00<?, ?it/s]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
Encountered Fatal Exception: index 1 is out of bounds for axis 0 with size 1
Encountered Fatal Exception: index 80 is out of bounds for axis 0 with size 80Traceback: 

Traceback: 
Traceback (most recent call last):
  File "/its/home/ts468/PhD/ve/GeNN_4_9_0/lib/python3.8/site-packages/lava/magma/runtime/runtime.py", line 98, in target_fn
    actor.start(*args, **kwargs)
  File "/its/home/ts468/PhD/ve/GeNN_4_9_0/lib/python3.8/site-packages/lava/magma/core/model/py/model.py", line 93, in start
    self.run()
  File "/its/home/ts468/PhD/ve/GeNN_4_9_0/lib/python3.8/site-packages/lava/magma/core/model/py/model.py", line 232, in run
    raise inst
  File "/its/home/ts468/PhD/ve/GeNN_4_9_0/lib/python3.8/site-packages/lava/magma/core/model/py/model.py", line 218, i

  0%|          | 0/2 [00:00<?, ?it/s]


RuntimeError: 2 Exception(s) occurred. See output above for details.