# Setup

In [None]:
!git clone https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text.git

In [None]:
!pip install redis numpy pandas h5py omegaconf editdistance tqdm

# Evaluate Baseline Model

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM.py

import os
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse

from rnn_model import GRUDecoder
from evaluate_model_helpers import *

parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
                    help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
                    help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
                    help='Evaluation type: "val" for validation set, "test" for test set. '
                         'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
                    help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=1,
                    help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
model_path = args.model_path
data_dir = args.data_dir
eval_type = args.eval_type  
b2txt_csv_df = pd.read_csv(args.csv_path)
model_args_path = os.path.join(model_path, 'args.yaml')
if not os.path.exists(model_args_path):
    model_args_path = os.path.join(model_path, 'checkpoint/args.yaml')
print(f"Loading model args from: {model_args_path}")
model_args = OmegaConf.load(model_args_path)
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
    if gpu_number >= torch.cuda.device_count():
        raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
    device = f'cuda:{gpu_number}'
    device = torch.device(device)
    print(f'Using {device} for model inference.')
else:
    if gpu_number >= 0:
        print(f'GPU number {gpu_number} requested but not available.')
    print('Using CPU for model inference.')
    device = torch.device('cpu')

# Model definition
model = GRUDecoder(
    neural_dim = model_args['model']['n_input_features'],
    n_units = model_args['model']['n_units'], 
    n_days = len(model_args['dataset']['sessions']),
    n_classes = model_args['dataset']['n_classes'],
    rnn_dropout = model_args['model']['rnn_dropout'],
    input_dropout = model_args['model']['input_network']['input_layer_dropout'],
    n_layers = model_args['model']['n_layers'],
    patch_size = model_args['model']['patch_size'],
    patch_stride = model_args['model']['patch_stride'],
)

# Checkpoint
checkpoint_path = os.path.join(model_path, 'best_checkpoint')
if not os.path.exists(checkpoint_path):
    checkpoint_path = os.path.join(model_path, 'checkpoint/best_checkpoint')
print(f"Loading model checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
for key in list(checkpoint['model_state_dict'].keys()):
    checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
    checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])  
model.to(device) 
model.eval()

# Loss function
ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

# Data loading loop
test_data = {}
total_test_trials = 0
for session in model_args['dataset']['sessions']:
    if not os.path.exists(os.path.join(data_dir, session)):
        print(f"Session folder not found: {os.path.join(data_dir, session)}, skipping.")
        continue
    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
    if f'data_{eval_type}.hdf5' in files:
        eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
        data = load_h5py_file(eval_file, b2txt_csv_df)
        test_data[session] = data
        total_test_trials += len(test_data[session]["neural_features"])
        print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()

with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
    for session, data in test_data.items():

        data['logits'] = []
        data['pred_seq'] = []
        data['losses'] = [] 
        input_layer = model_args['dataset']['sessions'].index(session)
        
        for trial in range(len(data['neural_features'])):
            neural_input = data['neural_features'][trial]
            neural_input = np.expand_dims(neural_input, axis=0)
            dtype = torch.bfloat16 if device.type != 'cpu' else torch.float32
            neural_input_tensor = torch.tensor(neural_input, device=device, dtype=dtype)
            
            with torch.no_grad(): 
                # runSingleDecodingStep returns a numpy array
                logits_numpy = runSingleDecodingStep(neural_input_tensor, input_layer, model, model_args, device)
            
            # Store the numpy array for phoneme extraction later
            data['logits'].append(logits_numpy)
            
            if eval_type == 'val':
                # Convert logits back to a tensor for loss calculation
                logits_tensor = torch.tensor(logits_numpy, device=device)
            
                true_seq = torch.tensor(data['seq_class_ids'][trial][0:data['seq_len'][trial]], device=device, dtype=torch.long)
                true_len = torch.tensor([data['seq_len'][trial]], device=device, dtype=torch.long)
                # Use the tensor's shape
                adjusted_lens = torch.tensor([logits_tensor.shape[1]], device=device, dtype=torch.long)
                
                # Now perform log_softmax on the tensor
                log_probs = torch.permute(logits_tensor.log_softmax(2), [1, 0, 2])
                
                loss = ctc_loss(log_probs, true_seq, adjusted_lens, true_len)
                data['losses'].append(loss.item())

            pbar.update(1)
pbar.close()

for session, data in test_data.items():
    data['pred_seq'] = []
    for trial in range(len(data['logits'])):
        logits = data['logits'][trial][0]
        pred_seq = np.argmax(logits, axis=-1)
        pred_seq = [int(p) for p in pred_seq if p != 0]
        pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
        pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
        data['pred_seq'].append(pred_seq)
        
        if eval_type == 'val':
            block_num = data['block_num'][trial]
            trial_num = data['trial_num'][trial]
            print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
            sentence_label = data['sentence_label'][trial]
            true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
            true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
            print(f'Sentence label:      {sentence_label}')
            print(f'True sequence:       {" ".join(true_seq)}')
            print(f'Predicted Sequence:  {" ".join(pred_seq)}')
            print()

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'phoneme_predictions_{time.strftime("%Y%m%d_%H%M%S")}.csv')

ids = []
all_pred_phonemes = []
all_true_phonemes = []
trial_accuracy = []
all_losses = [] 

total_edit_distance = 0
total_phoneme_length = 0

