In [56]:
# Show current directory
import os
curr_dir = os.getcwd()
print(curr_dir)
import json
# Define general variables
parent = r"C:\__NeuroSpark_Liset_Dataset__\neurospark_mat\CNN_TRAINING_SESSIONS" # Modify this to your data path folder

### HOME PC
# parent=r"E:\neurospark_mat\CNN_TRAINING_SESSIONS"


c:\Users\NCN\Documents\PedroFelix\LAVA_SNN_ripples\snnTorch


In [None]:
# Add parent directory to path (To acess sntt_utils)
import sys

parent_dir = os.path.abspath(os.path.join(curr_dir, os.pardir))
liset_path = os.path.abspath(os.path.join(curr_dir, '../liset_tk'))


# Add the grandparent directory to the system path
# grandparent_dir = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir))
sys.path.append(parent_dir)
sys.path.append(liset_path)
from liset_aux import ripples_std, middle

print(sys.path)



# Check if Cuda is available

In [None]:
import torch
import numpy as np

# Check CUDA Installation
print(torch.cuda.is_available())

# Get the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs: {num_gpus}")

# Get information about each GPU
for i in range(num_gpus):
    device_props = torch.cuda.get_device_properties(i)
    print(f"\nGPU {i}:")
    print(f"  Name: {device_props.name}")
    print(f"  Total memory: {device_props.total_memory / 1024**3:.2f} GB")
    print(f"  Multiprocessor count: {device_props.multi_processor_count}")
    print(f"  Major compute capability: {device_props.major}")
    print(f"  Minor compute capability: {device_props.minor}")


## Define the Device that will be used to train the SNN


In [None]:
# Set the device to be used
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")     # torch.device("cpu") #

print("device: ", device)

In [None]:
dt=1  # Time step in milliseconds (1 ms)


RIPPLE_DETECTION_OFFSET = [18, 45, 31, 20] # it's calculated as 4.5 periods of the ripple wavelet - for 100 Hz and 250 Hz as the limit frequencies
# The Windows for HFO detection are based on the MAX DETECTION OFFSET
RIPPLE_CONFIDENCE_WINDOW = int(round(RIPPLE_DETECTION_OFFSET[1] * 1.8)) 

# in timesteps (ms) - Max time from the Insertion Timing to the GT annotation
MAX_DETECTION_OFFSET =RIPPLE_DETECTION_OFFSET[1]   # in timesteps (ms)

MEAN_DETECTION_OFFSET = RIPPLE_DETECTION_OFFSET[2]   # in timesteps (ms)

WINDOW_SIZE = int(RIPPLE_DETECTION_OFFSET[1]*4)   # in timesteps (ms) - The size of the window to slice the input data

# unit: timesteps (ms) - The number of steps that 2 consecutive windows must overlap to not lose any relevant CBs
# INTERSECT_WINDOW_LEN = int(MAX_DETECTION_OFFSET)
std, mean = ripples_std(parent) # 61 ms
INTERSECT_WINDOW_LEN=int((std+mean)*1000) # in timesteps (ms) - The size of the intersection window to slice the input data

# unit: timesteps (ms) - The number of steps that the window must shift to get the next window
WINDOW_SHIFT = int(WINDOW_SIZE - INTERSECT_WINDOW_LEN)

# unit: timesteps (ms) - The time window after the GT annotation where the network should predict the burst (GT_time, GT_time + PRED_CAUSALITY_WINDOW)
# This is needed to give the network some extra time steps to increase the membrane potential and spike
PRED_CAUSALITY_WINDOW = int(5)     # Giving PRED_CAUSALITY_WINDOW ms for the network to update its inner state and spike   

# unit: timesteps (ms) - The time window around the GT annotation where the network should predict the burst (GT_time - PRED_GT_TOLERANCE, GT_time + PRED_GT_TOLERANCE)
PRED_GT_TOLERANCE = int(RIPPLE_DETECTION_OFFSET[3])  # in timesteps (ms)
print(f"WINDOW_SIZE: {WINDOW_SIZE}")
print(f"INTERSECT_WINDOW_LEN: {INTERSECT_WINDOW_LEN} (MAX_DETECTION_OFFSET)")
print(f"WINDOW_SHIFT: {WINDOW_SHIFT}")
print(f"MEAN DETECTION OFFSET: {MEAN_DETECTION_OFFSET}")
print(f"PRED_GT_TOLERANCE: {PRED_GT_TOLERANCE}")
print(f"PRED_CAUSALITY_WINDOW: {PRED_CAUSALITY_WINDOW}")

