In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_baseline_model.py
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch 
import shutil

args_path = 'rnn_args.yaml'
if not os.path.exists(args_path):
    print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
    args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    
args = OmegaConf.load(args_path)

# Force the script to use the correct Kaggle data paths
args.dataset.dataset_dir = "/kaggle/working/brain-to-text-25-minimal-shirley/t15_copyTask_neuralData/hdf5_data_final"

# Train on first 3 sessions, reserve 4th for testing
args.dataset.sessions = [
    't15.2023.08.11',
    't15.2023.08.13',
    't15.2023.08.18',
    't15.2023.08.20'
]

# Update n_days to match number of sessions
args.model.n_days = len(args.dataset.sessions)

# Update dataset probabilities
# Note: t15.2023.08.11 has no val data, so set its val probability to 0
if hasattr(args.dataset, 'dataset_probability_val'):
    # Set to 0 for days without validation data
    args.dataset.dataset_probability_val = [0.0, 1.0, 1.0, 1.0]

# Disable logging of individual day validation PER to avoid division by zero
args.log_individual_day_val_PER = False

args.num_training_batches = 10000
print(f"\nTraining for {args.num_training_batches} batches.")
args.days_per_batch = 2
print(f"\nTraining for {args.days_per_batch} days per batches.")
args.batch_size = 16
print(f"\nTraining for batch size = {args.batch_size}.")


# Set output directories
new_output_dir = "trained_models/baseline_rnn/progressive_training"
new_checkpoint_dir = "trained_models/baseline_rnn/checkpoint/progressive_training"
args.output_dir = new_output_dir
args.checkpoint_dir = new_checkpoint_dir

# Remove old directories if they exist
if os.path.exists(args.output_dir):
    print(f"Removing existing output directory: {args.output_dir}")
    shutil.rmtree(args.output_dir) 
if os.path.exists(args.checkpoint_dir):
    print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
    shutil.rmtree(args.checkpoint_dir) 

print(f"\nConfiguration:")
print(f"  Training sessions: {args.dataset.sessions}")
print(f"  Training: All 4 sessions using data_train.hdf5")
print(f"  Validation: Sessions 2-4 using data_val.hdf5 (Session 1 has no val data)")
print(f"  Reserved: data_test.hdf5 from all sessions (for future testing)")
print(f"  Optimizer: AdamW (default)")
print(f"  Learning rate: {args.lr_max}")
print(f"  Batch size: {args.dataset.batch_size}")
print(f"  Output dir: {args.output_dir}")

# Verify data files exist
print("\nVerifying data files:")
for i, session in enumerate(args.dataset.sessions):
    train_path = os.path.join(args.dataset.dataset_dir, session, 'data_train.hdf5')
    val_path = os.path.join(args.dataset.dataset_dir, session, 'data_val.hdf5')
    test_path = os.path.join(args.dataset.dataset_dir, session, 'data_test.hdf5')
    
    train_exists = os.path.exists(train_path)
    val_exists = os.path.exists(val_path)
    test_exists = os.path.exists(test_path)
    
    if not train_exists:
        print(f"    WARNING: Missing training data!")

# CREATE TRAINER
print("\nInitializing trainer...")
trainer = BrainToTextDecoder_Trainer(args) 

# RUN TRAINING
print("Starting base model training...")
train_stats = trainer.train() 
val_per_list = train_stats.get('val_PERs', []) 
val_score = np.min(val_per_list) if val_per_list else 1.0 
print(f"Base Model Training Finished")
print(f"Final Best (min) Validation PER: {val_score:.4f}")
print(f"Model checkpoint saved in: {args.checkpoint_dir}")

In [None]:
!cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python train_baseline_model.py

In [None]:
import os

eval_script = "/kaggle/working/nejm-brain-to-text/model_training/evaluate_sessions.py"
model_path = "/kaggle/working/nejm-brain-to-text/model_training/trained_models/baseline_rnn/checkpoint/progressive_training"
data_dir = "/kaggle/working/brain-to-text-25-minimal-shirley/t15_copyTask_neuralData/hdf5_data_final"
eval_type = "test"
gpu_number = 0

target_sessions = ["t15.2023.08.13", "t15.2023.08.18", "t15.2023.08.20"]
sessions_str = " ".join(target_sessions)
cmd = f"""
cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python {eval_script} \
    --model_path {model_path} \
    --data_dir {data_dir} \
    --eval_type {eval_type} \
    --gpu_number {gpu_number} \
    --sessions {sessions_str}
"""

print(f"Model: {model_path}")
print(f"Eval type: {eval_type}")
print("\nEvaluating on:")
for session in target_sessions:
    print(f"Evaluating {session}: data_{eval_type}.hdf5")
os.system(cmd)
print("Evaluation finished!")

In [None]:
import pandas as pd
def analyze_predictions(output_file):
    df = pd.read_csv(output_file)
    
    overall_avg = df['trial_acc'].mean()
    print(f"Overall average trial accuracy: {overall_avg:.4f}")

    # session_avg = df.groupby('session')['trial_acc'].mean()
    # print("\nAverage accuracy per session:")
    # print(session_avg)

    # Calculate AGGREGATE PER
    total_edit_distance = 0
    total_phoneme_length = 0
    for index, row in df.iterrows():
        pred_seq = row['true_phoneme'].split('-')
        true_seq = row['true_phoneme'].split('-')

        ed = row['trial_ed']
        true_len = len(true_seq)
        
        total_edit_distance += ed
        total_phoneme_length += true_len
            
    aggregate_per = total_edit_distance / total_phoneme_length
    print(f"Aggregate Phoneme Error Rate (PER): {aggregate_per:.4f}")
        
    avg_loss = df['trial_ctc_loss'].mean()
    print(f"Average Validation Loss: {avg_loss:.4f}")

# Open the csv and perform some analysis on it
output_file = "/kaggle/working/nejm-brain-to-text/model_training/output/phoneme_predictions_20251105_065010.csv"
df = pd.read_csv(output_file)
df.head()
analyze_predictions(output_file)