trial_id = 0
for session, data in test_data.items():
    if eval_type == 'val':
        all_losses.extend(data['losses'])
        
    for trial_idx, pred_seq in enumerate(data['pred_seq']):
        pred_phonemes = ' '.join(pred_seq)
        if eval_type == 'val':
            true_seq = data['seq_class_ids'][trial_idx][0:data['seq_len'][trial_idx]]
            true_phonemes_list = [LOGIT_TO_PHONEME[p] for p in true_seq]
            true_phonemes = ' '.join(true_phonemes_list)
            
            ed = editdistance.eval(pred_seq, true_phonemes_list)
            true_len = len(true_phonemes_list)
            
            total_edit_distance += ed
            total_phoneme_length += true_len
            
            acc = (1 - ed / true_len) if true_len > 0 else 0
        else:
            true_phonemes = None
            acc = None

        ids.append(trial_id)
        all_pred_phonemes.append(pred_phonemes)
        all_true_phonemes.append(true_phonemes)
        trial_accuracy.append(acc)
        trial_id += 1

df_out = pd.DataFrame({
    'id': ids,
    'pred_phonemes': all_pred_phonemes,
    'true_phonemes': all_true_phonemes,
    'trial_accuracy': trial_accuracy
})
df_out.to_csv(output_file, index=False)

print(f"Saved phoneme predictions, true sequences, and trial-level accuracy to {output_file}")

if eval_type == 'val':
    if trial_accuracy:
        avg_accuracy = sum([a for a in trial_accuracy if a is not None]) / len([a for a in trial_accuracy if a is not None])
        print(f"Average trial-level phoneme accuracy: {avg_accuracy:.4f}")
    else:
        print("No trials found to calculate accuracy.")

    if total_phoneme_length > 0:
        aggregate_per = total_edit_distance / total_phoneme_length
        print(f"Aggregate Phoneme Error Rate (PER): {aggregate_per:.4f}")
    else:
        print("Could not calculate Aggregate PER: no validation phonemes found.")
        
    if all_losses:
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Sum Validation Loss: {sum(all_losses):.4f}")
        print(f"Average Validation Loss: {avg_loss:.4f}")
    else:
        print("Could not calculate average validation loss.")

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM.py

import os
import torch
import numpy as np
import pandas as pd
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse

from rnn_model import GRUDecoder
from evaluate_model_helpers import *

# (Argument parser, paths, model args, device setup... all unchanged)
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
                    help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
                    help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
                    help='Evaluation type: "val" for validation set, "test" for test set. '
                         'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
                    help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=1,
                    help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
model_path = args.model_path
data_dir = args.data_dir
eval_type = args.eval_type  
b2txt_csv_df = pd.read_csv(args.csv_path)
model_args_path = os.path.join(model_path, 'args.yaml')
if not os.path.exists(model_args_path):
    model_args_path = os.path.join(model_path, 'checkpoint/args.yaml')
print(f"Loading model args from: {model_args_path}")
model_args = OmegaConf.load(model_args_path)
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
    if gpu_number >= torch.cuda.device_count():
        raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
    device = f'cuda:{gpu_number}'
    device = torch.device(device)
    print(f'Using {device} for model inference.')
else:
    if gpu_number >= 0:
        print(f'GPU number {gpu_number} requested but not available.')
    print('Using CPU for model inference.')
    device = torch.device('cpu')

# (Model definition... unchanged)
model = GRUDecoder(
    neural_dim = model_args['model']['n_input_features'],
    n_units = model_args['model']['n_units'], 
    n_days = len(model_args['dataset']['sessions']),
    n_classes = model_args['dataset']['n_classes'],
    rnn_dropout = model_args['model']['rnn_dropout'],
    input_dropout = model_args['model']['input_network']['input_layer_dropout'],
    n_layers = model_args['model']['n_layers'],
    patch_size = model_args['model']['patch_size'],
    patch_stride = model_args['model']['patch_stride'],
)

# (Checkpoint loading... unchanged)
checkpoint_path = os.path.join(model_path, 'best_checkpoint')
if not os.path.exists(checkpoint_path):
    checkpoint_path = os.path.join(model_path, 'checkpoint/best_checkpoint')
print(f"Loading model checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device)
for key in list(checkpoint['model_state_dict'].keys()):
    checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
    checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])  
model.to(device) 
model.eval()

# (Loss function definition... unchanged)
ctc_loss = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

# (Data loading loop... unchanged)
test_data = {}
total_test_trials = 0
for session in model_args['dataset']['sessions']:
    if not os.path.exists(os.path.join(data_dir, session)):
        print(f"Session folder not found: {os.path.join(data_dir, session)}, skipping.")
        continue
    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
    if f'data_{eval_type}.hdf5' in files:
        eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
        data = load_h5py_file(eval_file, b2txt_csv_df)
        test_data[session] = data
        total_test_trials += len(test_data[session]["neural_features"])
        print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()