In [None]:
# Constants for the Refractory LIF Process
confidence_window=int(RIPPLE_DETECTION_OFFSET[1])
# We know that 2 relevant events do not occur within the confidence window of a ripple event, so we set the refractory period accordingly
refrac_period = np.floor(confidence_window / dt)   # Number of time-steps for the refractory period
print("Refractory Period: ", refrac_period, "ms")

# Read the concatenated data and GT

In [54]:
freq=1000 # Frequency of the signal in Hz (1 kHz)
bandpass=[100, 250] # Bandpass filter range in Hz (100-250 Hz)


# Generate the SNN's Input Data and Labels (GT)
We need to transform the input data into a format that we can feed into the SNN. To allow learning through BPTT, we will split the spike trains into time windows of WINDOW_SIZE ms. Since a relevant HFO can occur in-between 2 time windows, we will introduce an overlap of INTERSECT_WINDOW_LEN ms between the time windows -- equal to the maximum duration of an HFO.

We are opting for this windowing strategy because it is a simple way to implement learning in the SNN. Another option would be to feed the data in real-time to the SNN without windows, but this would disable the possibility of having batch_size > 1.

# Split the Input into Time Windows and Calculate the Ground Truth

In [79]:
# Only run this block if first time
if 'windowed_inpuit_data' not in locals():
    # Split the Input Data and Ground Truth into Windows
    windowed_input_data = []    # Input Data Windows
    windowed_gt = []        # Ground Truth Windows (spike time if HFO, -1 if no HFO)
    total_windows_count = 0
    skipped_hfo_count = 0   # Counts the nº of skipped HFOs due to no input activations
    total_hfos=0
    # curr_ripple_times = ripples_concat[curr_ripple_id]    # Get the GT times for the current sEEG source

    # LOAD THE DATA
    # Iterate over the datasets
    for dataset in os.listdir(parent):
        data_dir = os.path.abspath(os.path.join(curr_dir, os.pardir,"extract_Nripples","train_pedro","dataset_up_down",dataset,str(freq)))
        data=np.load(os.path.join(data_dir,f'data_up_down_{bandpass[0]}_{bandpass[1]}.npy'),allow_pickle=True)
        ripples=np.load(os.path.join(data_dir,"ripples.npy"),allow_pickle=True)
        with open(os.path.join(data_dir, f'params_{bandpass[0]}_{bandpass[1]}.json'), 'r') as f:
            parameters=json.load(f)
            thresholds=parameters["threshold"]
        print(data_dir)
        print("data_concat shape: ", data.shape)
        print("ripples_concat shape: ", ripples.shape)
        # print("Head of data_concat: ", data[:10][:])
        # print("Head of ripples_concat: ", ripples[:10])
        ripples = ripples[np.argsort(ripples[:, 0])]
    
        for channel in range(data.shape[1]):
            if thresholds[channel]>0.1:
                curr_ripple_id = 0     # Keep track of the current GT event index since it is monotonically increasing the timestep
                for i in range(0, data.shape[0], WINDOW_SHIFT):
                    left, right = i, i+WINDOW_SIZE
                    # Get the current input window
                    curr_window = data[left:right,channel,:]
                    # Increment the total windows count
                    total_windows_count += 1
                    # Check if the current window is smaller than the expected size
                    if curr_window.shape[0] < WINDOW_SIZE:
                        # If the current window is smaller than the expected size, break the loop
                        print(f"[WARNING] Current window [{left}, {right}] is smaller than the expected size. Breaking the loop...")
                        break

                    # OPTIMIZATION STEP: Skip windows with no activations - The gradient will be zero (at least when using MSE Spike Rate)
                    if np.sum(curr_window) == 0:
                        # print(f"Window [{left}:{right}] has no Input activations. Skipping...")
                        cur_gt_time=[-1, -1]    # Default value for Spike Time (no HFO)
                        if curr_ripple_id < ripples.shape[0]:
                            cur_gt_time = ripples[curr_ripple_id]    # Convert to closest integer (floor)
                        if (cur_gt_time[1] >= left) and (cur_gt_time[0] <= right):
                            if cur_gt_time[1] <= right:
                                print(f"[WARNING] Window [{left}:{right}] has a GT event at {cur_gt_time} and NO Input activations. Skipping...")
                                # Update the curr_gt_idx to the next GT event
                                skipped_hfo_count += 1
                            curr_ripple_id += 1
                        continue   
                    
                    '''
                    Check if there is a GT event in the current window
                    '''

                    curr_gt = -1    # Default value for Spike Time (no HFO)
                    
                    # Check if the current GT event is within the current window
                    while curr_ripple_id<ripples.shape[0] and ripples[curr_ripple_id][1] < left:
                        # Ripple ends before the window starts → skip it
                        curr_ripple_id += 1
                    
                    if curr_ripple_id >= ripples.shape[0]:
                        curr_ripple_id=ripples.shape[0]-1
                
                    cur_gt_time = ripples[curr_ripple_id]      
                    if (cur_gt_time[1] >= left) and (cur_gt_time[0] <= right):
                        '''
                            Check if the current window encapsulates the whole HFO Causality Window
                            The Network may spike in the interval [GT_time, GT_time + MEAN_HFO_DURATION + PRED_GT_TOLERANCE]
                            
                            However, we are using an upper limit for the HFO Duration of MAX_HFO_DURATION. This way, the Ground Truth
                            Timestamps will be clamped uppwards by WINDOW_SIZE - MAX_HFO_DURATION + MEAN_HFO_DURATION
                        '''
                        if cur_gt_time[1] <= right and cur_gt_time[0]>=left: # If the GT event is completely within the current window
                            '''The Network should predict the HFO -> Calculate the spike time
                            Let's assume the network should spike at the end of the relevant event. We have no way of knowing
                            the exact end time, so we use the mean duration of the event to calculate the spike time.
                            '''
                            avg_spike_time = cur_gt_time[0] + MEAN_DETECTION_OFFSET   # The network should spike at the end of the relevant event
                            
                            # Subtract the left offset to get the spike time in the current window
                            relative_spike_time = avg_spike_time - left
                            if relative_spike_time > WINDOW_SIZE:
                                # If the spike time is greater than the window size, we want to skip the window
                                print(f"[WARNING] Spike time {relative_spike_time} is greater than the window size {WINDOW_SIZE}. Adjusting...")
                                relative_spike_time= cur_gt_time[1]-left

                            curr_gt = relative_spike_time   # Update the curr_gt value

                            # Update the curr_gt_idx to the next GT event
                            curr_ripple_id += 1
                            
                        elif cur_gt_time[1] > right or cur_gt_time[0] < left:
                            continue
                            # If the GT event is not completely within the current window, we want to skip the window
                    
                    # Append the current window    
                    windowed_input_data.append(curr_window)            
                    # Append the current GT Spike Time to the windowed GT
                    windowed_gt.append(curr_gt)
                total_hfos+=ripples.shape[0]
            else:
                print(f"[WARNING] Channel {channel} has a very low threshold. Skipping...")
    # Convert to numpy array
    windowed_input_data = np.array(windowed_input_data)
    windowed_gt = np.array(windowed_gt, dtype=np.float32)
