# ECoG Foundation Model Training
This is meant to be a minimal notebook which is capable of running model training with a free to use colab notebooks. Feel free to change this as you see fit for your experiments.

## Repo and data setup (if not already on your machine)

In [None]:
# Clone repository.
!git clone https://github.com/leoniekerken/ECoG-foundation-model.git

Now, go into the repo you just downloaded and change the hugging face user access token in the Makefile to your personal access token. If you don't want to do this everytime you could also upload the code to your personal drive and change the path_to_github_repo variable below, although then you risk your code being out of date.

In [None]:
# Download data.
!cd ECoG-foundation-model && make download-data

## Installs and imports

In [None]:
# Required pip installs.
!pip install accelerate
!pip install einops
!pip install mne
!pip install mne-bids
!pip install pyEDFlib

In [None]:
# The local path to the github repo. Must be accessible from this notebook.
# If you just run the code above this will work.
path_to_github_repo = '../'

In [None]:
# Add import for ECoG code.
import sys
import os
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import ipywidgets as widgets
from IPython.display import display, clear_output
import time
import matplotlib.pyplot as plt
%matplotlib inline
sys.path.append(os.path.join(path_to_github_repo, 'ECoG_MAE'))

# Other imports
from dataclasses import dataclass

from config import VideoMAEExperimentConfig, VideoMAETaskConfig, ViTConfig, TrainerConfig, ECoGDataConfig
import math
from ecog_setup import system_setup, model_setup
from loader import dl_setup
import constants
from train import train_model
import utils
from mae_st_util.models_mae import MaskedAutoencoderViT
from plot import plot_multi_band_reconstruction
from mask import get_padding_mask

## Configuration

In [None]:
# Configuration for this experiment. See class definition for possible config values and docstrings.
experiment_config = VideoMAEExperimentConfig(
        video_mae_task_config=VideoMAETaskConfig(
            vit_config=ViTConfig(
                # Dimensionality of embeddings through the encoder block.
                dim = 64,
                # Dimensionality of embeddings through the decoder block.
                decoder_embed_dim = 32,
                # Ratio of input dimensionality to use as a hidden layer in Transformer Block MLP's
                mlp_ratio = 4.0,
                # Number of transformer blocks in encoder.
                depth = 6,
                # Number of transformer blocks in decoder.
                decoder_depth = 3,
                # Number of attention heads per block in encoder. 
                num_heads = 8,
                # Number of attention heads per block in decoder. 
                decoder_num_heads = 4,
                # Number of electrodes in each patch of input to encoder.
                patch_size = 2,
                # Number of frames in each patch of input to encoder.
                frame_patch_size = 4,
                # If true, prepends a cls_token to input to get embedding for classification.
                use_cls_token = False,
                # If true then use a separate position embedding for the decoder.
                sep_pos_embed = True,
                # Use truncated normal initialization if True.
                trunc_init = False,
                # If True then don't use a bias for query, key, and values in attention blocks.
                no_qkv_bias = False,
            ),
            # Proportion of patches to mask out. See MAE as spatio temporal learners paper for details.
            encoder_mask_ratio = 0.25,
            # Percentage of masks tokens to pass into decoder for reconstruction.
            pct_masks_to_decode = 1.0,
            # If true then normalize the target before calculating loss. Input is normalized before passing in so likely
            # unnecessary.
            norm_pix_loss = False,
        ),
        trainer_config=TrainerConfig(
            # Peak learning rate to use in torch OneCyclerLR
            max_learning_rate = 5e-4,
        ),
        ecog_data_config=ECoGDataConfig(
            # What percentage of the data to train over.
            data_size = 1.0,
            # Batch size for training and eval.
            batch_size = 32,
            # If true then convert data to power envelope by taking magnitude of Hilbert
            # transform.
            env = False,
            # Frequency bands to transform the data into.
            bands = [[4, 8], [8, 13], [13, 30], [30, 55], [70, 200]],
            # Sample frequency of original dataset.
            original_fs = 512,
            # Resample rate for new data.
            new_fs = 120,
            # Relative path to the dataset root directory.
            dataset_path = '../dataset',
            # Proportion of data to have in training set. The rest will go to test set.
            train_data_proportion = 0.9,
            # Number of seconds of data to use for a training example.
            sample_length = 1,
            # If true then shuffle the data before splitting to train and eval.
            shuffle = False,
        ),
        # job_name='test_run',
    )

# Device to train on.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Number of training steps to run.
max_iters = 10000
# Number of validation steps for estimating performance on training and eval data set.
eval_iters = 50
# How frequently to check performance of validation data set.
eval_interval = 100

## Dataloader Setup

In [None]:
train_dl, test_dl, num_train_samples = dl_setup(experiment_config)

In [None]:
# The data is arranged in shape b*c*t*d*h*w, where
# b = batch size,
# c = freq bands,
# t = number of datapoints within a sample (args.new_fs samples per second)
# h = height of grid (currently 8)
# w = width of grid (currently 8)

