In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from dataset import EEGDatasetV2
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import polars as pl

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

import random
import numpy as np

# Set the seed for Python's built-in random module
random.seed(69)

# Set the seed for NumPy's random number generator
np.random.seed(69)

# Set the seed for PyTorch's random number generators
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.cuda.manual_seed_all(69)  # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [2]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined_prev_prev_2.parquet')

event_id,orig_marker,time,Fp1,Fpz,Fp2,F7,F3,Fz,F4,F8,FC5,FC1,FC2,FC6,M1,T7,C3,Cz,C4,T8,M2,CP5,CP1,CP2,CP6,P7,P3,Pz,P4,P8,POz,O1,O2,AF7,AF3,AF4,AF8,F5,F1,F2,F6,FC3,FCz,FC4,C5,C1,C2,C6,CP3,CP4,P5,P1,P2,P6,PO5,PO3,PO4,PO6,FT7,FT8,TP7,TP8,PO7,PO8,Oz,prev_prev_marker,prev_marker,marker
i64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,str,str
0,"""Stimulus/P""",22.746,-21.377546,-3.329028,-18.758145,-30.73564,-27.979262,-17.733182,-11.122367,10.858938,-44.451512,-12.26531,-11.228322,-7.163317,-14.05838,-30.227744,-31.09242,-12.310457,-37.415239,-10.066313,-17.291664,-34.362408,11.172778,-5.593829,-11.168018,-35.016279,-24.367157,-11.756309,-25.158083,-11.400707,-9.222274,-26.857691,-28.989126,-35.296483,-18.676274,-2.265123,-0.681522,-27.036101,-17.27791,-14.018948,-0.860448,-24.590269,-11.165179,-15.876778,-15.611566,-15.470771,-11.342499,-0.354879,-47.082836,-25.75458,-23.214993,-28.41397,-32.145426,-23.09293,-18.515905,-18.41682,-2.948578,-29.370747,-31.412779,-7.822252,-27.387539,-27.833376,-17.620415,-29.850262,-16.782366,"""Left""","""Right""","""Right"""
0,"""Stimulus/P""",22.748,-21.219501,-3.171698,-18.695978,-30.899023,-27.964396,-17.633559,-10.771049,11.624404,-44.61444,-11.903594,-10.694397,-7.173872,-14.512791,-29.422429,-30.804771,-11.544276,-37.174302,-5.694034,-15.207482,-34.277941,11.258942,-5.45682,-11.079793,-34.895782,-24.831925,-11.852139,-25.596388,-11.766606,-9.404878,-27.45209,-29.890226,-35.150051,-18.198552,-1.781131,-0.387058,-26.465131,-16.855895,-13.527142,0.034495,-24.553398,-10.965987,-15.584374,-15.866586,-15.515992,-11.014299,-0.915622,-47.12713,-25.906854,-23.593264,-28.479127,-32.260117,-23.671744,-18.869104,-18.846811,-3.254924,-30.022285,-31.631908,-6.706748,-27.717578,-26.271575,-17.999255,-30.66878,-17.174672,"""Left""","""Right""","""Right"""
0,"""Stimulus/P""",22.75,-21.107591,-3.092899,-18.653458,-31.489265,-28.012768,-17.474946,-10.519029,12.088246,-45.094286,-11.5613,-10.089097,-7.148177,-15.480021,-28.509242,-30.69717,-10.76844,-36.881722,-3.919959,-13.700795,-34.221406,11.324101,-5.275087,-10.958402,-34.506983,-25.011003,-11.901826,-25.887634,-12.305291,-9.472896,-27.557485,-30.611738,-35.342925,-17.875566,-1.327837,-0.237936,-25.884175,-16.47395,-13.142645,0.856802,-24.602636,-10.661366,-15.28256,-16.268992,-15.581594,-10.649388,-1.149936,-47.226449,-26.028439,-23.705772,-28.537429,-32.317398,-24.173332,-18.842379,-18.952802,-3.564849,-30.51049,-32.262253,-5.861874,-28.14821,-25.338115,-17.990954,-31.322179,-16.965819,"""Left""","""Right""","""Right"""
0,"""Stimulus/P""",22.752,-21.006886,-3.056991,-18.582953,-32.411302,-28.069638,-17.238215,-10.384327,12.253836,-45.787117,-11.235732,-9.436472,-7.044266,-16.833204,-27.524872,-30.745046,-10.0373,-36.523804,-4.812991,-12.846942,-34.155775,11.375401,-5.041548,-10.786247,-33.837596,-24.853686,-11.903567,-25.986433,-12.937813,-9.399739,-27.148928,-31.042835,-35.812823,-17.689045,-0.922058,-0.185608,-25.297602,-16.130204,-12.878529,1.541432,-24.679923,-10.259005,-14.956661,-16.743237,-15.629367,-10.255516,-0.988882,-47.336807,-26.087693,-23.51264,-28.561456,-32.291279,-24.516449,-18.415956,-18.701083,-3.820157,-30.758124,-33.167084,-5.325221,-28.585937,-25.039198,-17.584963,-31.720538,-16.191491,"""Left""","""Right""","""Right"""
0,"""Stimulus/P""",22.754,-20.872785,-3.017698,-18.431995,-33.530736,-28.07893,-16.911503,-10.372313,12.147042,-46.564886,-10.917663,-8.764207,-6.83163,-18.402727,-26.528758,-30.907434,-9.397617,-36.09857,-8.137851,-12.662949,-34.050964,11.419024,-4.761436,-10.570737,-32.913805,-24.350841,-11.86542,-25.886998,-13.604752,-9.183589,-26.25909,-31.129443,-36.458857,-17.601949,-0.578673,-0.170892,-24.714531,-15.811797,-12.733076,2.042612,-24.727222,-9.773674,-14.59811,-17.214512,-15.624975,-9.844416,-0.423997,-47.41706,-26.072529,-23.013341,-28.529877,-32.175476,-24.660389,-17.617972,-18.101536,-3.988149,-30.737496,-34.17095,-5.110807,-28.945567,-25.33534,-16.816198,-31.824498,-14.951163,"""Left""","""Right""","""Right"""
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2751,"""Stimulus/A""",956.836,-18.873559,-24.227101,-13.486943,-0.403397,-11.98354,-16.557347,-17.518873,-14.536315,-6.105563,-10.793705,-4.853747,-7.456738,10.414136,-2.954412,-5.026366,5.687796,-2.821944,3.946126,4.366952,-0.86041,-2.708516,3.370458,7.267404,2.82558,2.815112,-1.849568,9.221112,7.698055,2.700318,3.765467,4.554788,6.173636,-7.883982,-18.15556,-1.624224,-10.617139,-18.317587,-17.6747,-16.561863,-4.705041,-4.589931,1.813413,-4.839258,-4.689196,0.565952,-2.825098,-4.87889,4.368522,-1.258749,4.727739,7.158111,5.224539,2.009805,2.215446,9.938263,10.332691,0.287628,-6.543679,5.453187,0.753634,3.273086,10.787134,5.02536,"""Left""","""Right""","""Left"""
2751,"""Stimulus/A""",956.838,-18.704103,-24.969565,-14.585048,-1.047141,-11.943701,-16.45773,-17.959561,-15.664664,-7.082881,-11.007107,-5.073525,-8.186207,8.59455,-4.468591,-5.374392,5.653907,-2.543513,2.916711,1.86508,-2.226537,-2.766756,3.701948,7.189953,0.664609,2.01553,-1.72148,9.666301,7.382395,2.65323,2.507406,4.086716,5.953678,-7.897042,-18.654489,-2.907517,-10.716993,-18.135785,-18.222283,-17.620759,-5.103536,-4.95171,1.615306,-5.864539,-4.802275,0.860669,-3.010969,-5.317478,4.972083,-2.620815,4.514949,7.682226,5.731373,0.435993,0.715036,10.155613,10.351784,-0.944356,-7.749362,3.527842,-0.227262,1.671753,10.779771,4.207392,"""Left""","""Right""","""Left"""
2751,"""Stimulus/A""",956.84,-18.061667,-25.302962,-15.126122,-1.543149,-11.638041,-16.168372,-18.234343,-16.689475,-7.844465,-11.080084,-5.22795,-8.864504,6.983161,-5.717354,-5.553369,5.590639,-2.268625,1.711394,-0.604882,-3.276609,-2.707251,3.97878,6.988798,-1.132993,1.500356,-1.638895,10.013091,6.84536,2.529954,1.345095,3.459218,5.952675,-7.653053,-18.907577,-3.972871,-10.577918,-17.735399,-18.616241,-18.516636,-5.395328,-5.275623,1.45044,-6.564501,-4.839488,1.135505,-3.189045,-5.473588,5.565337,-3.532208,4.463409,8.272403,6.121259,-0.84287,-0.486312,10.264764,10.23355,-2.014752,-9.047787,1.966788,-1.368454,0.333598,10.611117,3.436481,"""Left""","""Right""","""Left"""
2751,"""Stimulus/A""",956.842,-17.023192,-25.220666,-15.079919,-1.852834,-11.10314,-15.743369,-18.338312,-17.531294,-8.309149,-11.006394,-5.293303,-9.428791,5.715629,-6.592615,-5.547356,5.501853,-2.008694,0.434555,-2.88224,-3.923841,-2.530573,4.190703,6.70415,-2.430261,1.293626,-1.613443,10.266805,6.137341,2.336787,0.363123,2.739203,6.169374,-7.181344,-18.903561,-4.734289,-10.219783,-17.170485,-18.840451,-19.18934,-5.546836,-5.528552,1.357458,-6.879141,-4.784823,1.373481,-3.355288,-5.336791,6.114987,-3.940727,4.575404,8.896693,6.400383,-1.731706,-1.306744,10.309981,10.048173,-2.819302,-10.318089,0.896298,-2.548241,-0.638066,10.358187,2.779875,"""Left""","""Right""","""Left"""


