In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_baseline_model.py
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch 
import shutil

args_path = 'rnn_args.yaml'
if not os.path.exists(args_path):
    print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
    args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    
args = OmegaConf.load(args_path)

# Force the script to use the correct Kaggle data paths
args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

# Train on first 3 sessions, reserve 4th for testing
args.dataset.sessions = [
    't15.2023.08.11',
    't15.2023.08.13',
    't15.2023.08.18',
    't15.2023.08.20'
]

# Update n_days to match number of sessions
args.model.n_days = len(args.dataset.sessions)

# Update dataset probabilities
# Note: t15.2023.08.11 has no val data, so set its val probability to 0
if hasattr(args.dataset, 'dataset_probability_train'):
    args.dataset.dataset_probability_train = [1.0, 1.0, 1.0, 1.0]
if hasattr(args.dataset, 'dataset_probability_val'):
    # Set to 0 for days without validation data
    args.dataset.dataset_probability_val = [0.0, 1.0, 1.0, 0.0]

# Disable logging of individual day validation PER to avoid division by zero
args.log_individual_day_val_PER = False

args.num_training_batches = 10000
print(f"\nTraining for {args.num_training_batches} batches.")
args.days_per_batch = 2
print(f"\nTraining for {args.days_per_batch} days per batches.")
args.batch_size = 16
print(f"\nTraining for batch size = {args.batch_size}.")


# Set output directories
new_output_dir = "trained_models/baseline_rnn/progressive_training"
new_checkpoint_dir = "trained_models/baseline_rnn/checkpoint/progressive_training"
args.output_dir = new_output_dir
args.checkpoint_dir = new_checkpoint_dir

# Remove old directories if they exist
if os.path.exists(args.output_dir):
    print(f"Removing existing output directory: {args.output_dir}")
    shutil.rmtree(args.output_dir) 
if os.path.exists(args.checkpoint_dir):
    print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
    shutil.rmtree(args.checkpoint_dir) 

# Verify data files exist
print("\nVerifying data files:")
for i, session in enumerate(args.dataset.sessions):
    train_path = os.path.join(args.dataset.dataset_dir, session, 'data_train.hdf5')
    val_path = os.path.join(args.dataset.dataset_dir, session, 'data_val.hdf5')
    test_path = os.path.join(args.dataset.dataset_dir, session, 'data_test.hdf5')
    
    train_exists = os.path.exists(train_path)
    val_exists = os.path.exists(val_path)
    test_exists = os.path.exists(test_path)
    
    if not train_exists:
        print(f"    WARNING: Missing training data!")

# CREATE TRAINER
print("\nInitializing trainer...")
trainer = BrainToTextDecoder_Trainer(args) 

train_stats = trainer.train() 

val_per_list = train_stats.get('val_PERs', []) 
val_score = np.min(val_per_list) if val_per_list else 1.0 

print(f"\n" + "="*60)
print(f"--- Base Model Training Finished ---")
print("="*60)
print(f"Final Best (min) Validation PER: {val_score:.4f}")
print(f"Model checkpoint saved in: {args.checkpoint_dir}")

In [None]:
!cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python train_baseline_model.py

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM.py

import os
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse

from rnn_model import GRUDecoder
from evaluate_model_helpers import *

parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
                    help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
                    help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
                    help='Evaluation type: "val" for validation set, "test" for test set. '
                         'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
                    help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=1,
                    help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
parser.add_argument('--session', type=str, default=None,
                    help='Specify a single session to evaluate (e.g., "t15.2023.08.20").')
args = parser.parse_args()
model_path = args.model_path
data_dir = args.data_dir
eval_type = args.eval_type  
b2txt_csv_df = pd.read_csv(args.csv_path)
model_args_path = os.path.join(model_path, 'args.yaml')
if not os.path.exists(model_args_path):
    model_args_path = os.path.join(model_path, 'checkpoint/args.yaml')