# --- START FIX ---
# This loop is modified to handle the numpy/tensor conversion
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
    for session, data in test_data.items():

        data['logits'] = []
        data['pred_seq'] = []
        data['losses'] = [] 
        input_layer = model_args['dataset']['sessions'].index(session)
        
        for trial in range(len(data['neural_features'])):
            neural_input = data['neural_features'][trial]
            neural_input = np.expand_dims(neural_input, axis=0)
            dtype = torch.bfloat16 if device.type != 'cpu' else torch.float32
            neural_input_tensor = torch.tensor(neural_input, device=device, dtype=dtype)
            
            with torch.no_grad(): # Ensure no gradients are calculated
                # runSingleDecodingStep returns a numpy array
                logits_numpy = runSingleDecodingStep(neural_input_tensor, input_layer, model, model_args, device)
            
            # Store the numpy array for phoneme extraction later
            data['logits'].append(logits_numpy)
            
            if eval_type == 'val':
                # Convert logits back to a tensor for loss calculation
                logits_tensor = torch.tensor(logits_numpy, device=device)
            
                true_seq = torch.tensor(data['seq_class_ids'][trial][0:data['seq_len'][trial]], device=device, dtype=torch.long)
                true_len = torch.tensor([data['seq_len'][trial]], device=device, dtype=torch.long)
                # Use the tensor's shape
                adjusted_lens = torch.tensor([logits_tensor.shape[1]], device=device, dtype=torch.long)
                
                # Now perform log_softmax on the tensor
                log_probs = torch.permute(logits_tensor.log_softmax(2), [1, 0, 2])
                
                loss = ctc_loss(log_probs, true_seq, adjusted_lens, true_len)
                data['losses'].append(loss.item())

            pbar.update(1)
pbar.close()
# --- END FIX ---


# (Rest of the script is unchanged)
for session, data in test_data.items():
    data['pred_seq'] = []
    for trial in range(len(data['logits'])):
        logits = data['logits'][trial][0]
        pred_seq = np.argmax(logits, axis=-1)
        pred_seq = [int(p) for p in pred_seq if p != 0]
        pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
        pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
        data['pred_seq'].append(pred_seq)
        
        if eval_type == 'val':
            block_num = data['block_num'][trial]
            trial_num = data['trial_num'][trial]
            print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
            sentence_label = data['sentence_label'][trial]
            true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
            true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
            print(f'Sentence label:      {sentence_label}')
            print(f'True sequence:       {" ".join(true_seq)}')
            print(f'Predicted Sequence:  {" ".join(pred_seq)}')
            print()

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f'phoneme_predictions_{time.strftime("%Y%m%d_%H%M%S")}.csv')

ids = []
all_pred_phonemes = []
all_true_phonemes = []
trial_accuracy = []
all_losses = [] 

total_edit_distance = 0
total_phoneme_length = 0

trial_id = 0
for session, data in test_data.items():
    if eval_type == 'val':
        all_losses.extend(data['losses'])
        
    for trial_idx, pred_seq in enumerate(data['pred_seq']):
        pred_phonemes = ' '.join(pred_seq)
        if eval_type == 'val':
            true_seq = data['seq_class_ids'][trial_idx][0:data['seq_len'][trial_idx]]
            true_phonemes_list = [LOGIT_TO_PHONEME[p] for p in true_seq]
            true_phonemes = ' '.join(true_phonemes_list)
            
            ed = editdistance.eval(pred_seq, true_phonemes_list)
            true_len = len(true_phonemes_list)
            
            total_edit_distance += ed
            total_phoneme_length += true_len
            
            acc = (1 - ed / true_len) if true_len > 0 else 0
        else:
            true_phonemes = None
            acc = None

        ids.append(trial_id)
        all_pred_phonemes.append(pred_phonemes)
        all_true_phonemes.append(true_phonemes)
        trial_accuracy.append(acc)
        trial_id += 1

df_out = pd.DataFrame({
    'id': ids,
    'pred_phonemes': all_pred_phonemes,
    'true_phonemes': all_true_phonemes,
    'trial_accuracy': trial_accuracy
})
df_out.to_csv(output_file, index=False)

print(f"Saved phoneme predictions, true sequences, and trial-level accuracy to {output_file}")

if eval_type == 'val':
    if trial_accuracy:
        avg_accuracy = sum([a for a in trial_accuracy if a is not None]) / len([a for a in trial_accuracy if a is not None])
        print(f"Average trial-level phoneme accuracy: {avg_accuracy:.4f}")
    else:
        print("No trials found to calculate accuracy.")

    if total_phoneme_length > 0:
        aggregate_per = total_edit_distance / total_phoneme_length
        print(f"Aggregate Phoneme Error Rate (PER): {aggregate_per:.4f}")
    else:
        print("Could not calculate Aggregate PER: no validation phonemes found.")
        
    if all_losses:
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Sum Validation Loss: {sum(all_losses):.4f}")
        print(f"Average Validation Loss: {avg_loss:.4f}")
    else:
        print("Could not calculate average validation loss.")

# I got:
# Average trial-level phoneme accuracy: 0.9027
# Aggregate Phoneme Error Rate (PER): 0.1020
# Sum Validation Loss: 1007.7086
# Average Validation Loss: 0.7067

# Compare AdamW and SGD

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_model_optimize.py

import optuna
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch 
import shutil