In [3]:
config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined_prev_prev_2.parquet',
    'batch_size': 32,
    'input_size': 65, # Number of features
    'hidden_size': 32, # Number of features in the hidden state
    'num_layers': 1, # Number of recurrent layers
    'output_size': 1, # Number of output classes (binary classification)
    'learning_rate': 5e-4,
    'epochs': 300,
    'bidirectional': False,
    'dropout': 0.5,
    'log_dir': './runs/RNN',
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
}

In [4]:
train_set = torch.load('train_set.pt', weights_only=False)
val_set = torch.load('val_set.pt', weights_only=False)
test_set = torch.load('test_set.pt', weights_only=False)

def collate_fn(batch):
    """
    Collate function for variable-length EEG feature sequences.

    Each sample is expected to be a tuple (label, feature), where:
    - label is a scalar tensor (or 1D tensor) representing the class/target.
    - feature is a tensor of shape (seq_len, num_channels), where seq_len may vary.

    This function stacks labels and pads features along the time dimension so that
    all sequences in the batch have the same length.
    """
    # Unzip the batch into labels and features
    labels, features, original_labels = zip(*batch)
    
    labels = torch.stack(labels)
    padded_features = pad_sequence(features, batch_first=True)
    original_labels = torch.stack(original_labels)
    
    return labels, padded_features, original_labels


