In [160]:
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Tuple

try:
    from bindsnet.network import Network
except:
    from bindsnet.network import Network

from bindsnet.learning import PostPre
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.network.topology import Connection
from bindsnet.network.nodes import LIFNodes, Input
from bindsnet.analysis.plotting import plot_spikes
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting

In [194]:
from typing import Dict, List


class Topographies2SNN(Network):

    def __init__(self,
                 input_shape: Tuple,
                 fc_layers: List[int],
                 ):
        super().__init__()

        fc1_layer = self._add_layer(fc_layers[0], name='fc1_layer')
        for band_name in ['Delta','Theta','Alpha','Beta','Gamma']:
            layer_name = f'{band_name}-topography'
            input_layer = self._add_input_layer(name=layer_name, shape=input_shape)
            self._add_connection(input_layer, fc1_layer,
                                 layer_name, 'fc1_layer')

        prev_layer = fc1_layer
        for i, n in enumerate(fc_layers[1:]):
            layer = self._add_layer(n, name=f'fc{i+2}_layer')
            self._add_connection(prev_layer, layer,
                                 f'fc{i+1}_layer', f'fc{i+2}_layer')
            prev_layer = layer

    def _add_input_layer(self, name, shape):
        input_layer = Input(
            n=math.prod(shape),
            shape=shape,
            traces=True,
            tc_trace=20.0
        )
        self.add_layer(input_layer, name=name)
        return input_layer

    def _add_layer(self, n, name):
        layer = LIFNodes(
            n=n,
            traces=True,
            rest=0.0,
            thresh=10,
        )
        self.add_layer(layer, name=name)
        return layer

    def _add_connection(self, source, target,
                        source_name, target_name
    ):
        w = 0.5 * torch.rand(source.n, target.n)
        conn = Connection(
            source=source,
            target=target,
            w=w,
            update_rule=PostPre,
            norm=78.4,
            nu=(1e-4, 1e-2),
        )
        self.add_connection(conn,
                            source=source_name,
                            target=target_name)


bands = {
    'Delta': (.5, 4),
    'Theta': (4, 8),
    'Alpha': (8, 14),
    'Beta': (14, 32),
    'Gamma': (32, 62),
}

def merge_n_shuffle_tensors(*dicts: Dict, label_tensor):
    merged_dict = {}

    batch_size = list(dicts[0].values())[0].shape[1]
    # Generate random permutations for shuffling along axis 1
    perm_indices = torch.randperm(3 * batch_size)
    label_tensor = label_tensor[perm_indices]

    # Loop through the keys in one of the dictionaries (assuming they all have the same keys)
    for key in dicts[0].keys():
        # Concatenate the tensors along the second dimension (b)
        merged_value = torch.cat([d[key] for d in dicts], dim=1)

        # Assign the merged tensor to the corresponding key in the merged dictionary
        merged_dict[key] = merged_value[:, perm_indices, :, :]
    return merged_dict, label_tensor

In [None]:
network = DiehlAndCook2015(
    n_inpt=5 * 11 * 11,
    n_neurons=100,
    exc=22.5,
    inh=120,
    dt=1,
    norm=78.4,
    nu=(1e-4, 1e-2),
    theta_plus=0.05,
    inpt_shape=(5, 11, 11),
)

### Load the data

In [197]:
output_neurons = 64

fc_layers = [121, output_neurons]
snn = Topographies2SNN(input_shape=(11, 11),
                       fc_layers=fc_layers,
                       )
ouput_layer_name = f'fc{len(fc_layers)}_layer'
output_path_dir = Path('../datasets/EEG_data_for_Mental_Attention_State_Detection/preprocessed_resonators')

trial = '3'
band_name = 'Delta'
minute = 3
step = 0
# prefer to have batch size that is divided by 3 and its divide the number 645
# batch_size = 129
batch_size = 30
update_step = 1
update_interval = batch_size * update_step
sim_time = 1000
# sim_time = 153600 // 4
labeled_inputs = {
    key: {
        f'{band_name}-topography': torch.load(output_path_dir / trial / band_name / f'{minute + 10*i}.pt')[:sim_time,
                                   step*batch_size//3:(step+1)*batch_size//3, :, :]
        for band_name in bands.keys()
    }
    for i, key in enumerate(['focus', 'unfocus', 'drowsed'])
}
origin_label_tensor = torch.tensor([0] * (batch_size//3) +
                                   [1] * (batch_size//3) +
                                   [2] * (batch_size//3))

spike_record = torch.zeros((update_interval, sim_time, output_neurons))
n_classes = 3
assignments = -torch.ones(output_neurons)
proportions = torch.zeros((output_neurons, n_classes))
rates = torch.zeros((output_neurons, n_classes))

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}
labels = []

# Voltage recording for excitatory and inhibitory layers.
spikes = {}
for layer in set(snn.layers):
    if layer.endswith('topography'):
        continue
    spikes[layer] = Monitor(
        snn.layers[layer], state_vars=["s"], time=sim_time,
    )
    snn.add_monitor(spikes[layer], name="%s_spikes" % layer)

### Training Process

In [199]:
labels = []
for step in range(0, update_step*10 + 1):
    if step % update_step == 0 and step > 0:
        label_tensor = torch.tensor(labels)

        # Get network predictions.
        all_activity_pred = all_activity(
            spikes=spike_record, assignments=assignments, n_labels=n_classes
        )
        proportion_pred = proportion_weighting(
            spikes=spike_record,
            assignments=assignments,
            proportions=proportions,
            n_labels=n_classes,
        )

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100
            * torch.sum(label_tensor.long() == all_activity_pred).item()
            / len(label_tensor)
        )
        accuracy["proportion"].append(
            100
            * torch.sum(label_tensor.long() == proportion_pred).item()
            / len(label_tensor)
        )

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (
                accuracy["all"][-1],
                np.mean(accuracy["all"]),
                np.max(accuracy["all"]),
            )
        )
        print(
            "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
            " (best)\n"
            % (
                accuracy["proportion"][-1],
                np.mean(accuracy["proportion"]),
                np.max(accuracy["proportion"]),
            )
        )

        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(
            spikes=spike_record,
            labels=label_tensor,
            n_labels=n_classes,
            rates=rates,
        )

        labels = []

    inputs, label_tensor = merge_n_shuffle_tensors(labeled_inputs['focus'], labeled_inputs['unfocus'], labeled_inputs['drowsed'],
                                                   label_tensor=origin_label_tensor)
    labels.extend(label_tensor.tolist())

    # Run the network on the input.
    snn.run(inputs=inputs, time=sim_time)
    s = spikes[ouput_layer_name].get("s").permute((1, 0, 2))
    spike_record[
                (step * batch_size) % update_interval :
                (step * batch_size % update_interval) + s.size(0)
            ] = s
    spikes_output = {
        monitor_name: monitor.get('s')
        for monitor_name, monitor in spikes.items()
    }
    print({monitor_name: monitor.get('s').sum() for monitor_name, monitor in spikes.items()})
    # plot_spikes(spikes_output)
    # plt.show()
    snn.reset_state_variables()  # Reset state variables.

RuntimeError: output with shape [121, 121] doesn't match the broadcast shape [30, 121, 121]

In [202]:
list(inputs.values())[0].shape

torch.Size([1000, 30, 11, 11])