print(next(train_dl._get_iterator()).shape)

## Model Setup

In [None]:
model_config = experiment_config.video_mae_task_config.vit_config
data_config = experiment_config.ecog_data_config

num_frames = experiment_config.ecog_data_config.sample_length * experiment_config.ecog_data_config.new_fs

frame_patch_size = model_config.frame_patch_size
num_patches = int(  # Defining the number of patches
    constants.GRID_SIZE**2
    * num_frames
    // model_config.patch_size
    // frame_patch_size
)

num_encoder_patches = int(
    num_patches * (1 - experiment_config.video_mae_task_config.encoder_mask_ratio)
)
num_decoder_patches = int(
    num_patches * experiment_config.video_mae_task_config.pct_masks_to_decode
)
print("num_patches", num_patches)
print("num_encoder_patches", num_encoder_patches)
print("num_decoder_patches", num_decoder_patches)
print("patch dimensionality", frame_patch_size * model_config.patch_size * model_config.patch_size * len(data_config.bands))
print("encoder embedding dimensionality", model_config.dim)
print("decoder embedding dimensionality", model_config.decoder_embed_dim)

model = MaskedAutoencoderViT(
    img_size=constants.GRID_SIZE,
    patch_size=model_config.patch_size,
    in_chans=len(experiment_config.ecog_data_config.bands),
    embed_dim=model_config.dim,
    depth=model_config.depth,
    num_heads=model_config.num_heads,
    decoder_embed_dim=model_config.decoder_embed_dim,
    decoder_depth=model_config.decoder_depth,
    decoder_num_heads=model_config.decoder_num_heads,
    mlp_ratio=model_config.mlp_ratio,
    norm_pix_loss=experiment_config.video_mae_task_config.norm_pix_loss,
    num_frames=num_frames,
    t_patch_size=model_config.frame_patch_size,
    no_qkv_bias=model_config.no_qkv_bias,
    sep_pos_embed=model_config.sep_pos_embed,
    trunc_init=model_config.trunc_init,
    cls_embed=model_config.use_cls_token,
    pred_t_dim=num_frames // model_config.frame_patch_size,
    img_mask=None,
    pct_masks_to_decode=experiment_config.video_mae_task_config.pct_masks_to_decode,
)
utils.count_params(model)

no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
opt_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": 1e-2,
    },
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]

optimizer = torch.optim.AdamW(
    opt_grouped_parameters, lr=experiment_config.trainer_config.max_learning_rate
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=experiment_config.trainer_config.max_learning_rate,
    total_steps=max_iters
)

## Training Loop

### Model Output Visualization

In [None]:
class TrainingVisualizer:
    def __init__(self):
        self.plots = {'train': [], 'val': []}
        self.iteration_nums = []
        self.losses = {'train': [], 'val': []}
        
        # Create widgets
        self.split_dropdown = widgets.Dropdown(
            options=['train', 'val'],
            description='Split:',
            value='train'
        )
        
        # Instead of a slider, use a dropdown for iterations
        self.iter_select = widgets.Dropdown(
            options=[],  # Will be populated as we get iterations
            description='Iteration:',
            value=None
        )
        
        self.output = widgets.Output()
        
        # Create layout
        controls = widgets.VBox([self.split_dropdown, self.iter_select])
        self.widget_layout = widgets.VBox([controls, self.output])
        display(self.widget_layout)
        
        # Link widgets
        self.interactive = widgets.interactive_output(
            self.update_plot,
            {'split': self.split_dropdown, 'iteration': self.iter_select}
        )
    
    def update_plot(self, split, iteration):
        if not self.plots[split] or iteration is None:  # No plots stored yet
            return
            
        with self.output:
            clear_output(wait=True)
            idx = self.iteration_nums.index(iteration)
            display(self.plots[split][idx])
            plt.close()
            print(f"Loss at iteration {iteration}: {self.losses[split][idx]:.4f}")
    
    def force_update(self):
        """Force the widget to update with current values"""
        self.update_plot(self.split_dropdown.value, self.iter_select.value)
        time.sleep(0.1)  # Small delay to allow display to update

# Create a global visualizer instance
visualizer = TrainingVisualizer()

### Train Loss Visualization

In [None]:
class LossPlotVisualizer:
    def __init__(self, window_size=20):
        self.window_size = window_size
        self.output = widgets.Output()
        display(self.output)
        
    def update_plot(self, loss_values, lrs):
        # Make sure number of loss values is divisible by window size.
        trimmed_loss_values = loss_values[:-(len(loss_values) % self.window_size)]
        with self.output:
            clear_output(wait=True)
            fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True)
            fig.set_figwidth(10)
            fig.set_figheight(8)
            averaged_loss = torch.tensor(trimmed_loss_values).view(-1, self.window_size).mean(1)
            ax1.plot([i * self.window_size for i in range(len(averaged_loss))], averaged_loss)
            ax1.set_title('Training Loss')
            ax1.set_xlabel('Steps (averaged over {} iterations)'.format(self.window_size))
            ax1.set_ylabel('Loss')
            ax1.grid(True)
            ax2.plot(lrs)
            ax2.set_title('Learning Rate')
            ax2.set_xlabel('Steps'.format(self.window_size))
            ax2.set_ylabel('Learning Rate')
            ax2.grid(True)
            plt.show()
            plt.close()