print(f"Loading model args from: {model_args_path}")
model_args = OmegaConf.load(model_args_path)
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
    if gpu_number >= torch.cuda.device_count():
        raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
    device = f'cuda:{gpu_number}'
    device = torch.device(device)
    print(f'Using {device} for model inference.')
else:
    if gpu_number >= 0:
        print(f'GPU number {gpu_number} requested but not available.')
    print('Using CPU for model inference.')
    device = torch.device('cpu')

model = GRUDecoder(
    neural_dim = model_args['model']['n_input_features'],
    n_units = model_args['model']['n_units'], 
    n_days = len(model_args['dataset']['sessions']),
    n_classes = model_args['dataset']['n_classes'],
    rnn_dropout = model_args['model']['rnn_dropout'],
    input_dropout = model_args['model']['input_network']['input_layer_dropout'],
    n_layers = model_args['model']['n_layers'],
    patch_size = model_args['model']['patch_size'],
    patch_stride = model_args['model']['patch_stride'],
)

checkpoint_path = os.path.join(model_path, 'best_checkpoint')
if not os.path.exists(checkpoint_path):
    checkpoint_path = os.path.join(model_path, 'checkpoint/best_checkpoint')
print(f"Loading model checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
for key in list(checkpoint['model_state_dict'].keys()):
    checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
    checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])  
model.to(device) 
model.eval()

ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

# Determine which sessions to evaluate
sessions_to_evaluate = model_args['dataset']['sessions']
if args.session: # If a specific session was passed
    if args.session in sessions_to_evaluate:
        sessions_to_evaluate = [args.session] # Overwrite the list to only contain this one session
        print(f"--- Evaluating ONLY specified session: {args.session} ---")
    else:
        raise ValueError(f"Session '{args.session}' not found in model config list.")

test_data = {}
total_test_trials = 0
for session in sessions_to_evaluate: # Now loop over the (potentially) filtered list
    if not os.path.exists(os.path.join(data_dir, session)):
        print(f"Session folder not found: {os.path.join(data_dir, session)}, skipping.")
        continue
    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
    if f'data_{eval_type}.hdf5' in files:
        eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
        data = load_h5py_file(eval_file, b2txt_csv_df)
        test_data[session] = data
        total_test_trials += len(test_data[session]["neural_features"])
        print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()

with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
    for session, data in test_data.items():

        data['logits'] = []
        data['pred_seq'] = []
        data['losses'] = [] 
        input_layer = model_args['dataset']['sessions'].index(session)
        
        for trial in range(len(data['neural_features'])):
            neural_input = data['neural_features'][trial]
            neural_input = np.expand_dims(neural_input, axis=0)
            dtype = torch.bfloat16 if device.type != 'cpu' else torch.float32
            neural_input_tensor = torch.tensor(neural_input, device=device, dtype=dtype)
            
            with torch.no_grad(): # Ensure no gradients are calculated
                # runSingleDecodingStep returns a numpy array
                logits_numpy = runSingleDecodingStep(neural_input_tensor, input_layer, model, model_args, device)
            
            # Store the numpy array for phoneme extraction later
            data['logits'].append(logits_numpy)
            
            if eval_type == 'val':
                # Convert logits back to a tensor for loss calculation
                logits_tensor = torch.tensor(logits_numpy, device=device)
            
                true_seq = torch.tensor(data['seq_class_ids'][trial][0:data['seq_len'][trial]], device=device, dtype=torch.long)
                true_len = torch.tensor([data['seq_len'][trial]], device=device, dtype=torch.long)
                # Use the tensor's shape
                adjusted_lens = torch.tensor([logits_tensor.shape[1]], device=device, dtype=torch.long)
                
                # Now perform log_softmax on the tensor
                log_probs = torch.permute(logits_tensor.log_softmax(2), [1, 0, 2])
                
                loss = ctc_loss(log_probs, true_seq, adjusted_lens, true_len)
                data['losses'].append(loss.item())

            pbar.update(1)
pbar.close()