# --- 1. Define the Objective Function for Optuna ---
def objective(trial):
    
    args_path = 'rnn_args.yaml'
    if not os.path.exists(args_path):
        print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
        args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
        
    args = OmegaConf.load(args_path)

    # Force the script to use the correct Kaggle data paths
    args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
    args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

    # Give each trial a unique output and checkpoint directory
    base_output_dir = args.output_dir
    base_checkpoint_dir = args.checkpoint_dir
    trial_id_str = f"trial_{trial.number}"
    
    args.output_dir = os.path.join(base_output_dir, trial_id_str)
    args.checkpoint_dir = os.path.join(base_checkpoint_dir, trial_id_str)

    # Remove existing directories
    if os.path.exists(args.output_dir):
        print(f"Removing existing output directory: {args.output_dir}")
        shutil.rmtree(args.output_dir) # Recursively remove directory
    if os.path.exists(args.checkpoint_dir):
         print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
         shutil.rmtree(args.checkpoint_dir) # Recursively remove directory

    # Suggest new hyperparameters for this trial
    # --- DO THIS LATER ---
    # n_layers = trial.suggest_int('n_layers', 1, 4) 
    # n_units = trial.suggest_int('n_units', 256, 1024, step=128)
    # rnn_dropout = trial.suggest_float('rnn_dropout', 0.1, 0.5)
    # -----------------------
    optimizer_name = trial.suggest_categorical("optimizer", ["AdamW", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)

    # Overwrite the default args with the new values from Optuna
    # args.model.n_layers = n_layers
    # args.model.n_units = n_units
    # args.model.rnn_dropout = rnn_dropout
    args.lr_max = lr

    # Override to train faster
    args.num_training_batches = 1000 
    
    print(f"\n--- Starting Trial {trial.number} ---")
    #print(f"Params: n_layers={n_layers}, n_units={n_units}, rnn_dropout={rnn_dropout}, optimizer={optimizer_name}, lr={lr}")
    print(f"Params: optimizer={optimizer_name}, lr={lr} (Using default model architecture)")
    print(f"Training for {args.num_training_batches} batches.")
    print(f"Output dir: {args.output_dir}")

    # Create trainer (initializes model)
    trainer = BrainToTextDecoder_Trainer(args)

    # --- Manually Create Optimizer and Scheduler ---
    # Define param groups FIRST, filtering for requires_grad=True
    bias_params = [p for name, p in trainer.model.named_parameters() if ('gru.bias' in name or 'out.bias' in name) and p.requires_grad]
    day_params = [p for name, p in trainer.model.named_parameters() if 'day_' in name and p.requires_grad]
    other_params = [p for name, p in trainer.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name and p.requires_grad]

    # Structure the groups (handle case where day_params might be empty after filtering)
    if day_params: # Only include day_params group if it's not empty
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias', 'lr': lr}, # Set LR here too
                {'params' : day_params, 'lr' : args.lr_max_day, 'weight_decay' : args.weight_decay_day, 'group_type' : 'day_layer'},
                {'params' : other_params, 'group_type' : 'other', 'lr': lr} # Set LR here too
            ]
    else: # No trainable day_params found
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias', 'lr': lr}, # Set LR here too
                {'params' : other_params, 'group_type' : 'other', 'lr': lr} # Set LR here too
            ]

    # Now create the optimizer using the filtered groups
    if optimizer_name == "AdamW":
        trainer.optimizer = torch.optim.AdamW(
            param_groups,
            # lr is set per group above
            betas = (args.beta0, args.beta1),
            eps = args.epsilon,
            # weight_decay handled per group
            fused = True # Keep using fused if available
        )
                 
    elif optimizer_name == "SGD":
        trainer.optimizer = torch.optim.SGD(
            param_groups, 
            # lr is set per group above
            momentum=0.9,
            # weight_decay handled per group
        ) 

    # Recreate the learning rate scheduler using the NEW optimizer (this part should be fine now)
    if args.lr_scheduler_type == 'linear':
        trainer.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer=trainer.optimizer, 
            start_factor=1.0,
            end_factor=args.lr_min / args.lr_max, 
            total_iters=args.lr_decay_steps,
        )
    elif args.lr_scheduler_type == 'cosine':
         # Check if create_cosine_lr_scheduler needs adjustment for potentially only 2 groups
         # The original function handles 2 or 3 groups
         # If day_params was empty, len(param_groups) will be 2, which is handled.
         trainer.learning_rate_scheduler = trainer.create_cosine_lr_scheduler(trainer.optimizer)
         
    train_stats = trainer.train() 

    # Get the list of all validation Phoneme Error Rates
    val_per_list = train_stats.get('val_PERs', []) 
    
    # Get the best (minimum) PER from the list.
    val_score = np.min(val_per_list) if val_per_list else 1.0 
    
    print(f"Trial {trial.number} finished. Best (min) PER: {val_score}")

    return val_score

