# preprocess data
split the eeg into small chunks

In [23]:
from tqdm.notebook import tqdm
from scipy.signal import butter, lfilter, sosfilt, iirnotch
import pandas as pd
import numpy as np
import torch as t

import matplotlib.pyplot as plt

In [24]:
OUT_DIR = './eeg-band-1-70-notch-60/'
BASE_PATH = './hms-harmful-brain-activity-classification/'
DATA_PATH = './hms-harmful-brain-activity-classification/train_eegs/'

In [25]:
# data has shape (sequence_len=10000, probes=20)

def clip(data, bound=300):
    return np.clip(data, -bound, bound)

def robust_norm(data):
    median = np.median(data, axis=0)
    q75, q25 = np.percentile(data, [75 ,25], axis=0)
    iqr = q75 - q25
    iqr[iqr < 1e-6] = 1e-6 # numerical stability
    return (data - median) / iqr

def band_filter(data, low=1, high=70, fs=200, order=4):
    sos = butter(N=order, Wn=[low, high], btype='bandpass', fs=fs, output='sos')
    return sosfilt(sos, data, axis=0)

def notch_filter(data, notch_freq=60, Q=30, fs=200):
    b_notch, a_notch = iirnotch(notch_freq, Q, fs)
    return lfilter(b_notch, a_notch, data, axis=0)

# def logify(data):
#     log_data = np.log1p(np.abs(data))
#     log_data[data < 0] *= -1
#     return log_data

# def normalize_signals(data):
#     mean = data.mean(axis=0, keepdims=True)
#     std = data.std(axis=0, keepdims=True)
#     std_adjusted = np.where(std > 1e-10, std, 1) # numerical stability
#     return (data - mean) / std_adjusted

# def butter_bandpass_filter(data, lowcut=1. , highcut =40., fs=200, order=6):
#     b, a = butter(order, [lowcut / (0.5 * fs), highcut / (0.5 * fs)], btype='band')
#     return lfilter(b, a, data)

def filters(data):
    data = band_filter(data)
    data = notch_filter(data)
    # data = robust_norm(data)
    return data

In [26]:
def preprocess(df):
    sample_rate = 200
    duration = 10_000
    for eeg_id, group in tqdm(df.groupby('eeg_id')):
        parquet_file = f'{DATA_PATH}{eeg_id}.parquet'
        raw_eeg = pd.read_parquet(parquet_file)
        for _, row in group.iterrows():
            eeg_sub_id = row['eeg_sub_id']
            offset = int(row['eeg_label_offset_seconds'] * sample_rate)
            eeg = raw_eeg.iloc[offset:offset + duration]
            eeg = eeg.ffill(axis=0).fillna(0)
            filtered_eeg = filters(eeg.values)
            data = t.tensor(filtered_eeg).float()
            # plt.plot(filtered_eeg[:1000, 0] - filtered_eeg[:1000, 1])
            # plt.show()
            # plt.plot(filtered_eeg[:1000, :19])
            # plt.show()
            # plt.plot(filtered_eeg[:1000, :18] - filtered_eeg[:1000, 1:19])
            # plt.show()
            t.save(data, f'{OUT_DIR}{eeg_id}_{eeg_sub_id}.pt')
df = pd.read_csv(f'{BASE_PATH}train.csv')
preprocess(df)

  0%|          | 0/17089 [00:00<?, ?it/s]

In [27]:
import random
random.choice([0.01, 0.05, 0.1, 0.15, 0.2, 0.3])

0.15

In [28]:
def soft_rounding(input_tensor, threshold=0.01, steepness=10):
    # Apply a sigmoid function to softly threshold values close to 0
    # The steepness parameter controls how sharp the transition is
    # Values much larger than the threshold will be near 1 after this, and those much smaller will be near 0
    weight = t.sigmoid(steepness * (input_tensor - threshold))

    # Apply the weight, values close to 0 get diminished
    adjusted_tensor = input_tensor * weight

    # Renormalize to ensure the sum is 1
    normalized_tensor = adjusted_tensor / adjusted_tensor.sum()

    return normalized_tensor

# Example usage
input_tensor = t.tensor([0.005, 0.995, 0.0001, 0.0003], dtype=t.float32)
output_tensor = soft_rounding(input_tensor)
print(output_tensor)


tensor([2.4434e-03, 9.9737e-01, 4.7643e-05, 1.4308e-04])


In [29]:
def round_zero(res, threshold=1e-2):
    res[res < threshold] = 0
    print(res.sum(dim=-1, keepdim=True))
    res = res / res.sum(dim=-1, keepdim=True)
    return res

res = t.tensor([
    [0.0001, 0.002, 0.997, 0.0001],
    [0.0001, 0.02, 0.45, 0.44]
])

round_zero(res)

tensor([[0.9970],
        [0.9100]])


tensor([[0.0000, 0.0000, 1.0000, 0.0000],
        [0.0000, 0.0220, 0.4945, 0.4835]])