In [61]:
import os
import numpy as np
import torch
import scipy.io
import scipy.signal
import picard
import sklearn.preprocessing

In [62]:
show = __name__ == "__main__"

In [63]:
dataset = torch.load("dataset.pt", weights_only=True)

# Import the relevant metadata
for key, value in dataset["metadata"].items():
    locals()[key] = value

data = dataset['data']
labels = dataset['labels']

for key, value in dataset["metadata"].items():
    locals()[key] = value

In [64]:
def normalize(data, dim=None):
    means = data.mean(dim=dim, keepdim=True)
    stds = data.std(dim=dim, keepdim=True)
    data -= means
    data /= stds
    return data
    
if show:
    data_normal = normalize(data, dim=2)

In [65]:
def common_average_referencing(data, dim=None, residual=False):
    means = data.mean(dim=dim, keepdim=True)
    data -= means
    if residual:
        return torch.cat([data, means], dim=dim)
    return data

if show:
    data_car = common_average_referencing(data, dim=1)

In [66]:
if show:
    data_normal_car = normalize(common_average_referencing(data, dim=1), dim=2)
    print(data_normal_car.shape)    

torch.Size([9315, 6, 4096])


In [67]:
# Super naive spike thesholdning. Everything beyond +/- 1 SD is a spike.
def abs_threshold(data, threshold=1):
    return (data.abs() > 1).float()


In [68]:
if show:
    spikes = abs_threshold(data_normal_car, threshold=1)
    print(spikes.shape)

torch.Size([9315, 6, 4096])


In [69]:
if show:
    print(spikes)

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 0.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[0., 0., 0.,  ..., 1., 1., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.]],

        [[0., 0., 0.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 1., 1., 1.]],

        ...,

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0

In [70]:
# Naive spike thesholdning. Everything beyond +/- 1 SD is a spike.
def pol_threshold(data, threshold=1, dim=1):
    return torch.cat(((data > threshold).float(), (data < -threshold).float()), dim=dim)

In [71]:
if show:
    spikes_pol = pol_threshold(data_normal_car, threshold=1, dim=1)
    spikes_pol.shape

In [72]:
def delta_coding(data, delta=0.1):
    diff = torch.diff(data, dim=2, prepend=data[:,:,:1])
    spikes_diff = (diff / delta).to(int)
    # This wasn't exactly obvious. floor(1.5) is 1 while floor (-1.5) is -2
    spikes_delta = torch.cat((
        torch.floor(spikes_diff).clamp(min=0),
        -torch.ceil(spikes_diff).clamp(max=0),
    ), dim=1).to(torch.float32)
    return spikes_delta


In [73]:
if show:
    spikes_delta = delta_coding(data_normal_car, delta=0.1)
    print(spikes_delta.shape)

torch.Size([9315, 12, 4096])


In [74]:
if show:
    print(spikes_delta[0,0,1000:1100], spikes_delta[0,7,1000:1100])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 1., 1., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        1., 1., 1., 1., 1., 2., 2., 2., 1., 1.])


In [75]:
if show:
    print(spikes_delta[0,0,:].sum(), spikes_delta[0,7,:].sum())

tensor(855.) tensor(650.)


In [76]:
def rate_coding(data, sensitivity=0.1):
    spikes_quant = data / sensitivity
    spikes_rate = torch.cat((
        torch.floor(spikes_quant).clamp(min=0),
        -torch.ceil(spikes_quant).clamp(max=0),
    ), dim=1).to(int).to(torch.float32)
    return spikes_rate


In [77]:
if show:
    spikes_rate = rate_coding(data_normal_car, sensitivity=0.1)
    print(spikes_rate[0,0,:].sum(), spikes_rate[0,7,:].sum())

tensor(15121.) tensor(15308.)


In [78]:
if show:
    print(spikes_rate.shape)

torch.Size([9315, 12, 4096])


In [79]:
if show:
    spikes_rate[0,0,1000:1100], spikes_rate[0,7,1000:1100]

In [80]:
def bin_sum(data, size, dim=-1):
    '''Bin the data into fewer bins by summation, analogous to Tonic.ToFrame. Implementation by ChatGPT.'''
    total_indices = data.size(dim)
    quotient, remainder = divmod(total_indices, size)
    bin_sizes = [quotient + (1 if i >= size - remainder else 0) for i in range(size)]
    splits = torch.split(torch.arange(total_indices), bin_sizes)
    return torch.stack([data.index_select(dim, idx).sum(dim=dim) for idx in splits], dim=dim)

In [81]:
def collate(batch):
    events, targets = zip(*batch)
    events = torch.stack(events).permute(2, 0, 1)
    targets = torch.tensor(targets)
    return events, targets