# --- 2. Create the Study and Run It ---
if __name__ == "__main__":
    
    # We want to MINIMIZE the Phoneme Error Rate
    study = optuna.create_study(direction='minimize')
    
    # Adjust n_trials as needed for time constraints
    study.optimize(objective, n_trials=10) # Reduced trials for faster testing

    print("\n--- Optimization Finished ---")
    print("Best trial:")
    trial = study.best_trial

    print(f"  Value (Min PER): {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

In [None]:
!cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python train_model_optimize.py

Result: 

Trial 0: optimizer=SGD, lr=0.00116, Final PER = 0.6147

Trial 1: optimizer=AdamW, lr=8.36e-05, Final PER = 1.0000

Trial 2: optimizer=AdamW, lr=0.00025, Final PER = 1.0000

Trial 3: optimizer=SGD, lr=4.95e-05, Final PER = 1.0000

Trial 4: optimizer=SGD, lr=3.34e-05, Final PER = 1.0000

Trial 5: optimizer=SGD, lr=5.31e-05, Final PER = 1.0000

Trial 6: optimizer=SGD, lr=0.00031, Final PER = 0.9975

Trial 7: optimizer=SGD, lr=0.00344, Final PER = 0.4412 

Trial 8: optimizer=AdamW, lr=0.00038, Final PER = 0.9980

Trial 9: optimizer=SGD, lr=1.13e-05, Final PER = 0.9360

SGD with not too small of a learning rate would be good


# Test SGD with Different Hyperparam

Tried 2 sets: 

- Set 1: Conclusion - high momenutm + nesterov would be good 

Trial 0: lr=0.0028, momentum=0.944, nesterov=True -> PER = 0.3585

Trial 1: lr=0.00026, momentum=0.934, nesterov=True -> PER = 0.9957

Trial 2: lr=0.00028, momentum=0.821, nesterov=False -> PER = 1.0000

Trial 3: lr=0.00062, momentum=0.916, nesterov=False -> PER = 0.8413

Trial 4: lr=0.00548, momentum=0.871, nesterov=False -> PER = 0.4291

Trial 5: lr=0.00192, momentum=0.815, nesterov=True -> PER = 0.6237

Trial 6: lr=0.00071, momentum=0.986, nesterov=True -> PER = 0.3942

Trial 7: lr=0.00064, momentum=0.932, nesterov=True -> PER = 0.6098

Trial 8: lr=0.00030, momentum=0.967, nesterov=True -> PER = 0.6250

Trial 9: lr=0.00053, momentum=0.952, nesterov=True -> PER = 0.9360

- Set 2: Conclusion

Trial 0: lr=0.0029, momentum=0.956, wd=2.5e-05, nesterov=True -> PER = 0.3118 

Trial 1: lr=0.0030, momentum=0.945, wd=8.9e-05, nesterov=True -> PER = 0.3393

Trial 2: lr=0.0029, momentum=0.954, wd=0.0014, nesterov=True -> PER = 0.3212

Trial 3: lr=0.0030, momentum=0.974, wd=0.0079, nesterov=True -> PER = 0.2858 

Trial 4: lr=0.0026, momentum=0.977, wd=0.0002, nesterov=True -> PER = 0.2848 

Trial 5: lr=0.0026, momentum=0.947, wd=0.0004, nesterov=True -> PER = 0.3501

Trial 6: lr=0.0028, momentum=0.989, wd=3.8e-05, nesterov=True -> PER = 0.2815

Trial 7: lr=0.0026, momentum=0.957, wd=6.6e-05, nesterov=True -> PER = 0.3226

Trial 8: lr=0.0030, momentum=0.964, wd=0.0023, nesterov=True -> PER = 0.2970

Trial 9: lr=0.0028, momentum=0.984, wd=4.1e-05, nesterov=True -> PER = 0.2933

Trial 10: lr=0.0032, momentum=0.990, wd=1.2e-05, nesterov=True -> PER = 0.2796 (Best)

Trial 11: lr=0.0032, momentum=0.990, wd=1.2e-05, nesterov=True -> PER = 0.2825

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_model_optimize_sgd.py

import optuna
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch
import shutil

# --- 1. Define the Objective Function for Optuna ---
def objective(trial):
    
    args_path = 'rnn_args.yaml'
    # ... (rest of the initial setup: load args, set paths, unique dirs, remove old dirs) ...
    if not os.path.exists(args_path):
        print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
        args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    args = OmegaConf.load(args_path)
    args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
    args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"
    base_output_dir = args.output_dir
    base_checkpoint_dir = args.checkpoint_dir
    trial_id_str = f"trial_{trial.number}"
    args.output_dir = os.path.join(base_output_dir, trial_id_str)
    args.checkpoint_dir = os.path.join(base_checkpoint_dir, trial_id_str)
    if os.path.exists(args.output_dir):
        print(f"Removing existing output directory: {args.output_dir}")
        shutil.rmtree(args.output_dir) 
    if os.path.exists(args.checkpoint_dir):
         print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
         shutil.rmtree(args.checkpoint_dir) 

    # --- Suggest Focused SGD Hyperparameters ---
    lr = trial.suggest_float("lr", 0.0026, 0.0032, log=False) 
    momentum = trial.suggest_float("momentum", 0.94, 0.99) 
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True) 
    nesterov = True 

    # --- Overwrite Learning Rate Arg ---
    args.lr_max = lr 
    
    args.num_training_batches = 1000 # Keep training short for testing
    
    print(f"\n--- Starting Trial {trial.number} ---")
    print(f"Params: optimizer=SGD, lr={lr}, momentum={momentum}, weight_decay={weight_decay}, nesterov={nesterov} (Using default model architecture)") 
    print(f"Training for {args.num_training_batches} batches.")
    print(f"Output dir: {args.output_dir}")

    # Create trainer (initializes model)
    trainer = BrainToTextDecoder_Trainer(args) 

    # --- Manually Create SGD Optimizer with Filtered Groups and Tuned Params ---
    
    # 1. Manually define param lists, filtering for requires_grad=True
    bias_params = [p for name, p in trainer.model.named_parameters() if ('gru.bias' in name or 'out.bias' in name) and p.requires_grad]
    day_params = [p for name, p in trainer.model.named_parameters() if 'day_' in name and p.requires_grad]
    other_params = [p for name, p in trainer.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name and p.requires_grad]

    # 2. Structure the groups, applying specific settings
    if day_params: 
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'lr': lr}, 
                {'params' : day_params, 'lr' : args.lr_max_day, 'weight_decay' : args.weight_decay_day}, # Keep day params separate
                {'params' : other_params, 'lr': lr, 'weight_decay': weight_decay} # Apply tuned weight_decay here
            ]
    else: 
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'lr': lr}, 
                {'params' : other_params, 'lr': lr, 'weight_decay': weight_decay} # Apply tuned weight_decay here
            ]
            
    # 3. Create the SGD optimizer using the filtered groups and tuned params
    trainer.optimizer = torch.optim.SGD(
        param_groups, 
        lr=lr, 
        momentum=momentum, 
        nesterov=nesterov 
        # weight_decay is handled per group
    ) 

    # 4. Recreate the learning rate scheduler using the NEW optimizer 
    if args.lr_scheduler_type == 'linear':
        trainer.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer=trainer.optimizer, 
            start_factor=1.0,
            end_factor=args.lr_min / args.lr_max, 
            total_iters=args.lr_decay_steps,
        )
    elif args.lr_scheduler_type == 'cosine':
         trainer.learning_rate_scheduler = trainer.create_cosine_lr_scheduler(trainer.optimizer) 
         
    train_stats = trainer.train() 

    # ... (rest of the function: get val_score, print, return) ...
    val_per_list = train_stats.get('val_PERs', []) 
    val_score = np.min(val_per_list) if val_per_list else 1.0 
    print(f"Trial {trial.number} finished. Best (min) PER: {val_score}")
    return val_score
    