for session, data in test_data.items():
    data['pred_seq'] = []
    for trial in range(len(data['logits'])):
        logits = data['logits'][trial][0]
        pred_seq = np.argmax(logits, axis=-1)
        pred_seq = [int(p) for p in pred_seq if p != 0]
        pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
        pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
        data['pred_seq'].append(pred_seq)
        
        if eval_type == 'val':
            block_num = data['block_num'][trial]
            trial_num = data['trial_num'][trial]
            print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
            sentence_label = data['sentence_label'][trial]
            true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
            true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
            print(f'Sentence label:      {sentence_label}')
            print(f'True sequence:       {" ".join(true_seq)}')
            print(f'Predicted Sequence:  {" ".join(pred_seq)}')
            print()

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'phoneme_predictions_{time.strftime("%Y%m%d_%H%M%S")}.csv')

ids = []
all_pred_phonemes = []
all_true_phonemes = []
trial_accuracy = []
all_losses = [] 

total_edit_distance = 0
total_phoneme_length = 0

trial_id = 0
for session, data in test_data.items():
    if eval_type == 'val':
        all_losses.extend(data['losses'])
        
    for trial_idx, pred_seq in enumerate(data['pred_seq']):
        pred_phonemes = ' '.join(pred_seq)
        if eval_type == 'val':
            true_seq = data['seq_class_ids'][trial_idx][0:data['seq_len'][trial_idx]]
            true_phonemes_list = [LOGIT_TO_PHONEME[p] for p in true_seq]
            true_phonemes = ' '.join(true_phonemes_list)
            
            ed = editdistance.eval(pred_seq, true_phonemes_list)
            true_len = len(true_phonemes_list)
            
            total_edit_distance += ed
            total_phoneme_length += true_len
            
            acc = (1 - ed / true_len) if true_len > 0 else 0
        else:
            true_phonemes = None
            acc = None

        ids.append(trial_id)
        all_pred_phonemes.append(pred_phonemes)
        all_true_phonemes.append(true_phonemes)
        trial_accuracy.append(acc)
        trial_id += 1

df_out = pd.DataFrame({
    'id': ids,
    'pred_phonemes': all_pred_phonemes,
    'true_phonemes': all_true_phonemes,
    'trial_accuracy': trial_accuracy
})
df_out.to_csv(output_file, index=False)

print(f"Saved phoneme predictions, true sequences, and trial-level accuracy to {output_file}")

if eval_type == 'val':
    if trial_accuracy:
        avg_accuracy = sum([a for a in trial_accuracy if a is not None]) / len([a for a in trial_accuracy if a is not None])
        print(f"Average trial-level phoneme accuracy: {avg_accuracy:.4f}")
    else:
        print("No trials found to calculate accuracy.")

    if total_phoneme_length > 0:
        aggregate_per = total_edit_distance / total_phoneme_length
        print(f"Aggregate Phoneme Error Rate (PER): {aggregate_per:.4f}")
    else:
        print("Could not calculate Aggregate PER: no validation phonemes found.")
        
    if all_losses:
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Sum Validation Loss: {sum(all_losses):.4f}")
        print(f"Average Validation Loss: {avg_loss:.4f}")
    else:
        print("Could not calculate average validation loss.")

In [None]:
import os

eval_script = "/kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM.py"
model_path = "/kaggle/working/nejm-brain-to-text/model_training/trained_models/baseline_rnn/checkpoint/progressive_training"
data_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

# USE VAL INSTEAD OF TEST
eval_type = "val"  # This has ground truth labels
gpu_number = 0
target_session = "t15.2023.08.20"

cmd = f"""
cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python {eval_script} \
    --model_path {model_path} \
    --data_dir {data_dir} \
    --csv_path {csv_path} \
    --eval_type {eval_type} \
    --gpu_number {gpu_number}\
    --session {target_session}
"""
print(f"Model: {model_path}")
print(f"Eval type: {eval_type}")

os.system(cmd)