In [1]:
import torch
import numpy as np
# Load tensor
spikes_tensor = torch.load("tensors/spike_data_tensor.pt" ,map_location=torch.device('cpu'))
labels_tensor = torch.load("tensors/labels_tensor.pt", map_location=torch.device('cpu'))

label_distribution = torch.bincount(labels_tensor)
print(f'Original Labels distribution: {label_distribution}')

Original Labels distribution: tensor([1294,   95, 1046,   53,   40])


  spikes_tensor = torch.load("tensors/spike_data_tensor.pt" ,map_location=torch.device('cpu'))
  labels_tensor = torch.load("tensors/labels_tensor.pt", map_location=torch.device('cpu'))


In [2]:
from torch.utils.data import Dataset, DataLoader, Subset
import tonic

class CCMKDataset(Dataset):
    def __init__(self, spikes_tensor, labels_tensor, seed=None, **kwargs):
        # Set default values
        nchannels = kwargs.get('nchannels', None)
        target_label = kwargs.get('target_label', 2)
        device = kwargs.get('device', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        # Random seed for reproducibility
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
        
        # Move tensors to the specified device
        self.spikes_tensor = spikes_tensor.to(device)
        self.labels_tensor = labels_tensor.to(device)

        # If nchannels is specified
        if nchannels is not None:
            self.spikes_tensor = self.spikes_tensor[:, nchannels, :]
        
        # Ignore labels 1, 3, and 4
        valid_mask = (self.labels_tensor == 0) | (self.labels_tensor == target_label)
        self.spikes_tensor = self.spikes_tensor[valid_mask]
        self.labels_tensor = self.labels_tensor[valid_mask]
        
        # Convert the target label (2) to 1
        self.labels_tensor[self.labels_tensor == target_label] = 1
        
        # Downsampling: Match the number of label 0 samples to the number of label 1 samples
        label_0_indices = torch.where(self.labels_tensor == 0)[0]
        label_1_indices = torch.where(self.labels_tensor == 1)[0]
        
        # Determine the number of samples to match (downsample label 0)
        num_samples = min(len(label_1_indices), len(label_0_indices))
        
        # Randomly sample from label 0 and label 1 indices
        selected_label_0_indices = np.random.choice(label_0_indices.cpu(), size=num_samples, replace=False)
        selected_label_1_indices = np.random.choice(label_1_indices.cpu(), size=num_samples, replace=False)

        selected_label_0_indices = torch.tensor(selected_label_0_indices, device=device)
        selected_label_1_indices = torch.tensor(selected_label_1_indices, device=device)

        # Combine the downsampled label 0 indices with label 1 indices
        balanced_indices = torch.cat([selected_label_0_indices, selected_label_1_indices])

        # Apply the balanced indices to filter spikes and labels
        self.spikes_tensor = self.spikes_tensor[balanced_indices]
        self.labels_tensor = self.labels_tensor[balanced_indices]

    def __len__(self):
        return len(self.labels_tensor)
    
    def __getitem__(self, idx):
        spike_data = self.spikes_tensor[idx]
        label = self.labels_tensor[idx]
        return spike_data, label

  from .autonotebook import tqdm as notebook_tqdm


### Dataset and DataLoader Configuration

In [32]:
import numpy as np
from sklearn.model_selection import train_test_split

# Dataset arguments
dataset_kwargs = dict(
    nchannels= list(range(16)),
    target_label= 2,      
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),  
    seed=42 
)

# Instantiate dataset 
dataset = CCMKDataset(spikes_tensor, labels_tensor, **dataset_kwargs)

# Balance the dataset by selecting equal samples from both classes
label_0_indices = torch.where(dataset.labels_tensor == 0)[0]
label_1_indices = torch.where(dataset.labels_tensor == 1)[0]

# Determine the number of samples to match
num_samples = min(len(label_1_indices), len(label_0_indices))

# Randomly sample from label 0 and label 1 indices
selected_label_0_indices = np.random.choice(label_0_indices.cpu(), size=num_samples, replace=False)
selected_label_1_indices = np.random.choice(label_1_indices.cpu(), size=num_samples, replace=False)