# Create the loss visualizer
loss_visualizer = LossPlotVisualizer()

### Util Functions

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            batch = get_batch(split)
            padding_mask = get_padding_mask(batch, device)
            model.initialize_mask(padding_mask)
            batch = torch.nan_to_num(batch)
            loss, pred, mask, latent, correlation = model(batch, mask_ratio=experiment_config.video_mae_task_config.encoder_mask_ratio)
            if k == 0:  # Only store the first reconstruction of each evaluation
                norm_batch = model.forward_input_norm(batch)
                pred_signal = model.unpatchify(pred)
                pred_signal_np = pred_signal.detach().cpu().numpy()
                masked_signal = apply_mask_to_batch(model, norm_batch, mask)
                fig = plot_multi_band_reconstruction(
                    norm_batch.detach().cpu().numpy(),
                    pred_signal_np,
                    experiment_config.video_mae_task_config.vit_config.frame_patch_size,
                    seen_signal=masked_signal.detach().cpu().numpy()
                )
                # Store the plot
                visualizer.plots[split].append(fig)
            losses[k] = loss.item()
        out[split] = losses.mean()
        # Store the loss
        visualizer.losses[split].append(out[split])
        
    model.train()
    return out

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    dataloader = train_dl if split == "train" else test_dl
    return next(enumerate(dataloader))[1]

def apply_mask_to_batch(model, batch, mask):
    """Replaces values in batch with nan when mask is False.
    model: VideoViT
    mask: (batch_size, frames // frame_patch_size)
    batch: (batch_size, num_channels, frames, electrode_height, electrode_width)
    """
    batch_patch = model.patchify(batch)
    # Mask is for every frame_patch_size frames, so expand to align with batch
    # tensor and fill in masked out information with nan.
    B, t = mask.shape
    _, T, P = batch_patch.shape
    frame_patch_size = T // t
    # Repeats mask values to align with batch dimensions.
    mask = mask.repeat_interleave(frame_patch_size * P, axis=1).view(B, T, P).to(torch.bool)
    return model.unpatchify(batch_patch.masked_fill(mask, torch.nan))

### Training

In [None]:
model = model.to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
loss_i = []
lr_i = []

# batch = get_batch("train")

for iter in range(max_iters):
    optimizer.zero_grad()
    
    if iter % eval_interval == 0 or iter == max_iters - 1:
        visualizer.iteration_nums.append(iter)
        visualizer.iter_select.options = visualizer.iteration_nums
        visualizer.iter_select.value = iter
        
        losses = estimate_loss()
        print(
            f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
        )
        visualizer.force_update()
        loss_visualizer.update_plot(loss_i, lr_i)

    batch = get_batch("train")
    signal = batch.to(device)
    padding_mask = get_padding_mask(signal, device)
    signal = torch.nan_to_num(signal)
    model.initialize_mask(padding_mask)
    
    loss, pred, mask, latent, correlation = model(signal, mask_ratio=experiment_config.video_mae_task_config.encoder_mask_ratio)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    
    loss_i.append(loss.item())
    lr_i.append(lr_scheduler.get_lr())

## Encoding

In [None]:
from downstream_tasks.encoding_decoding.config import EncodingDecodingExperimentConfig, EncodingDecodingTaskConfig, EncodingDecodingDataConfig
from downstream_tasks.encoding_decoding.utils import run_encoding_task, run_decoding_task

In [None]:
encoding_experiment_config = EncodingDecodingExperimentConfig(
    encoding_data_config = EncodingDecodingDataConfig(
        conversation_data_df_path = os.path.join(path_to_github_repo, "word-embeddings/gpt2-layer-8-emb.pkl"),
        encoding_neural_data_folder = os.path.join(path_to_github_repo, "preprocessed-highgamma"),
        electrode_glob_path = "NY*_*_Part*_conversation*_electrode_preprocess_file_{elec_id}.mat",
        lag = 0
    ),
    encoding_task_config = EncodingDecodingTaskConfig(
        model_path = "", # Unused here.
        embedding_device = "cuda",
        embedding_batch_size = 8,
        num_folds = 2,
    )
)

In [None]:
pearson_correlations, mspe = run_encoding_task(encoding_experiment_config, experiment_config.ecog_data_config, model)

In [None]:
pearson_correlations

In [None]:
mspe

## Decoding

In [None]:
pearson_correlations, mspe = run_encoding_task(encoding_experiment_config, experiment_config.ecog_data_config, model)

In [None]:
pearson_correlations

In [None]:
mspe