# --- (The __main__ section remains the same) ---
if __name__ == "__main__":
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=20) # Keep n_trials low for testing
    print("\n--- Optimization Finished ---")
    print("Best trial:")
    trial = study.best_trial
    print(f"  Value (Min PER): {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

In [None]:
!cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python train_model_optimize_sgd.py

# Train full model with the SGD Param

New best test PER 0.1714 (at 10k batches); Average trial-level phoneme accuracy: 0.8345 (with 13,800 batches not 10k)


In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_sgd_model.py

from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch 
import shutil

# --- 1. DEFINE YOUR WINNING PARAMETERS ---
# (From Optuna Trial 10, the best one)
BEST_LR = 0.00316
BEST_MOMENTUM = 0.990
BEST_WEIGHT_DECAY = 1.22e-05
BEST_NESTEROV = True
# ---

print("--- Starting Full Training Run with Optimized SGD Params ---")

args_path = 'rnn_args.yaml'
if not os.path.exists(args_path):
    print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
    args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    
args = OmegaConf.load(args_path)

# Force the script to use the correct Kaggle data paths
args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

# --- 2. SET FULL TRAINING LENGTH ---
# Use the default number of batches for a full training run
args.num_training_batches = 10000 #original model 120000
print(f"Training for {args.num_training_batches} batches.")

# --- 3. SET UNIQUE OUTPUT DIRECTORY ---
# We'll save this model in a special "best_sgd_model" folder
new_output_dir = "trained_models/baseline_rnn/best_sgd_model"
new_checkpoint_dir = "trained_models/baseline_rnn/checkpoint/best_sgd_model"
args.output_dir = new_output_dir
args.checkpoint_dir = new_checkpoint_dir

# Remove old directories if they exist
if os.path.exists(args.output_dir):
    print(f"Removing existing output directory: {args.output_dir}")
    shutil.rmtree(args.output_dir) 
if os.path.exists(args.checkpoint_dir):
     print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
     shutil.rmtree(args.checkpoint_dir) 

# --- 4. OVERWRITE ARGS WITH BEST PARAMS ---
args.lr_max = BEST_LR

print(f"Params: optimizer=SGD, lr={BEST_LR}, momentum={BEST_MOMENTUM}, weight_decay={BEST_WEIGHT_DECAY}, nesterov={BEST_NESTEROV}")
print(f"Output dir: {args.output_dir}")

# Create trainer (initializes model)
trainer = BrainToTextDecoder_Trainer(args) 

# --- 5. MANUALLY CREATE THE OPTIMIZER AND SCHEDULER ---

# 1. Manually define param lists, filtering for requires_grad=True
bias_params = [p for name, p in trainer.model.named_parameters() if ('gru.bias' in name or 'out.bias' in name) and p.requires_grad]
day_params = [p for name, p in trainer.model.named_parameters() if 'day_' in name and p.requires_grad]
other_params = [p for name, p in trainer.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name and p.requires_grad]

# 2. Structure the groups, applying specific settings
if day_params: 
    param_groups = [
            {'params' : bias_params, 'weight_decay' : 0, 'lr': BEST_LR}, 
            {'params' : day_params, 'lr' : args.lr_max_day, 'weight_decay' : args.weight_decay_day},
            {'params' : other_params, 'lr': BEST_LR, 'weight_decay': BEST_WEIGHT_DECAY} 
        ]
else: 
    param_groups = [
            {'params' : bias_params, 'weight_decay' : 0, 'lr': BEST_LR}, 
            {'params' : other_params, 'lr': BEST_LR, 'weight_decay': BEST_WEIGHT_DECAY} 
        ]
        
# 3. Create the SGD optimizer
trainer.optimizer = torch.optim.SGD(
    param_groups, 
    lr=BEST_LR, 
    momentum=BEST_MOMENTUM, 
    nesterov=BEST_NESTEROV 
) 

# 4. Recreate the learning rate scheduler
if args.lr_scheduler_type == 'linear':
    trainer.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer=trainer.optimizer, 
        start_factor=1.0,
        end_factor=args.lr_min / args.lr_max, 
        total_iters=args.lr_decay_steps,
    )
elif args.lr_scheduler_type == 'cosine':
     trainer.learning_rate_scheduler = trainer.create_cosine_lr_scheduler(trainer.optimizer) 
     
# --- 6. RUN TRAINING ---
print("Starting full model training...")
train_stats = trainer.train() 

val_per_list = train_stats.get('val_PERs', []) 
val_score = np.min(val_per_list) if val_per_list else 1.0 

print(f"\n--- Full Training Finished ---")
print(f"Final Best (min) PER: {val_score}")
print(f"Your new model is saved in: {args.checkpoint_dir}")

In [None]:
!cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python train_sgd_model.py

In [None]:
import os