generator = torch.Generator().manual_seed(69)  # Set seed
train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    generator=generator,  # Add this line
    num_workers=0,
    pin_memory=True,
    # persistent_workers=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")


# Set up logging
writer = SummaryWriter(log_dir=config['log_dir'])

train dataset shape: (1926, [labels: torch.Size([]), features: [2000, 65]])


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.optim.lr_scheduler import ReduceLROnPlateau

# torch.cuda.empty_cache() 
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'



class RNNClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout=0.5, bidirectional=False):
        super(RNNClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.prelayer_norm = nn.BatchNorm2d(num_features=input_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout)

        self.layer_norm = nn.LayerNorm(hidden_size)
        self.l_relu = nn.LeakyReLU(hidden_size)
        
        self.dropout = nn.Dropout(dropout)

        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.permute(0, 2, 1).unsqueeze(-1)  # Reshape to [batch_size, num_features, seq_len, 1]
        x = self.prelayer_norm(x)
        x = x.squeeze(-1).permute(0, 2, 1)  # Reshape back to [batch_size, seq_len, num_features]

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))

        out = self.layer_norm(out)
        out = self.l_relu(out)
        
        out = self.dropout(out)
        out = self.fc(out[:, -1, :])
        return out

# Initialize the model, loss function, and optimizer
model = RNNClassifier(config['input_size'], config['hidden_size'], config['num_layers'], config['output_size'], dropout=config['dropout']).to(config['device'])

# Initialize the loss function with class weights
criterion = torch.nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Define the ReduceLROnPlateau scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, threshold=0.001, threshold_mode='abs')

# Initialize variables to track the best validation loss
best_metric = 0

