In [None]:
import torch
import numpy as np

from google.colab import drive
drive.mount('/content/drive')
# Load tensor
labels_tensor = torch.load("/content/drive/My Drive/labels_tensor.pt")
spikes_tensor = torch.load("/content/drive/My Drive/spike_data_tensor.pt")

label_distribution = torch.bincount(labels_tensor)
print(f'Original Labels distribution: {label_distribution}')


In [None]:
import torch
import numpy as np

class CCMKDataset(torch.utils.data.Dataset):
    def __init__(self, spikes_tensor, labels_tensor, target_label=2):
        self.spikes_tensor = spikes_tensor
        self.labels_tensor = labels_tensor
        self.target_label = target_label

        # Calculate the total number of target samples, i.e., those with label 2
        target_samples_count = (self.labels_tensor == target_label).sum().item()
        # Calculate the total number of non-target samples (labels 1, 3, and 4)
        non_target_samples_count = (self.labels_tensor == 1).sum().item() + \
                                   (self.labels_tensor == 3).sum().item() + \
                                   (self.labels_tensor == 4).sum().item()

        # Determine how many samples to randomly select from label 0 to balance the dataset
        required_samples_from_label_0 = target_samples_count - non_target_samples_count

        # Ensure we don't try to select more samples than are available in label 0
        available_label_0_samples = (self.labels_tensor == 0).sum().item()
        if required_samples_from_label_0 > available_label_0_samples:
            required_samples_from_label_0 = available_label_0_samples

        # Randomly select the required number of samples from label 0 (background noise)
        selected_background_noise = torch.where(self.labels_tensor == 0)[0]
        selected_background_noise = np.random.choice(selected_background_noise.cpu(), required_samples_from_label_0, replace=False)
        selected_background_noise = torch.tensor(selected_background_noise)  # Convert back to tensor

        # Combine selected label 0 samples with all non-target samples (labels 1, 3, 4)
        selected_non_target_mask = torch.cat([torch.where(self.labels_tensor == 1)[0],
                                              torch.where(self.labels_tensor == 3)[0],
                                              torch.where(self.labels_tensor == 4)[0],
                                              selected_background_noise])

        # Filter spikes and labels based on the selected indices
        self.filtered_spikes = torch.cat([self.spikes_tensor[selected_non_target_mask],
                                          self.spikes_tensor[self.labels_tensor == target_label]])

        # Convert all non-target labels (0, 1, 3, 4) to 0 (negative sample)
        # Convert target label 2 to 1 (positive sample)
        self.filtered_labels = torch.cat([torch.zeros(len(selected_non_target_mask)),  # Set non-targets to 0
                                          torch.ones(target_samples_count)])  # Set targets to 1

        # Debugging: Check the distribution of labels after processing
        print(f"Filtered Labels distribution after processing: {torch.bincount(self.filtered_labels.int())}")

    def __len__(self):
        return len(self.filtered_labels)

    def __getitem__(self, idx):
        return self.filtered_spikes[idx], self.filtered_labels[idx]


# Create the dataset
dataset = CCMKDataset(spikes_tensor=spikes_tensor, labels_tensor=labels_tensor)

# Create a DataLoader to batch the data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)


Filtered Labels distribution after processing: tensor([1046, 1046])
Inputs shape: torch.Size([32, 16, 101]), Targets shape: torch.Size([32])
Labels distribution: tensor([1046, 1046])


In [None]:
# Check the length of the dataset
print(f"Dataset length: {len(dataset)}")

# Get and print the shape of the first sample in the dataset
first_spikes, first_label = dataset[0]
print(f"Shape of the first spikes tensor: {first_spikes.shape}")
print(f"Label of the first sample: {first_label}")

# Check batch information in the dataloader
for batch_idx, (inputs, targets) in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}:")
    print(f" - Inputs shape: {inputs.shape}")  # Batch size x Number of channels x Number of time steps
    print(f" - Targets shape: {targets.shape}")
    break  # Only view the first batch
#batch size=32
#number of channels=16
#snumber of time steps=101

labels = torch.cat([batch[1] for batch in dataloader])
print(f'Labels distribution: {torch.bincount(labels.int())}')

Dataset length: 2092
Shape of the first spikes tensor: torch.Size([16, 101])
Label of the first sample: 0.0
Batch 1:
 - Inputs shape: torch.Size([32, 16, 101])
 - Targets shape: torch.Size([32])
Labels distribution: tensor([1046, 1046])


In [None]:
!pip install rockpool


Collecting rockpool
  Downloading rockpool-2.8.tar.gz (501 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m501.4/501.4 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rockpool
  Building wheel for rockpool (setup.py) ... [?25l[?25hdone
  Created wheel for rockpool: filename=rockpool-2.8-py3-none-any.whl size=602817 sha256=040f6a0b22bd692a944a5aacbb2d8e7a99526a3c9c2ef4357f856b5c16ebf701
  Stored in directory: /root/.cache/pip/wheels/86/da/b5/d6b52866a6c79c247eb46105b77bb3dab82eb3f567594ccc85
Successfully built rockpool
Installing collected packages: rockpool
Successfully installed rockpool-2.8


In [None]:
import torch
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from rockpool.nn.networks import SynNet
from tqdm.notebook import trange  # For progress bar

# Define dataset characteristics
n_channels = 16  # Number of input channels
n_classes = 2    # Number of output classes (assuming a binary classification task)
n_time = 101     # Number of time steps
batch_size = 32  # Batch size

# Initialize the SynNet model
net = SynNet(
    n_channels=n_channels,                # Number of input channels
    n_classes=n_classes,                  # Number of output classes (2 for binary classification)
    size_hidden_layers=[24, 24, 24],      # Number of neurons in each hidden layer
    time_constants_per_layer=[2, 4, 8],   # Time constants for each layer
)

# Define the loss function and optimizer
optimizer = Adam(net.parameters().astorch(), lr=1e-3)
loss_fun = CrossEntropyLoss()  # Use CrossEntropyLoss for classification tasks


In [None]:
n_epochs = 10    # Number of epochs for training

for epoch in trange(n_epochs):
    running_loss = 0.0
    for inputs, targets in dataloader:
        optimizer.zero_grad()  # Clear the gradients

        # Transpose the input tensor to match the expected shape [batch_size, n_time, n_channels]
        inputs = inputs.transpose(1, 2)  # Now the shape is [batch_size, n_time, n_channels]

        # Forward pass through the network
        output, _, _ = net(inputs)  # The output has the shape [batch_size, n_time, n_classes]

        # Select the output from the last time step
        output = output[:, -1, :]  # Now the shape is [batch_size, n_classes]

        # Ensure targets are of type Long
        targets = targets.long()  # Convert targets to Long type

        # Compute the loss
        loss = loss_fun(output, targets)  # Targets should have shape [batch_size]

        # Backward pass and optimization
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        # Accumulate the loss
        running_loss += loss.item()

    # Print average loss for this epoch
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}')

print('Training completed')


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1, Loss: 0.6555271744728088
Epoch 2, Loss: 0.6422403233520912
Epoch 3, Loss: 0.5844872684189768
Epoch 4, Loss: 0.5483888220606428
Epoch 5, Loss: 0.5298814606485944
Epoch 6, Loss: 0.5434634170748971
Epoch 7, Loss: 0.5179341286420822
Epoch 8, Loss: 0.5232368012269338
Epoch 9, Loss: 0.5206198380752043
Epoch 10, Loss: 0.5180452371185477
Training completed
