# Import Packages

In [12]:
import numpy as np
import matplotlib.pyplot as plt
import snntorch as snn
import snntorch.spikeplot as splt
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.ndimage import gaussian_filter
from snntorch import spikegen, surrogate, utils
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from IPython.display import HTML

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net_default = {
    'beta': 0.9,
    'kernel_size': 5,
    'num_steps': 100,
    'rf_channels': 16,
    'spike_grad': surrogate.fast_sigmoid(slope=25)
}


class TouchDataset(Dataset):
    def __init__(self, filename, transform=None) -> None:
        """Constructor for TouchDataset.

        Args:
            root_dir (str): File Name of the touch data.
            transform (Compose, optional): Compose of transforms to apply on samples. Defaults to None.
        """
        data = np.load(filename, allow_pickle=True)
        sensordata = [np.array(data[k]['sensordata'], dtype=np.float32) for k in data.keys()]
        orientation = [np.array(data[k]['orientations'], dtype=np.float32) for k in data.keys()]
        self.sensordata, self.orientation = np.vstack(
            sensordata), np.hstack(orientation)
        self.transform = transform

    def __len__(self):
        return len(self.orientation)

    def __getitem__(self, index):
        sample = self.sensordata[index, :, :]
        if self.transform:
            sample = self.transform(sample)
        return sample, self.orientation[index]


class MinMaxScale(object):
    def __init__(self, minVal, maxVal):
        """Initialize the transform to rescale the values into [minVal, maxVal].

        Args:
            minVal (float): Minimum value.
            maxVal (float): Maximum value.
        """
        self.min, self.max = minVal, maxVal

    def __call__(self, x):
        min_x, max_x = np.min(x), np.max(x)
        if min_x >= max_x:
            y = np.ones(x.shape, dtype=np.float32)*max_x
        else:
            y = (x - min_x)/(max_x - min_x)*(self.max - self.min) + self.min
        return y


class ToSpike(object):
    def __init__(self, encoding='rate', **kwargs):
        """Initialize the spike generator.

        Args:
            encoding (str, optional): Spiking encoding options. Defaults to 'rate'.
        """
        self._encodings = ['rate', 'latency', 'delta']
        if encoding in self._encodings:
            self.encoding = encoding
        else:
            self.encoding = self._encodings[0]
        self.configs = kwargs
        match self.encoding:
            case 'rate':
                default = {
                    'num_steps': 100,
                    'gain': 0.5
                }
            case 'latency':
                deafult = {

                }
            case 'delta':
                default = {
                    'threshold': 4
                }
            case _:
                default = {

                }
        self.configs = default | self.configs

    def __call__(self, x):
        match self.encoding:
            case 'rate':
                y = spikegen.rate(
                    x,
                    num_steps=self.configs['num_steps'],
                    gain=self.configs['gain'])
            case 'delta':
                y = spikegen.delta(
                    x,
                    threshold=self.configs['threshold'])
            case _:
                y = x
        return y


class TacNet(nn.Module):
    def __init__(self, dim_input, dim_output, **kwargs) -> None:
        super().__init__()
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.configs = kwargs | net_default

        # Initialize Network
        # autopep8: off
        self.net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=self.configs['rf_channels'], kernel_size=self.configs['kernel_size'], stride=1),
                                 snn.Leaky(beta=self.configs['beta'], spike_grad=self.configs['spike_grad'], init_hidden=True),
                                 nn.Conv2d(self.configs['rf_channels'], 32, 3),
                                 nn.MaxPool2d(2),
                                 snn.Leaky(beta=self.configs['beta'], spike_grad=self.configs['spike_grad'], init_hidden=True, output=True))
        # autopep8: on
        self.init_hidden()

    def init_hidden(self):
        """Initialize the weights of the 1st hidden layer. Sample weights from a Gaussian process.
        """        
        center = self.configs['kernel_size']//2
        kernel = np.zeros(self.net[0].weight.shape)
        kernel[:, :, center, center] = 1
        kernel = gaussian_filter(kernel, sigma=center)
        with torch.no_grad():
            self.net[0].weight.copy_(torch.from_numpy(kernel))

    def forward(self, num_steps, x):
        mem_rec = []
        spk_rec = []
        # resets hidden states for all LIF neurons in net
        utils.reset(self.net)

        for step in range(num_steps):
            spk_out, mem_out = self.net(x)
            spk_rec.append(spk_out)
            mem_rec.append(mem_out)

        return torch.stack(spk_rec), torch.stack(mem_rec)


# Encode Touch Dataset with Spikes

In [13]:
# Copyright 2022 wngfra.
# SPDX-License-Identifier: Apache-2.0
batch_size = 128
num_steps = 100

transform = transforms.Compose([
    MinMaxScale(0., 1.0),
    transforms.ToTensor(),
    ToSpike('rate'),
])
touch_dataset = TouchDataset('../data/touch.pkl', transform=transform)
train_loader = DataLoader(touch_dataset, batch_size=batch_size, shuffle=True)

model = TacNet(1, 10).to(device)

# Iterate through minibatches
for train_data, train_target in train_loader:
    train_data = train_data.to(device)
    train_target = train_target.to(device)
    
    spk, mem = model(num_steps, train_data[:, 0, :, :])