In [1]:
import math
import torch
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import Tuple

try:
    from bindsnet.network import Network
except:
    from bindsnet.network import Network

from bindsnet.learning import PostPre
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.network.topology import Connection
from bindsnet.network.nodes import LIFNodes, Input
from bindsnet.analysis.plotting import plot_spikes
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting

In [2]:
from typing import Dict, List


class Topographies2SNN(Network):

    def __init__(self,
                 input_shape: Tuple,
                 fc_layers: List[int],
                 ):
        super().__init__()

        fc1_layer = self._add_layer(fc_layers[0], name='fc1_layer')
        for band_name in ['Delta','Theta','Alpha','Beta','Gamma']:
            layer_name = f'{band_name}-topography'
            input_layer = self._add_input_layer(name=layer_name, shape=input_shape)
            self._add_connection(input_layer, fc1_layer,
                                 layer_name, 'fc1_layer')

        prev_layer = fc1_layer
        for i, n in enumerate(fc_layers[1:]):
            layer = self._add_layer(n, name=f'fc{i+2}_layer')
            self._add_connection(prev_layer, layer,
                                 f'fc{i+1}_layer', f'fc{i+2}_layer')
            prev_layer = layer

    def _add_input_layer(self, name, shape):
        input_layer = Input(
            n=math.prod(shape),
            shape=shape,
            traces=True,
            tc_trace=20.0
        )
        self.add_layer(input_layer, name=name)
        return input_layer

    def _add_layer(self, n, name):
        layer = LIFNodes(
            n=n,
            traces=True,
            rest=0.0,
            thresh=10,
        )
        self.add_layer(layer, name=name)
        return layer

    def _add_connection(self, source, target,
                        source_name, target_name
    ):
        w = 0.5 * torch.rand(source.n, target.n)
        conn = Connection(
            source=source,
            target=target,
            w=w,
            update_rule=PostPre,
            # norm=78.4,
            # nu=(1e-4, 1e-2),
        )
        self.add_connection(conn,
                            source=source_name,
                            target=target_name)


bands = {
    'Delta': (.5, 4),
    # 'Theta': (4, 8),
    # 'Alpha': (8, 14),
    # 'Beta': (14, 32),
    # 'Gamma': (32, 62),
}

def merge_n_shuffle_tensors(*dicts: Dict, label_tensor):
    merged_dict = {}

    batch_size = list(dicts[0].values())[0].shape[1]
    # Generate random permutations for shuffling along axis 1
    perm_indices = torch.randperm(3 * batch_size)
    label_tensor = label_tensor[perm_indices]

    # Loop through the keys in one of the dictionaries (assuming they all have the same keys)
    for key in dicts[0].keys():
        # Concatenate the tensors along the second dimension (b)
        merged_value = torch.cat([d[key] for d in dicts], dim=1)

        # Assign the merged tensor to the corresponding key in the merged dictionary
        merged_dict[key] = merged_value[:, perm_indices, :, :]
    return merged_dict, label_tensor

## Our Architecture

### Load the data

In [250]:
output_neurons = 64

fc_layers = [121, output_neurons]
snn = Topographies2SNN(input_shape=(11, 11),
                       fc_layers=fc_layers,
                       )
ouput_layer_name = f'fc{len(fc_layers)}_layer'
output_path_dir = Path('../datasets/EEG_data_for_Mental_Attention_State_Detection/preprocessed_resonators')

