In [4]:
import os
import pandas as pd
import json
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.utils import resample
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    balanced_accuracy_score, f1_score  # <<< NEW >>>
)
from utils import utils

In [5]:
path1 = '/mnt/teams/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/hand_scored'
path2 = '/mnt/teams/TM_Lab/Tony/wr_new/data_used/tta_gcamp8s/gemini_predictions_full_sys'

# --- Create the first dictionary for the 'hand_scored' files ---
print("Loading files from 'hand_scored' directory...")
hand_scored_data = {} # Initialize an empty dictionary
# Loop through every file in the first directory
for filename in os.listdir(path1):
    # We only care about CSV files
    if filename.endswith('.csv'):
        # Create the full path to the file
        file_path = os.path.join(path1, filename)
        try:
            # Use the filename as the key and the loaded CSV data (a DataFrame) as the value
            hand_scored_data[filename] = pd.read_csv(file_path)
            #print(f"  - Loaded: {filename}")
        except Exception as e:
            print(f"  - FAILED to load {filename}. Error: {e}")
print("Loading files from 'gemini_predictions' directory...")
gemini_predictions_data = {} # Initialize the second empty dictionary

# Loop through every file in the second directory
for filename in os.listdir(path2):
    if filename.endswith('.csv'):
        file_path = os.path.join(path2, filename)
        try:
            gemini_predictions_data[filename] = pd.read_csv(file_path)
            #print(f"  - Loaded: {filename}")
        except Exception as e:
            print(f"  - FAILED to load {filename}. Error: {e}")

Loading files from 'hand_scored' directory...
Loading files from 'gemini_predictions' directory...


In [6]:
# --- 2. Simplify ALL Hand-Scored CSVs ---

print("Processing 'hand_scored_data' to simplify outcomes...")
hand_scored_mapping = {
    'ps': 's',
    'rnd': 'f'
}
for filename, df in hand_scored_data.items():
    if 'outcome' in df.columns:
        df['outcome'] = df['outcome'].replace(hand_scored_mapping)
    else:
        print(f"  - WARNING: 'outcome' column not found in {filename}")
print("Simplification of all hand-scored data is complete.\n")

for filename, df in gemini_predictions_data.items():
    if 'outcome_classification' in df.columns:
        df['outcome_classification'] = df['outcome_classification'].replace(hand_scored_mapping)
    else:
        print(f"  - WARNING: 'outcome' column not found in {filename}")
print("Simplification of all hand-scored data is complete.\n")

# --- 3. Simplify ALL Gemini CSVs ---
'''
print("Processing 'gemini_predictions_data' based on 'percentage_consumed'...")
for filename, df in gemini_predictions_data.items():
    if 'percentage_consumed' in df.columns and 'outcome_classification' in df.columns:
        df['outcome_classification'] = 'f'
        df.loc[df['percentage_consumed'] >= 30, 'outcome_classification'] = 's'
        df.loc[df['percentage_consumed'].isna(), 'outcome_classification'] = np.nan
    else:
        print(f"  - WARNING: Required columns not found in {filename}")
print("Simplification of all Gemini data is complete.")
'''
print("Processing 'hand_scored_data' to simplify outcomes...")
hand_scored_mapping = {
    'ps': 's',
    'rnd': 'f'
}
for filename, df in hand_scored_data.items():
    if 'outcome' in df.columns:
        df['outcome'] = df['outcome'].replace(hand_scored_mapping)
    else:
        print(f"  - WARNING: 'outcome' column not found in {filename}")
print("Simplification of all hand-scored data is complete.\n")

for filename, df in gemini_predictions_data.items():
    if 'outcome' in df.columns:
        df['outcome'] = df['outcome'].replace(hand_scored_mapping)
    else:
        print(f"  - WARNING: 'outcome' column not found in {filename}")
print("Simplification of all hand-scored data is complete.\n")

Processing 'hand_scored_data' to simplify outcomes...
Simplification of all hand-scored data is complete.

Simplification of all hand-scored data is complete.

Processing 'hand_scored_data' to simplify outcomes...
Simplification of all hand-scored data is complete.

Simplification of all hand-scored data is complete.



In [7]:
output_csv_filename = 'results.csv'
sessions = [i for i in  hand_scored_data.keys() if i in gemini_predictions_data.keys()]

all_results = utils.process_all_sessions(
    sessions=sessions,
    hand_scored_data=hand_scored_data,
    gemini_predictions_data=gemini_predictions_data,
    hand_col='outcome_x',
    gemini_col='outcome_y'
)


üöÄ Starting batch processing of 29 sessions...


[1/29] Processing K_R3_2025-02-04_1.csv...

PROCESSING: K_R3_2025-02-04_1.csv

[Step 1] Validating and merging data...
  Total merged rows: 120
  Removing 41 rows with invalid hand-scored values
  Removing 120 rows with invalid Gemini predictions
  ‚ùå No valid data after cleaning

[2/29] Processing FJ_R3_2024-07-15_1.csv...

PROCESSING: FJ_R3_2024-07-15_1.csv

[Step 1] Validating and merging data...
  Total merged rows: 120
  Removing 114 rows with invalid hand-scored values
  Removing 120 rows with invalid Gemini predictions
  ‚ùå No valid data after cleaning

[3/29] Processing AZ_R2_2024-12-14_1.csv...

PROCESSING: AZ_R2_2024-12-14_1.csv

[Step 1] Validating and merging data...
  Total merged rows: 120
  Removing 77 rows with invalid hand-scored values
  Removing 120 rows with invalid Gemini predictions
  ‚ùå No valid data after cleaning

[4/29] Processing K_R3_2025-02-21_1.csv...

PROCESSING: K_R3_2025-02-21_1.csv

[Step 1] Valida

In [49]:
utils.save_results(all_results, output_csv_filename)



‚úÖ Successfully saved 29 results to 'results.csv'


True