else:
    print("Code Block already run. Skipping...")

print("Windowed Input Data Shape: ", windowed_input_data.shape)
print("Windowed GT Shape: ", windowed_gt.shape)

removed_windows = total_windows_count - windowed_input_data.shape[0]
print(f"Removed {removed_windows}/{total_windows_count} ({round((removed_windows / total_windows_count)*100, 2)}%) windows with no input activations")
print(f"Skipped {skipped_hfo_count} HFOs due to no input activations")
print(f"Total HFOs (theoretical): {total_hfos}")
# Save the windowed data


c:\Users\NCN\Documents\PedroFelix\LAVA_SNN_ripples\extract_Nripples\train_pedro\dataset_up_down\Amigo2_1_hippo_2019-07-11_11-57-07_1150um\1000
data_concat shape:  (2398857, 8, 2)
ripples_concat shape:  (1309, 2)
c:\Users\NCN\Documents\PedroFelix\LAVA_SNN_ripples\extract_Nripples\train_pedro\dataset_up_down\Som_2_hippo_2019-07-24_12-01-49_1530um\1000
data_concat shape:  (1036254, 8, 2)
ripples_concat shape:  (485, 2)
Windowed Input Data Shape:  (150791, 180, 2)
Windowed GT Shape:  (150791,)
Removed 19663/170454 (11.54%) windows with no input activations
Skipped 253 HFOs due to no input activations
Total HFOs (theoretical): 10425