# Combine the downsampled label 0 indices with label 1 indices
balanced_indices = torch.cat([torch.tensor(selected_label_0_indices), torch.tensor(selected_label_1_indices)])

# Shuffle the balanced indices
balanced_indices = balanced_indices[torch.randperm(len(balanced_indices))]

# Manually split balanced indices
train_val_label_0, test_label_0 = train_test_split(selected_label_0_indices, test_size=0.20, random_state=42)
train_val_label_1, test_label_1 = train_test_split(selected_label_1_indices, test_size=0.20, random_state=42)

train_label_0, val_label_0 = train_test_split(train_val_label_0, test_size=0.10, random_state=42)
train_label_1, val_label_1 = train_test_split(train_val_label_1, test_size=0.10, random_state=42)

# Combine splits back together
train_indices = torch.cat([torch.tensor(train_label_0), torch.tensor(train_label_1)])
val_indices = torch.cat([torch.tensor(val_label_0), torch.tensor(val_label_1)])
test_indices = torch.cat([torch.tensor(test_label_0), torch.tensor(test_label_1)])

# Create final subsets for training, validation, and test
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

# DataLoader arguments (unchanged)
dataloader_kwargs = dict(
    batch_size=32,
    shuffle=True,
    drop_last=False,
    pin_memory=True,
    collate_fn=tonic.collation.PadTensors(batch_first=True),
    num_workers=0, 
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, **dataloader_kwargs)
val_loader = DataLoader(val_dataset, **dataloader_kwargs)
test_loader = DataLoader(test_dataset, **dataloader_kwargs)

In [36]:
# Adapted from Esther's snn.ipynb notebook

# Check the length of the dataset
print(f"Dataset length: {len(dataset)}")

# Get and print the shape of the first sample in the dataset
first_spikes, first_label = dataset[0]
print(f"Shape of the first spikes tensor: {first_spikes.shape}")
print(f"Label of the first sample: {first_label}")

# Check batch information in the dataloader
for batch_idx, (inputs, targets) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}:")
    print(f" - Inputs shape: {inputs.shape}")  # Batch size x Number of channels x Number of time steps
    print(f" - Targets shape: {targets.shape}")
    break  # Only view the first batch

# Get the distribution of labels
labels = torch.cat([batch[1] for batch in test_loader])
print(f'Labels distribution: {torch.bincount(labels.int())}')

Dataset length: 2092
Shape of the first spikes tensor: torch.Size([16, 101])
Label of the first sample: 0
Batch 1:
 - Inputs shape: torch.Size([32, 16, 101])
 - Targets shape: torch.Size([32])
Labels distribution: tensor([210, 210])


### Network Initialization

In [37]:
# Sampling frequency
fs = 101640 

# Calculate the time step
dt = 1 / fs
print(f"dt = {dt} seconds")

dt = 9.838646202282566e-06 seconds


In [38]:
from rockpool.nn.networks import SynNet
from rockpool.nn.modules import LIFTorch as LIFOtherSpiking

net = SynNet(
    n_channels=16,                        
    n_classes=2,                           
    size_hidden_layers=[24, 24, 24],        
    time_constants_per_layer=[2, 4, 8],
    dt=dt,     
    #output="vmem",                        
    #neuron_model=LIFOtherSpiking           
)

print(net)

SynNet  with shape (16, 2) {
    TorchSequential 'seq' with shape (16, 2) {
        LinearTorch '0_LinearTorch' with shape (16, 24)
        LIFTorch '1_LIFTorch' with shape (24, 24)
        TimeStepDropout '2_TimeStepDropout' with shape (24,)
        LinearTorch '3_LinearTorch' with shape (24, 24)
        LIFTorch '4_LIFTorch' with shape (24, 24)
        TimeStepDropout '5_TimeStepDropout' with shape (24,)
        LinearTorch '6_LinearTorch' with shape (24, 24)
        LIFTorch '7_LIFTorch' with shape (24, 24)
        TimeStepDropout '8_TimeStepDropout' with shape (24,)
        LinearTorch '9_LinearTorch' with shape (24, 2)
        LIFTorch '10_LIFTorch' with shape (2, 2)
    }
}
