# Initialization
This project assumes GitHub and Google Cloud Storage info is stored in environmental variables.  
It was developed in Google Colab, but can be run locally by selecting 'local' as the environment.  

For setup details (including required environment variables), see the README.md in the GitHub repository:  
https://github.com/tristan-day-research/NeuroStorm_seizure_detection  

In [1]:
# --- SETUP GITHUB AND GCP ENVIRONMENT VARIABLES ---
# Ensure the following environmental variables are set in Colab user data:
# - GITHUB_PAT: GitHub Personal Access Token
# - GITHUB_EMAIL: GitHub email for commits
# - GITHUB_USER_NAME: GitHub username
# - GCP_EEG_PROJECT_ID: Google Cloud Project ID
# - GCP_EEG_BUCKET_NAME: (Optional) GCP bucket for EEG data

# Select environment
ENVIRONMENT = 'colab'   # Choose 'local' or 'colab'
BRANCH_NAME = 'main'

if ENVIRONMENT == 'colab':
    from google.colab import userdata

    # Retrieve GitHub credentials from Colab user data
    token = userdata.get('GITHUB_PAT')
    github_email = userdata.get('GITHUB_EMAIL')
    github_username = userdata.get('GITHUB_USER_NAME')

    # Clone the repository (done here as the helper file isn't available yet)
    !git clone -b {BRANCH_NAME} https://{token}@github.com/tristan-day-research/NeuroStorm_seizure_detection.git

    # Change to correct directory
    %cd /content/NeuroStorm_seizure_detection/

    # Load the helper file now that the repo is cloned
    from src.setup import configure_environment

    # Run full environment setup
    bucket_name = configure_environment(environment=ENVIRONMENT)


Cloning into 'NeuroStorm_seizure_detection'...
remote: Enumerating objects: 172, done.[K
remote: Counting objects: 100% (172/172), done.[K
remote: Compressing objects: 100% (93/93), done.[K
remote: Total 172 (delta 80), reused 156 (delta 69), pack-reused 0 (from 0)[K
Receiving objects: 100% (172/172), 1.92 MiB | 10.06 MiB/s, done.
Resolving deltas: 100% (80/80), done.
/content/NeuroStorm_seizure_detection
GCP Project Set
Git configured with your user data.


In [None]:
# Standard Library Imports
import os
import gc
import json
import time
import glob
import math
import random
import threading
import subprocess
import logging
from collections import OrderedDict
from datetime import datetime
from functools import partial
from warnings import warn

# Third-Party Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors
plt.rcParams["font.family"] = "serif"
from tqdm import tqdm

# Machine Learning/Deep Learning Imports
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torch.nn.init import trunc_normal_
from torch.nn.utils import clip_grad_norm_, rnn
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import GradScaler, autocast
from torchvision.transforms import Compose

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
!nvidia-smi

Using device: cpu
/bin/bash: line 1: nvidia-smi: command not found


In [2]:
print(bucket_name)

None


# Preprocessing
Each EEG raw signal is segemnted into patches of a fixed length. Fast Fourier Transform (FFT) spectra are made from these patches which will be used to train the Vector-Quantized Variational Autoencoder (VQ-VAE).

In [None]:
# bucket_name = os.getenv('GCP_EEG_BUCKET_NAME')
!echo $GCP_EEG_BUCKET_NAME

print(bucket_name)


None


In [None]:
import torch
import matplotlib.pyplot as plt
from src.data_and_FFT import EEGDataset

# Visualization Function (Separate from Data Loader)
def visualize_eeg_and_fft(eeg_tensor, fft_tensor, fft_size, sample_idx=0):
    raw_signal = eeg_tensor[sample_idx, :, 0].cpu().numpy()  # [patch_size]
    fft_signal = fft_tensor[sample_idx, :, 0].cpu().numpy()  # [fft_size // 2 + 1]

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(raw_signal)
    plt.title("Raw EEG Signal (Time Domain)")
    plt.xlabel("Time")
    plt.ylabel("Amplitude")

    plt.subplot(1, 2, 2)
    plt.plot(fft_signal)
    plt.title("FFT of EEG Signal (Frequency Domain)")
    plt.xlabel("Frequency Bin")
    plt.ylabel("Magnitude")

    plt.tight_layout()
    plt.show()


# Parameters
file_prefix = 'kaggle/input/hms-harmful-brain-activity-classification'
patch_size = 200
overlap = 50
fft_size = 256


# Load Dataset and Visualize
def load_and_visualize_samples(bucket_name, file_prefix, patch_size, overlap, fft_size):
    dataset = EEGDataset(
        bucket_name=bucket_name,
        file_prefix=file_prefix,
        patch_size=patch_size,
        overlap=overlap,
        fft_size=fft_size
    )

    print(f"Found {len(dataset)} files. Visualizing...")

    for i in range(min(3, len(dataset))):  # Visualize 3 samples or fewer
        fft_data, mask = dataset[i]
        visualize_eeg_and_fft(fft_data, fft_data, fft_size)


# Run Test
load_and_visualize_samples(
    bucket_name=bucket_name,
    file_prefix=file_prefix,
    patch_size=patch_size,
    overlap=overlap,
    fft_size=fft_size
)


# Tokenizer

In [None]:
from src.vqvae.train import train, validate
from src.vqvae.data import EEGDataset, ToPatches
from torch.optim import AdamW
import torch

# --- Hyperparameters ---
stride = 150
batch_size = 4
num_workers = 4
num_epochs = 7
lr = 1e-4
lr_scheduler_step_size = 1
lr_scheduler_gamma = 0.9
accumulation_steps = 2

model_name = "2025_VQVAE_v1"
codebook_size = 1024
emb_dim = 64

# --- Dataset and Loader ---
transform_to_patches = ToPatches(patch_size=200, stride=stride)
eeg_dataset = EEGDataset(bucket_name=BUCKET_NAME, blob_prefix="train_eegs_HMS_processed",
                         transform=transform_to_patches)
train_size = int(0.8 * len(eeg_dataset))
valid_size = len(eeg_dataset) - train_size
train_dataset, valid_dataset = random_split(eeg_dataset, [train_size, valid_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# --- Model, Optimizer, Scheduler ---
model_class = globals()[model_name]
model = model_class(codebook_size=codebook_size, emb_dim=emb_dim)
model.to(device)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
scheduler = StepLR(optimizer, step_size=lr_scheduler_step_size, gamma=lr_scheduler_gamma)

loss_function = partial(fft_masked_mse_loss, phase_start_batch=200, phase_end_batch=400)

# --- Train and Validate ---
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    avg_loss = train(model, train_loader, optimizer, loss_function, device,
                     scheduler=scheduler, accum_steps=accumulation_steps)
    print(f"Train Loss: {avg_loss:.4f}")
    val_loss = validate(model, valid_loader, loss_function, device)
    print(f"Validation Loss: {val_loss:.4f}")
#