trial = '3'
band_name = 'Delta'
minute = 3
step = 0
# prefer to have batch size that is divided by 3 and its divide the number 645
# batch_size = 129
batch_size = 30
update_step = 1
update_interval = batch_size * update_step
sim_time = 5000
# sim_time = 153600 // 4
labeled_inputs = {
    key: {
        f'{band_name}-topography': torch.load(output_path_dir / trial / band_name / f'{minute + 10*i}.pt')[:sim_time,
                                   step*batch_size//3:(step+1)*batch_size//3, :, :]
        for band_name in bands.keys()
    }
    for i, key in enumerate(['focus', 'unfocus', 'drowsed'])
}
origin_label_tensor = torch.tensor([0] * (batch_size//3) +
                                   [1] * (batch_size//3) +
                                   [2] * (batch_size//3))

spike_record = torch.zeros((update_interval, sim_time, output_neurons))
n_classes = 3
assignments = -torch.ones(output_neurons)
proportions = torch.zeros((output_neurons, n_classes))
rates = torch.zeros((output_neurons, n_classes))

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}
labels = []

# Voltage recording for excitatory and inhibitory layers.
spikes = {}
for layer in set(snn.layers):
    if layer.endswith('topography'):
        continue
    spikes[layer] = Monitor(
        snn.layers[layer], state_vars=["s"], time=sim_time,
    )
    snn.add_monitor(spikes[layer], name="%s_spikes" % layer)

### Training Process

In [223]:
labels = []
for step in range(0, update_step*10 + 1):
    if step % update_step == 0 and step > 0:
        label_tensor = torch.tensor(labels)

        # Get network predictions.
        all_activity_pred = all_activity(
            spikes=spike_record, assignments=assignments, n_labels=n_classes
        )
        proportion_pred = proportion_weighting(
            spikes=spike_record,
            assignments=assignments,
            proportions=proportions,
            n_labels=n_classes,
        )

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100
            * torch.sum(label_tensor.long() == all_activity_pred).item()
            / len(label_tensor)
        )
        accuracy["proportion"].append(
            100
            * torch.sum(label_tensor.long() == proportion_pred).item()
            / len(label_tensor)
        )

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (
                accuracy["all"][-1],
                np.mean(accuracy["all"]),
                np.max(accuracy["all"]),
            )
        )
        print(
            "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
            " (best)\n"
            % (
                accuracy["proportion"][-1],
                np.mean(accuracy["proportion"]),
                np.max(accuracy["proportion"]),
            )
        )

        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(
            spikes=spike_record,
            labels=label_tensor,
            n_labels=n_classes,
            rates=rates,
        )

        labels = []

    inputs, label_tensor = merge_n_shuffle_tensors(labeled_inputs['focus'], labeled_inputs['unfocus'], labeled_inputs['drowsed'],
                                                   label_tensor=origin_label_tensor)
    labels.extend(label_tensor.tolist())

    # Run the network on the input.
    snn.run(inputs=inputs, time=sim_time)
    s = spikes[ouput_layer_name].get("s").permute((1, 0, 2))
    spike_record[
                (step * batch_size) % update_interval :
                (step * batch_size % update_interval) + s.size(0)
            ] = s
    spikes_output = {
        monitor_name: monitor.get('s')
        for monitor_name, monitor in spikes.items()
    }
    print({monitor_name: monitor.get('s').sum() for monitor_name, monitor in spikes.items()})
    # plot_spikes(spikes_output)
    # plt.show()
    snn.reset_state_variables()  # Reset state variables.

{'fc2_layer': tensor(11153), 'fc1_layer': tensor(31069)}

All activity accuracy: 33.33 (last), 33.33 (average), 33.33 (best)
Proportion weighting accuracy: 33.33 (last), 33.33 (average), 33.33 (best)

{'fc2_layer': tensor(11153), 'fc1_layer': tensor(31069)}

All activity accuracy: 46.67 (last), 40.00 (average), 46.67 (best)
Proportion weighting accuracy: 50.00 (last), 41.67 (average), 50.00 (best)

{'fc2_layer': tensor(11153), 'fc1_layer': tensor(31069)}

All activity accuracy: 46.67 (last), 42.22 (average), 46.67 (best)
Proportion weighting accuracy: 50.00 (last), 44.44 (average), 50.00 (best)

{'fc2_layer': tensor(11153), 'fc1_layer': tensor(31069)}

All activity accuracy: 46.67 (last), 43.33 (average), 46.67 (best)
Proportion weighting accuracy: 50.00 (last), 45.83 (average), 50.00 (best)

{'fc2_layer': tensor(11153), 'fc1_layer': tensor(31069)}

All activity accuracy: 46.67 (last), 44.00 (average), 46.67 (best)
Proportion weighting accuracy: 50.00 (last), 46.67 (average), 50.00 (be

## Using Diel

In [3]:
output_path_dir = Path('../datasets/EEG_data_for_Mental_Attention_State_Detection/preprocessed_resonators')

trial = '3'
band_name = 'Delta'
minute = 3
step = 0
# prefer to have batch size that is divided by 3 and its divide the number 645
# batch_size = 129
batch_size = 15
update_step = 1
update_interval = batch_size * update_step
sim_time = 2500
# sim_time = 153600 // 4
labeled_inputs = {
    key: {
        f'{band_name}-topography': torch.load(output_path_dir / trial / band_name / f'{minute + 10*i}.pt')[:sim_time,
                                   step*batch_size//3:(step+1)*batch_size//3, :, :]
        for band_name in bands.keys()
    }
    for i, key in enumerate(['focus',
                             'unfocus',
                             'drowsed'
                             ])
}
origin_label_tensor = torch.tensor([0] * (batch_size//3) +
                                   [1] * (batch_size//3) +
                                   [2] * (batch_size//3)
                                   )

n_classes = 3

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}
labels = []

In [4]:
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.analysis.plotting import (
    plot_assignments,
    plot_input,
    plot_performance,
    plot_spikes,
    plot_voltages,
    plot_weights,
)

n_neurons = 100
snn = DiehlAndCook2015(
    n_inpt= 11 * 11,
    # n_inpt=5 * 11 * 11,
    n_neurons=n_neurons,
    exc=22.5,
    inh=22.5,
    dt=1,
    # norm=78.4,
    norm=78.4,
    nu=(1e-5, 1e-1),
    theta_plus=0.05,
    inpt_shape=(1, 11, 11),
    # inpt_shape=(5, 11, 11),
)
# Set up monitors for spikes and voltages
spikes = {}
for layer in set(snn.layers):
    spikes[layer] = Monitor(
        snn.layers[layer], state_vars=["s"], time=sim_time
    )
    snn.add_monitor(spikes[layer], name="%s_spikes" % layer)

output_layer_name = "Ae"
spike_record = torch.zeros((update_interval, sim_time, n_neurons))
assignments = -torch.ones(n_neurons)
proportions = torch.zeros((n_neurons, n_classes))
rates = torch.zeros((n_neurons, n_classes))

inpt_ims, inpt_axes = None, None
spike_ims, spike_axes = None, None
weights_im = None
assigns_im = None
perf_ax = None
voltage_axes, voltage_ims = None, None
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(
    snn.layers["Ae"], ["v"], time=sim_time,
)
inh_voltage_monitor = Monitor(
    snn.layers["Ai"], ["v"], time=sim_time,
)
snn.add_monitor(exc_voltage_monitor, name="exc_voltage")
snn.add_monitor(inh_voltage_monitor, name="inh_voltage")

In [37]:
128*.9

115.2

In [8]:
%matplotlib inline

In [11]:
from tqdm import tqdm
labels = []

pbar = tqdm(total=update_step*1000 + 1)
for step in range(0, update_step*1000 + 1):
    pbar_postfix = {}
    if step % update_step == 0 and step > 0:
        label_tensor = torch.tensor(labels)

        # Get network predictions.
        all_activity_pred = all_activity(
            spikes=spike_record, assignments=assignments, n_labels=n_classes
        )
        proportion_pred = proportion_weighting(
            spikes=spike_record,
            assignments=assignments,
            proportions=proportions,
            n_labels=n_classes,
        )

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100
            * torch.sum(label_tensor.long() == all_activity_pred).item()
            / len(label_tensor)
        )
        accuracy["proportion"].append(
            100
            * torch.sum(label_tensor.long() == proportion_pred).item()
            / len(label_tensor)
        )
        pbar_postfix['acc_all'] = accuracy["all"][-1]
        pbar_postfix['mean_all'] = np.mean(accuracy["all"])
        pbar_postfix['best_all'] = np.max(accuracy["all"])

        pbar_postfix['acc_prop'] = accuracy["proportion"][-1]
        pbar_postfix['mean_prop'] = np.mean(accuracy["proportion"])
        pbar_postfix['best_prop'] = np.max(accuracy["proportion"])

        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(
            spikes=spike_record,
            labels=label_tensor,
            n_labels=n_classes,
            rates=rates,
        )

        labels = []

    inputs, label_tensor = merge_n_shuffle_tensors(labeled_inputs['focus'], labeled_inputs['unfocus'], labeled_inputs['drowsed'],
                                                   label_tensor=origin_label_tensor)

    inputs = {'X': torch.stack(list(inputs.values()), dim=2)}
    labels.extend(label_tensor.tolist())

    # Run the network on the input.
    snn.run(inputs=inputs, time=sim_time)
    s = spikes[output_layer_name].get("s").permute((1, 0, 2))
    spike_record[
                (step * batch_size) % update_interval :
                (step * batch_size % update_interval) + s.size(0)
            ] = s

    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")

    spikes_output = {
        monitor_name: monitor.get('s')
        for monitor_name, monitor in spikes.items()
    }
    for monitor_name, monitor in spikes.items():
        pbar_postfix[f'{monitor_name}_s'] = monitor.get('s').sum()

    # plot epoch
    image = inputs["X"][:, 0].sum(0).view(11, 11)
    image = image/image.max()*255
    inpt = inputs["X"][:, 0].view(sim_time, 11*11).sum(0).view(11, 11)
    lable = label_tensor[0]
    input_exc_weights = snn.connections[("X", "Ae")].w
    square_weights = get_square_weights(
        input_exc_weights.view(11*11, n_neurons), n_sqrt, 11
    )
    square_assignments = get_square_assignments(assignments, n_sqrt)
    spikes_ = {
        layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes
    }
    voltages = {"Ae": exc_voltages, "Ai": inh_voltages}
    inpt_axes, inpt_ims = plot_input(
        image, inpt, label=lable, axes=inpt_axes, ims=inpt_ims
    )
    spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes)
    weights_im = plot_weights(square_weights, im=weights_im)
    assigns_im = plot_assignments(square_assignments, im=assigns_im)
    perf_ax = plot_performance(
        accuracy, x_scale=update_step * batch_size, ax=perf_ax
    )
    voltage_ims, voltage_axes = plot_voltages(
        voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line"
    )
    plt.pause(1e-8)

    snn.reset_state_variables()  # Reset state variables.
    # pbar.set_postfix(pbar_postfix)
    # pbar.update(1)
pbar.close()



  0%|          | 0/1001 [00:19<?, ?it/s][A


<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

KeyboardInterrupt: 

In [10]:
plt.show()