print("Training start")
# Training loop
for epoch in tqdm(range(config['epochs']), desc="Training"):
    # ---------- TRAIN ----------
    model.train()
    train_loss = 0.0
    
    for labels, features, _ in train_loader:
        features = features.to(config['device']).float()
        labels = labels.to(config['device']).float().unsqueeze(1)
        
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping (if specified)
        if config.get('grad_clip') is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
        
        optimizer.step()
        train_loss += loss.item()
    
    train_loss /= len(train_loader)
    
    # ---------- VALIDATION ----------
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for labels, features, _ in val_loader:
            features = features.to(config['device']).float()
            labels = labels.to(config['device']).float().unsqueeze(1)            
            outputs = model(features)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            preds = torch.sigmoid(outputs)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    val_loss /= len(val_loader)
    predictions = (np.array(all_preds) > 0.5).astype(int)
    
    # ---------- METRICS ----------
    accuracy = accuracy_score(all_labels, predictions)
    precision = precision_score(all_labels, predictions)
    recall = recall_score(all_labels, predictions)
    f1 = f1_score(all_labels, predictions)
    
    # ---------- SCHEDULER UPDATE ----------
    current_lr = optimizer.param_groups[0]['lr']
    
    if scheduler is not None:
        scheduler.step(val_loss)
    
    # ---------- LOGGING ----------
    writer.add_scalar('LR', current_lr, epoch)
    writer.add_scalar('Loss/Train', train_loss, epoch)
    writer.add_scalar('Loss/Val', val_loss, epoch)
    writer.add_scalar('Accuracy', accuracy, epoch)
    writer.add_scalar('Precision', precision, epoch)
    writer.add_scalar('Recall', recall, epoch)
    writer.add_scalar('F1', f1, epoch)
    
    # You can also combine them in a single dictionary
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }
    writer.add_scalars('Metrics', metrics, epoch)
    
    # ---------- SAVE BEST MODEL ----------
    if accuracy > best_metric:
        best_metric = accuracy
        torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")

writer.close()




Training start


Training:   0%|          | 0/300 [00:00<?, ?it/s]

In [7]:
f"{config['log_dir']}/best_model.pth"

'./runs/RNN/best_model.pth'

In [9]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

best_model = RNNClassifier(config['input_size'], config['hidden_size'], config['num_layers'], config['output_size'], dropout=config['dropout']).to(config['device'])

# Load the state dictionary
state_dict = torch.load(f"{config['log_dir']}/best_model.pth", map_location=config['device'])
best_model.load_state_dict(state_dict)

# Move model to the correct device
best_model = best_model.to(config['device'])

# Set model to evaluation mode
best_model.eval()

test_loss = 0
all_test_markers = []
all_test_predictions = []
all_test_original_markers = []
with torch.no_grad():
    for markers, features, original_markers in tqdm(test_loader):
        features = features.to(config['device'])
        markers = markers.to(config['device'])

        outputs = best_model(features)

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())
        all_test_original_markers.extend(original_markers.cpu().numpy().flatten())

# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, 1)
writer.add_scalar('Metrics/test_precision', test_precision, 1)
writer.add_scalar('Metrics/test_recall', test_recall, 1)
writer.add_scalar('Metrics/test_f1', test_f1, 1)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, 1)

# Close the TensorBoard writer
writer.close()



  0%|          | 0/13 [00:00<?, ?it/s]

In [10]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)


test_accuracy=0.6473429951690821
test_precision=0.5989847715736041
test_recall=0.6378378378378379
test_f1=0.6178010471204188
test_roc_auc=np.float64(0.6705771273456863)



In [11]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = f1_score(all_test_markers, binary_predictions)

    if current_recall > best_f1:
        best_f1 = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_f1=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.33999999999999986)
best_f1=0.6487603305785123


In [12]:
from sklearn.metrics import accuracy_score
import numpy as np
best_threshold = 0.1
best_accuracy = 0.0
thresholds = np.arange(0.005, 1.0, 0.005)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int) 
    current_recall = accuracy_score(all_test_markers, binary_predictions)

    if current_recall > best_accuracy:
        best_accuracy = current_recall
        best_threshold = threshold
        precision = precision_score(all_test_markers, [1 if p > threshold else 0 for p in all_test_predictions])
        recall = recall_score(all_test_markers, [1 if p > threshold else 0 for p in all_test_predictions])
        f1 = f1_score(all_test_markers, [1 if p > threshold else 0 for p in all_test_predictions])
        roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

