# Import Packages

In [103]:
import numpy as np
import matplotlib.pyplot as plt
import snntorch as snn
import snntorch.spikeplot as splt
import torch
import torch.nn as nn
import torch.nn.functional as F

from snntorch import spikegen, surrogate, utils
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from IPython.display import HTML

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net_default = {
    'spike_grad': surrogate.fast_sigmoid(slope=25),
    'beta': 0.5,
    'num_steps': 50
}


class TouchDataset(Dataset):
    def __init__(self, filename, transform=None) -> None:
        """Constructor for TouchDataset.

        Args:
            root_dir (str): File Name of the touch data.
            transform (Compose, optional): Compose of transforms to apply on samples. Defaults to None.
        """
        data = np.load(filename, allow_pickle=True)
        sensordata = [np.array(data[k]['sensordata']) for k in data.keys()]
        orientation = [np.array(data[k]['orientations'])for k in data.keys()]
        self.sensordata, self.orientation = np.vstack(
            sensordata), np.hstack(orientation)
        self.transform = transform

    def __len__(self):
        return len(self.orientation)

    def __getitem__(self, index):
        sample = self.sensordata[index, :, :]
        if self.transform:
            sample = self.transform(sample)
        return sample, self.orientation[index]


class MinMaxScale(object):
    def __init__(self, minVal, maxVal):
        """Initialize the transform to rescale the values into [minVal, maxVal].

        Args:
            minVal (float): Minimum value.
            maxVal (float): Maximum value.
        """
        self.min, self.max = minVal, maxVal

    def __call__(self, x):
        min_x, max_x = np.min(x), np.max(x)
        if min_x >= max_x:
            y = 0.0
        else:
            y = (x - min_x)/(max_x - min_x)*(self.max - self.min) + self.min
        return y


class ToSpike(object):
    def __init__(self, encoding='rate', **kwargs):
        """Initialize the spike generator.

        Args:
            encoding (str, optional): Spiking encoding options. Defaults to 'rate'.
        """
        self._encodings = ['rate', 'latency', 'delta']
        if encoding in self._encodings:
            self.encoding = encoding
        else:
            self.encoding = self._encodings[0]
        self.configs = kwargs
        match self.encoding:
            case 'rate':
                default = {
                    'num_steps': 100,
                    'gain': 0.5
                }
            case 'latency':
                deafult = {

                }
            case 'delta':
                default = {

                }
            case _:
                default = {

                }
        self.configs = default | self.configs

    def __call__(self, x):
        match self.encoding:
            case 'rate':
                y = spikegen.rate(
                    x, num_steps=self.configs['num_steps'], gain=self.configs['gain'])
            case _:
                y = x
        return y


class TacNet(nn.Module):
    def __init__(self, dim_input, dim_output, **kwargs) -> None:
        super().__init__()
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.configs = kwargs | net_default

        # Initialize Network
        # autopep8: off
        self.net = nn.Sequential(nn.Conv2d(1, 12, 5),
                                 nn.MaxPool2d(2),
                                 snn.Leaky(beta=self.configs['beta'], spike_grad=self.configs['spike_grad'], init_hidden=True),
                                 nn.Conv2d(12, 64, 5),
                                 nn.MaxPool2d(2),
                                 snn.Leaky(beta=self.configs['beta'], spike_grad=self.configs['spike_grad'], init_hidden=True),
                                 nn.Flatten(),
                                 snn.Leaky(beta=self.configs['beta'], spike_grad=self.configs['spike_grad'], init_hidden=True, output=True))
        # autopep8: on

    def forward(self, num_steps, x):
        mem_rec = []
        spk_rec = []
        # resets hidden states for all LIF neurons in net
        utils.reset(self.net)

        for step in range(num_steps):
            spk_out, mem_out = self.net(x)
            spk_rec.append(spk_out)
            mem_rec.append(mem_out)

        return torch.stack(spk_rec), torch.stack(mem_rec)


# Encode Touch Dataset with Spikes

In [104]:
# Copyright 2022 wngfra.
# SPDX-License-Identifier: Apache-2.0
batch_size = 128
num_steps = 1

transform = transforms.Compose([
    MinMaxScale(0., 1.0),
    transforms.ToTensor(),
    ToSpike('rate', num_steps=num_steps)
])
touch_dataset = TouchDataset('../data/touch.pkl', transform=transform)
train_loader = DataLoader(touch_dataset, batch_size=batch_size, shuffle=True)

model = TacNet(1, 10).to(device)
# Iterate through minibatches
data = iter(train_loader)
spike_data, _ = next(data)
x = spike_data.float().to(device)


# Reference
1. F. Pascal, L. Bombrun, J. -Y. Tourneret and Y. Berthoumieu, "Parameter Estimation For Multivariate Generalized Gaussian Distributions," in IEEE Transactions on Signal Processing, vol. 61, no. 23, pp. 5960-5971, Dec.1, 2013, doi: 10.1109/TSP.2013.2282909.
2. A. Parvizi-Fard, M. Amiri, D. Kumar, M. M. Iskarous, and N. V. Thakor, “A functional spiking neuronal network for tactile sensing pathway to process edge orientation,” Sci Rep, vol. 11, no. 1, p. 1320, Dec. 2021, doi: 10.1038/s41598-020-80132-4.
3. J. A. Pruszynski and R. S. Johansson, “Edge-orientation processing in first-order tactile neurons,” Nat Neurosci, vol. 17, no. 10, pp. 1404–1409, Oct. 2014, doi: 10.1038/nn.3804.
4. J. M. Yau, S. S. Kim, P. H. Thakur, and S. J. Bensmaia, “Feeling form: the neural basis of haptic shape perception,” Journal of Neurophysiology, vol. 115, no. 2, pp. 631–642, Feb. 2016, doi: 10.1152/jn.00598.2015.
5. G. Sutanto, Z. Su, S. Schaal, and F. Meier, “Learning Sensor Feedback Models from Demonstrations via Phase-Modulated Neural Networks,” in 2018 IEEE International Conference on Robotics and Automation (ICRA), Brisbane, QLD, May 2018, pp. 1142–1149. doi: 10.1109/ICRA.2018.8460986.