In [82]:
def downsample(data, in_rate=rate, out_rate=rate//8):
    samples = data.size(-1) * out_rate // in_rate
    result = scipy.signal.resample(data.numpy(), samples, axis=-1)
    return torch.tensor(result, device=data.device, dtype=data.dtype)

In [83]:
if show:
    data_downsampled = downsample(data)
    print(data_downsampled.shape, data_downsampled)

torch.Size([9315, 6, 512]) tensor([[[-1.7901, -1.2316, -0.2249,  ..., -0.4548, -0.1657, -0.8302],
         [-0.7497, -1.2089, -0.7197,  ...,  0.2524,  0.0408,  0.4718],
         [ 0.1546, -0.2013, -1.2267,  ..., -0.2937, -0.8689, -0.4492],
         [ 1.6666,  0.2609, -0.1395,  ...,  2.1496,  2.2918,  2.3698],
         [ 0.5850,  2.0446,  1.3731,  ..., -1.4338, -1.6507, -1.8923],
         [ 1.6648,  1.1493,  1.0155,  ...,  0.6244,  1.0931,  1.4547]],

        [[-0.1296,  0.3201, -0.0952,  ...,  0.2230, -0.2758, -1.3174],
         [-1.2068, -1.2258, -1.0657,  ...,  0.2338, -0.4884, -0.7140],
         [ 0.0492,  1.0897,  1.3169,  ...,  0.5872,  0.8851, -0.2034],
         [-0.4806, -0.6937, -1.0926,  ..., -0.1032, -0.9620, -0.8518],
         [ 1.2148,  0.9044,  0.7386,  ..., -0.1825,  0.8674,  1.8076],
         [ 0.3329, -0.4422,  0.1485,  ..., -0.6295, -0.0833,  1.0682]],

        [[ 1.3713,  0.3089, -0.2299,  ...,  2.9085,  2.2313,  1.8590],
         [ 0.9662,  0.3754,  0.6685,  ...,  1.

In [84]:
import warnings

# Never mind the warnings...
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.base")

def ica(data, max_iter=1000, tol=1e-6):
    results = []
    for d in data:
        d = d.transpose(0, 1)
        transform = picard.Picard(d.size(1), max_iter=1000, tol=1e-6)
        result = torch.tensor(transform.fit_transform(d),
                              dtype=data.dtype,
                              device=data.device).transpose(0, 1)
        results.append(result)
    return torch.stack(results)

In [85]:
if show:
    y = ica(data_downsampled)
    print(y.shape, y[0])

torch.Size([9315, 6, 512]) tensor([[ 0.0103, -0.0627, -0.0557,  ...,  0.0420,  0.0457,  0.0889],
        [ 0.0667,  0.0152, -0.0047,  ...,  0.0820,  0.0873,  0.0897],
        [ 0.0460,  0.0295, -0.0066,  ...,  0.0190,  0.0036,  0.0094],
        [-0.0461, -0.0633, -0.0741,  ...,  0.0075, -0.0168, -0.0154],
        [-0.0178, -0.0377,  0.0181,  ...,  0.0685,  0.0834,  0.0499],
        [ 0.0052,  0.0203,  0.0342,  ..., -0.0011, -0.0159, -0.0094]])


In [86]:
# This was provided by o1 based on the disussion here:
# https://chatgpt.com/share/67669e9c-ebd8-8007-92f0-70579f58cc9a

def ica_fit_and_transform(
    data: torch.Tensor,      # shape [N, C, T]
    train_mask: torch.Tensor,# shape [N], boolean or long indices
    max_iter: int = 1000,
    tol: float = 1e-6
) -> torch.Tensor:
    """
    Learn a single ICA decomposition on all training-set trials, 
    then apply that unmixing to the entire dataset.

    Args:
        data (torch.Tensor): 
            EEG data of shape [num_trials, num_channels, num_timepoints].
        train_mask (torch.Tensor): 
            Boolean or index mask of shape [num_trials] indicating which trials are training.
        max_iter (int): 
            Maximum number of Picard iterations.
        tol (float): 
            Tolerance for the stopping criterion in Picard.

    Returns:
        torch.Tensor: 
            ICA-transformed data of shape [num_trials, num_channels, num_timepoints].
            Each trial has the same shape as input but in the ICA component space.
    """

    device = data.device
    dtype = data.dtype

    # ------------------------------------------
    # 1) EXTRACT & CONCATENATE ALL TRAINING TRIALS
    # ------------------------------------------
    # data[train_mask] has shape [n_train, C, T]
    train_data = data[train_mask]
    n_train, C, T = train_data.shape

    # We want shape [n_samples, n_features] = [n_train*T, C]
    # So permute to [n_train, T, C], then reshape
    # => shape [n_train*T, C]
    train_data_2d = train_data.permute(0, 2, 1).reshape(-1, C)

    # Move to CPU if necessary, because Picard typically runs on NumPy
    train_data_np = train_data_2d.cpu().numpy()

    # ------------------------------------------
    # 2) FIT ICA ON ALL TRAINING DATA
    # ------------------------------------------
    ica_transform = picard.Picard(
        n_components=C, 
        max_iter=max_iter, 
        tol=tol
    )
    ica_transform.fit(train_data_np)

    # ------------------------------------------
    # 3) APPLY ICA TO ALL TRIALS (TRAIN/VAL/TEST)
    # ------------------------------------------
    results = []
    N = data.shape[0]
    for i in range(N):
        # Each trial: shape [C, T]
        trial = data[i]  
        # Transpose to [T, C] for transform()
        trial_t = trial.transpose(0, 1).cpu().numpy()

        # Transform with the learned unmixing
        trial_ica = ica_transform.transform(trial_t)  
        # => shape [T, C]

        # Convert back to Torch, and transpose to [C, T] if desired
        trial_ica_torch = torch.tensor(
            trial_ica, dtype=dtype, device=device
        ).transpose(0, 1)

        results.append(trial_ica_torch)

    # Stack back into [N, C, T]
    return torch.stack(results, dim=0)

In [87]:
if show:
    train_mask = torch.rand((len(data_downsampled),)) < 0.8
    z = ica_fit_and_transform(data_downsampled, train_mask=train_mask)
    print(z.shape, z[0])

torch.Size([9315, 6, 512]) tensor([[ 7.7457e-04,  5.2590e-04,  3.4034e-04,  ...,  4.3600e-04,
          4.8140e-04,  6.1021e-04],
        [-2.0970e-04,  2.6050e-04,  8.5429e-05,  ..., -7.0999e-05,
         -2.7817e-04, -7.3878e-04],
        [-4.4021e-04,  5.1835e-04,  1.0887e-03,  ..., -7.2686e-04,
         -4.6791e-04, -8.5118e-04],
        [-2.5562e-04,  7.5510e-05,  3.3148e-04,  ..., -1.4841e-05,
          3.0130e-04, -2.7047e-04],
        [-1.9194e-03, -1.8411e-03, -9.6671e-04,  ..., -2.4568e-04,
         -2.1815e-04, -5.0176e-04],
        [ 3.0886e-04, -2.6873e-04, -4.9061e-05,  ...,  9.5303e-04,
          1.2210e-03,  1.0832e-03]])


In [88]:
def bandpass_filter(data, low=2, high=40, rate=1024, order=5):
    sos = scipy.signal.butter(order, [low / rate / 2, high / rate / 2], btype='band', output='sos')
    return torch.tensor(scipy.signal.sosfilt(sos, np.ascontiguousarray(data)), dtype=data.dtype, device=data.device)

In [89]:
if show:
    a = bandpass_filter(data)
    print(a.shape, a[0])

torch.Size([9315, 6, 4096]) tensor([[-4.5239e-08, -4.8647e-07, -2.6267e-06,  ...,  7.2935e-01,
          7.1904e-01,  7.0679e-01],
        [-2.7523e-08, -2.9697e-07, -1.6099e-06,  ..., -3.5150e-02,
         -3.9813e-02, -4.4669e-02],
        [ 7.8006e-09,  8.3659e-08,  4.5010e-07,  ..., -1.0281e+00,
         -1.0309e+00, -1.0312e+00],
        [ 2.9645e-08,  3.1735e-07,  1.7043e-06,  ...,  4.2283e-01,
          4.1938e-01,  4.1686e-01],
        [ 3.3837e-08,  3.6639e-07,  1.9946e-06,  ..., -4.5932e-01,
         -4.5563e-01, -4.5226e-01],
        [ 3.7081e-08,  3.9848e-07,  2.1501e-06,  ...,  5.1451e-02,
          7.4699e-02,  9.8546e-02]])


In [90]:
def robust_scaler(data):
    scaled = sklearn.preprocessing.RobustScaler().fit_transform(data.cpu().numpy().reshape(-1, data.shape[-1]))
    return torch.tensor(scaled.reshape(data.shape), dtype=data.dtype, device=data.device)

In [91]:
if show:
    b = robust_scaler(data)
    print(b.shape, b[0])

torch.Size([9315, 6, 4096]) tensor([[-1.7777, -1.6704, -1.5620,  ..., -0.9362, -0.9044, -0.8454],
        [-1.0794, -1.0530, -1.0252,  ...,  0.1270,  0.0442, -0.0486],
        [ 0.3131,  0.2875,  0.2523,  ..., -0.1105, -0.1279, -0.1621],
        [ 1.1742,  1.0497,  0.9236,  ...,  1.6232,  1.5689,  1.5189],
        [ 1.3394,  1.3612,  1.3808,  ..., -0.8492, -0.7490, -0.6496],
        [ 1.4673,  1.3717,  1.2776,  ...,  1.1639,  1.1454,  1.1282]])