The code block above outputs:

1. A list of time windows of shape =  `(num_windows, window_size, input_neurons) -- windowed_input_data`
2. A list of labels of shape = `(num_windows, ) -- windowed_gt`

In [80]:
# Define mask for windows with an HFO (spike time >= 0) in the GT

# TODO: Some ripples are detect outside the window - this is not compatible...

# See GT Class Distribution
# Set print options to see more elements
np.set_printoptions(linewidth=100, threshold=50, edgeitems=20)
print(f"Ground Truth Class Distribution: {np.unique(windowed_gt, return_counts=True)}")



GT_HFO_MASK = windowed_gt >= 0
# print(windowed_gt[0:1000])
# Define the number of windows with an HFO
num_hfo_windows = np.sum(GT_HFO_MASK)
print(f"Number of windows with an HFO: {num_hfo_windows}")
print(f"Percentage of windows with an HFO: {num_hfo_windows / windowed_gt.shape[0] * 100:.2f}%")

Ground Truth Class Distribution: (array([ -1.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,
        45.,  46.,  47.,  48.,  49., ..., 161., 162., 163., 164., 165., 166., 167., 168., 169.,
       170., 171., 172., 173., 174., 175., 176., 177., 178., 179., 180.], dtype=float32), array([140979,     21,     25,      6,     27,      5,     18,     15,      5,     57,     37,
           17,     28,     23,     23,     49,     56,     46,     28,     67, ...,     44,     51,
           86,     57,     45,     29,     39,     50,     39,     43,     54,     54,     21,
           33,      8,     10,     28,     24,     22,      8], dtype=int64))
Number of windows with an HFO: 9812
Percentage of windows with an HFO: 6.51%


**Note**: It's a good sign that the GT time is distributed along the time window.

For example, if the GT annotation could only occur on the first 25% timesteps of the window, it could converge the network toward not spiking or spiking initially to minimize the loss.


**QUESTION** : How to deal with edge cases?

## Interpreting the Ground Truth Timestamp values
- Lower limit: `MEAN_HFO_DURATION`
- Upper limit: `RIPPLE[0] + MEAN_HFO_DURATION`

Such that we only consider the ripples that are completely within the window.

The Marker annotation is a timestamp that indicates the approximated end of the HFO event. It is bounded upwards and downwards by the variables mentioned above.

# Class Balancing
We can see that Class 0 (No HFO) is much more frequent than Class 1 (HFO). This is expected since HFOs are rare events. However, we need to be careful with **class imbalance**, as it can lead to model overfitting and poor generalization.

- 6.51% of Windows -> HFO
- 93.49% of Windows -> No HFO

In [None]:
from utils.training import undersample_majority, oversample_minority
intermediate_input, intermediate_gt = windowed_input_data, windowed_gt  # Default: No Balancing
balance=True
if balance:
    intermediate_input, intermediate_gt = undersample_majority(windowed_input_data, windowed_gt, GT_HFO_MASK)  

# Print the number of samples in each class
print(f"Intermediate GT Class Distribution: {np.unique(intermediate_gt, return_counts=True)}")
print(f"Intermediate Window Input Data Shape: {intermediate_input.shape}")


Intermediate GT Class Distribution: (array([ -1.,  31.,  32.,  33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,
        45.,  46.,  47.,  48.,  49., ..., 161., 162., 163., 164., 165., 166., 167., 168., 169.,
       170., 171., 172., 173., 174., 175., 176., 177., 178., 179., 180.], dtype=float32), array([9812,   21,   25,    6,   27,    5,   18,   15,    5,   57,   37,   17,   28,   23,   23,
         49,   56,   46,   28,   67, ...,   44,   51,   86,   57,   45,   29,   39,   50,   39,
         43,   54,   54,   21,   33,    8,   10,   28,   24,   22,    8], dtype=int64))
Intermediate Window Input Data Shape: (19624, 180, 2)