print(f"{best_threshold=}")
print(f"""
{best_accuracy=}
{precision=}
{recall=}
{f1=}
{roc_auc=}
""")

  0%|          | 0/199 [00:00<?, ?it/s]

best_threshold=np.float64(0.515)

best_accuracy=0.6521739130434783
precision=0.6073298429319371
recall=0.6270270270270271
f1=0.6170212765957447
roc_auc=np.float64(0.6705771273456863)



In [13]:
from collections import Counter

statuses = []
test_rightness =  all_test_markers == np.array([1 if p > best_threshold else 0 for p in all_test_predictions])
for original_marker, positive_verdict in zip(all_test_original_markers, test_rightness):
    if original_marker == 1:
        if positive_verdict:
            statuses.append('Stimulus/P right')
        else:
            statuses.append('Stimulus/P wrong')
    else:
        if positive_verdict:
            statuses.append('Stimulus/A right')
        else:
            statuses.append('Stimulus/A wrong')


results_counter = Counter(statuses)



# --- Extract counts ---
# Use .get() for safety in case a key is missing (though not in this example)
stim_A_right = results_counter.get('Stimulus/A right', 0)
stim_A_wrong = results_counter.get('Stimulus/A wrong', 0)
stim_P_right = results_counter.get('Stimulus/P right', 0)
stim_P_wrong = results_counter.get('Stimulus/P wrong', 0)

# --- Calculate totals ---
total_A_events = stim_A_right + stim_A_wrong
total_P_events = stim_P_right + stim_P_wrong

total_right_predictions = stim_A_right + stim_P_right
total_wrong_predictions = stim_A_wrong + stim_P_wrong
total_events = total_right_predictions + total_wrong_predictions

# --- Calculate Metrics ---

# 1. Overall Accuracy of the Rules
# (How often did the rule correctly predict the outcome?)
overall_accuracy = total_right_predictions / total_events if total_events > 0 else 0.0

# 2. Accuracy of Rule A (Flip)
# (When Stimulus/A was shown, how often was flipping the correct strategy?)
accuracy_A = stim_A_right / total_A_events if total_A_events > 0 else 0.0

# 3. Accuracy of Rule P (Persist)
# (When Stimulus/P was shown, how often was persisting the correct strategy?)
accuracy_P = stim_P_right / total_P_events if total_P_events > 0 else 0.0

# --- Print Results ---
print("--- Metrics Based on Rule Application Success ---")
print(f"Total Events Analyzed (A + P): {total_events}")
print("-" * 20)
print(f"Stimulus/A Events: {total_A_events}")
print(f"  - Rule 'Flip' Correct: {stim_A_right}")
print(f"  - Rule 'Flip' Incorrect: {stim_A_wrong}")
print(f"  - Accuracy of 'Flip' Rule: {accuracy_A:.4f}")
print("-" * 20)
print(f"Stimulus/P Events: {total_P_events}")
print(f"  - Rule 'Persist' Correct: {stim_P_right}")
print(f"  - Rule 'Persist' Incorrect: {stim_P_wrong}")
print(f"  - Accuracy of 'Persist' Rule: {accuracy_P:.4f}")
print("-" * 20)
print(f"Overall Accuracy (Rule matched outcome): {overall_accuracy:.4f}")
print("-" * 20)

print("\nNote:")
print("These metrics evaluate the success rate of the simple 'Flip'/'Persist' heuristics.")
print("They are NOT standard classification metrics like Precision, Recall, or F1-Score ")

--- Metrics Based on Rule Application Success ---
Total Events Analyzed (A + P): 414
--------------------
Stimulus/A Events: 268
  - Rule 'Flip' Correct: 249
  - Rule 'Flip' Incorrect: 19
  - Accuracy of 'Flip' Rule: 0.9291
--------------------
Stimulus/P Events: 146
  - Rule 'Persist' Correct: 21
  - Rule 'Persist' Incorrect: 125
  - Accuracy of 'Persist' Rule: 0.1438
--------------------
Overall Accuracy (Rule matched outcome): 0.6522
--------------------

Note:
These metrics evaluate the success rate of the simple 'Flip'/'Persist' heuristics.
They are NOT standard classification metrics like Precision, Recall, or F1-Score 