# --- 1. Define all the paths ---
eval_script = "/kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM_sgd.py"
# This path is now correct because the patched script will look 
# for 'args.yaml' directly inside it.
model_path = "/kaggle/working/nejm-brain-to-text/model_training/trained_models/baseline_rnn/checkpoint/best_sgd_model"
data_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

# --- 2. Set evaluation options ---
eval_type = "val"
gpu_number = 0

# --- 3. Build and run the command ---
cmd = f"""
cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python {eval_script} \
    --model_path {model_path} \
    --data_dir {data_dir} \
    --csv_path {csv_path} \
    --eval_type {eval_type} \
    --gpu_number {gpu_number}
"""

print(f"--- Running evaluation on new model: {model_path} ---")
os.system(cmd)
print("--- Evaluation finished. ---")

# Train Baseline Model with Fewer Batches (10,000) to Compare with Previous SGD

baseline model with 10k batches: Final Best (min) PER: 0.12702937424182892; Average trial-level phoneme accuracy: 0.8778

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_baseline_model.py

from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch 
import shutil

print("--- Starting Full Training Run with ORIGINAL BASELINE (AdamW) Params ---")

args_path = 'rnn_args.yaml'
if not os.path.exists(args_path):
    print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
    args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    
args = OmegaConf.load(args_path)
args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"

# --- SET TRAINING LENGTH ---
args.num_training_batches = 10000 # Set to 10k for fair comparison
print(f"Training for {args.num_training_batches} batches.")

# --- SET UNIQUE OUTPUT DIRECTORY ---
new_output_dir = "trained_models/baseline_rnn/best_baseline_model"
new_checkpoint_dir = "trained_models/baseline_rnn/checkpoint/best_baseline_model"
args.output_dir = new_output_dir
args.checkpoint_dir = new_checkpoint_dir

# Remove old directories if they exist
if os.path.exists(args.output_dir):
    print(f"Removing existing output directory: {args.output_dir}")
    shutil.rmtree(args.output_dir) 
if os.path.exists(args.checkpoint_dir):
     print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
     shutil.rmtree(args.checkpoint_dir) 

# --- NO PARAMS TO OVERWRITE ---
print(f"Params: optimizer=AdamW (default), lr={args.lr_max} (default)")
print(f"Output dir: {args.output_dir}")

# Create trainer (This will automatically create the default AdamW optimizer and scheduler)
trainer = BrainToTextDecoder_Trainer(args) 

# --- NO MANUAL OPTIMIZER CREATION NEEDED ---

# --- RUN TRAINING ---
print("Starting baseline model training...")
train_stats = trainer.train() 

val_per_list = train_stats.get('val_PERs', []) 
val_score = np.min(val_per_list) if val_per_list else 1.0 

print(f"\n--- Full Training Finished ---")
print(f"Final Best (min) PER: {val_score}")
print(f"Your new baseline model is saved in: {args.checkpoint_dir}")

In [None]:
import os
eval_script = "/kaggle/working/nejm-brain-to-text/model_training/evaluate_without_LLM.py"
data_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"
model_path_baseline = "/kaggle/working/nejm-brain-to-text/model_training/trained_models/baseline_rnn/checkpoint/best_baseline_model"

cmd_baseline = f"""
cd /kaggle/working/nejm-brain-to-text/model_training/ && \
python {eval_script} \
    --model_path {model_path_baseline} \
    --data_dir {data_dir} \
    --csv_path {csv_path} \
    --eval_type "val" \
    --gpu_number 0
"""
print(f"--- Evaluating BASELINE AdamW Model (trained 10k batches) ---")
os.system(cmd_baseline)

# Train SGD with different schedulers

Trial 5 (Best): PER = 0.2363 Params: lr=0.0028, momentum=0.983, scheduler=linear

Trial 0: PER = 0.2367 Params: lr=0.0031, momentum=0.976, scheduler=linear

Trial 16: PER = 0.2406 Params: lr=0.0030, momentum=0.976, scheduler=linear

Trial 13: PER = 0.2394 Params: lr=0.0028, momentum=0.981, scheduler=linear

Trial 1: PER = 0.2493 Params: lr=0.0026, momentum=0.979, scheduler=linear

Trial 15: PER = 0.2523 Params: lr=0.0030, momentum=0.970, scheduler=step, step_size=800

Trial 14: PER = 0.2573 Params: lr=0.0027, momentum=0.986, scheduler=linear

Trial 10: PER = 0.2796 Params: lr=0.0032, momentum=0.990, scheduler=step, step_size=800

Trial 6: PER = 0.2815 Params: lr=0.0028, momentum=0.989, scheduler=linear

Trial 11: PER = 0.2825 Params: lr=0.0032, momentum=0.990, scheduler=linear

Trial 7: PER = 0.2836 Params: lr=0.0031, momentum=0.988, scheduler=cosine

Trial 19: PER = 0.2872 Params: lr=0.0029, momentum=0.977, scheduler=cosine

Trial 4: PER = 0.2970 Params: lr=0.0029, momentum=0.974, scheduler=step, step_size=500

Trial 8: PER = 0.2415 Params: lr=0.0032, momentum=0.985, scheduler=linear

Trial 9: PER = 0.3084 Params: lr=0.0026, momentum=0.976, scheduler=step, step_size=500

Trial 2: PER = 0.2803 Params: lr=0.0028, momentum=0.990, scheduler=cosine

Trial 12: PER = 0.2387 Params: lr=0.0030, momentum=0.974, scheduler=linear

Trial 18: PER = 0.3793 Params: lr=0.0031, momentum=0.979, scheduler=step, step_size=300

Trial 17: PER = 0.2461 Params: lr=0.0027, momentum=0.984, scheduler=linear

Trial 3: PER = 0.2415 Params: lr=0.0027, momentum=0.978, scheduler=linear

Tested Trial 5 on full model with 10k batches: Final Best (min) PER: 0.14324022829532623 (did not test accuracy since the session timed out)

(vs baseline model with 10k batches: Final Best (min) PER: 0.12702937424182892; Average trial-level phoneme accuracy: 0.8778)

In [None]:
%%writefile /kaggle/working/nejm-brain-to-text/model_training/train_model_optimize_sgd.py

import optuna
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
import os
import numpy as np
import torch
import shutil

# --- 1. Define the Objective Function for Optuna ---
def objective(trial):
    
    args_path = 'rnn_args.yaml'
    if not os.path.exists(args_path):
        print(f"Warning: '{args_path}' not found. Using pretrained model's args as template.")
        args_path = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    args = OmegaConf.load(args_path)
    args.dataset.dataset_dir = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final" 
    args.dataset.csv_path = "/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv"
    base_output_dir = args.output_dir
    base_checkpoint_dir = args.checkpoint_dir
    trial_id_str = f"trial_{trial.number}"
    args.output_dir = os.path.join(base_output_dir, trial_id_str)
    args.checkpoint_dir = os.path.join(base_checkpoint_dir, trial_id_str)
    if os.path.exists(args.output_dir):
        print(f"Removing existing output directory: {args.output_dir}")
        shutil.rmtree(args.output_dir) 
    if os.path.exists(args.checkpoint_dir):
         print(f"Removing existing checkpoint directory: {args.checkpoint_dir}")
         shutil.rmtree(args.checkpoint_dir) 

    # --- Suggest Focused SGD Hyperparameters ---
    lr = trial.suggest_float("lr", 0.0026, 0.0032, log=False) # Narrowed LR range
    momentum = trial.suggest_float("momentum", 0.97, 0.99) # Narrowed momentum range
    weight_decay = args.weight_decay # Use default weight decay
    nesterov = True 

    scheduler_type = trial.suggest_categorical("scheduler", ["linear", "cosine", "step"])
    if scheduler_type == "step":
        # step_size must be < num_training_batches to have an effect
        step_size = trial.suggest_int("step_size", 300, 800, step=100)
        gamma = 0.1

    # --- Overwrite Learning Rate Arg ---
    args.lr_max = lr 
    
    args.num_training_batches = 1000 # Keep training short for testing
    
    print(f"\n--- Starting Trial {trial.number} ---")
    print(f"Params: optimizer=SGD, lr={lr}, momentum={momentum}, wd={weight_decay} (default), nesterov={nesterov}") 
    print(f"Scheduler: {scheduler_type}")
    if scheduler_type == "step":
        print(f"StepLR Params: step_size={step_size}, gamma={gamma}")
    print(f"Training for {args.num_training_batches} batches.")
    print(f"Output dir: {args.output_dir}")

    # Create trainer (initializes model)
    trainer = BrainToTextDecoder_Trainer(args) 

    # --- Manually Create SGD Optimizer with Filtered Groups and Tuned Params ---
    
    # 1. Manually define param lists
    bias_params = [p for name, p in trainer.model.named_parameters() if ('gru.bias' in name or 'out.bias' in name) and p.requires_grad]
    day_params = [p for name, p in trainer.model.named_parameters() if 'day_' in name and p.requires_grad]
    other_params = [p for name, p in trainer.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name and p.requires_grad]

    # 2. Structure the groups
    if day_params: 
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'lr': lr}, 
                {'params' : day_params, 'lr' : args.lr_max_day, 'weight_decay' : args.weight_decay_day},
                {'params' : other_params, 'lr': lr, 'weight_decay': weight_decay} 
            ]
    else: 
        param_groups = [
                {'params' : bias_params, 'weight_decay' : 0, 'lr': lr}, 
                {'params' : other_params, 'lr': lr, 'weight_decay': weight_decay} 
            ]
            
    # 3. Create the SGD optimizer
    trainer.optimizer = torch.optim.SGD(
        param_groups, 
        lr=lr, 
        momentum=momentum, 
        nesterov=nesterov 
    ) 

    # --- 4. Recreate the learning rate scheduler BASED ON TRIAL ---
    if scheduler_type == 'linear':
        print("Using LinearLR scheduler.")
        trainer.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer=trainer.optimizer, 
            start_factor=1.0,
            end_factor=args.lr_min / args.lr_max, 
            total_iters=args.lr_decay_steps, # Uses default decay steps
        )
    elif scheduler_type == 'cosine':
        print("Using CosineLR scheduler.")
        trainer.learning_rate_scheduler = trainer.create_cosine_lr_scheduler(trainer.optimizer) 
    elif scheduler_type == 'step':
        print(f"Using StepLR scheduler with step_size={step_size}, gamma={gamma}.")
        trainer.learning_rate_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer=trainer.optimizer,
            step_size=step_size,
            gamma=gamma
        )
         
    train_stats = trainer.train() 
    val_per_list = train_stats.get('val_PERs', []) 
    val_score = np.min(val_per_list) if val_per_list else 1.0 
    print(f"Trial {trial.number} finished. Best (min) PER: {val_score}")
    return val_score
    
if __name__ == "__main__":
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=20)
    
    print("\n--- Optimization Finished ---")
    print("Best trial:")
    trial = study.best_trial
    
    print(f"  Value (Min PER): {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")