In [None]:
from datetime import datetime, timedelta
from sklearn.model_selection import cross_validate, StratifiedKFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                            f1_score, roc_auc_score, confusion_matrix,
                            classification_report, make_scorer)
from sklearn.pipeline import Pipeline
import seaborn as sns
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                            f1_score, roc_auc_score, confusion_matrix,
                            classification_report, roc_curve, precision_recall_curve)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import os
import glob

# Time of conflict for different participants (Task 2 - conflict task)
toc = {
    4: ['15:36:00'], 5: ['18:33:00', '18:44:00'], 6: ['14:57:00'],
    7: ['16:02:00'], 8: ['17:04:00'], 12: ['17:01:00'], 13: ['9:49:00'],
    14: ['11:18:00'], 16: ['14:14:00'], 17: ['15:42:00'],
    18: ['11:22:00', '11:28:00'], 19: ['12:53:00', '13:01:00', '13:02:00'],
    20: ['16:24:00', '16:29:00'], 23: ['12:50:00', '12:55:00'],
    24: ['14:50:00', '14:53:00'], 25: ['16:18:00', '16:20:00', '16:23:00'],
    26: ['10:09:00', '10:16:00'], 27: ['11:14:00', '11:22:00'],
    29: ['13:38:00', '13:43:00']
}

# No-conflict time for Task 1
mtt1 = {
    1: '10:34:00', 2: '11:52:00', 3: '14:10:00', 4: '15:24:00',
    5: '18:21:00', 6: '14:45:00', 7: '15:45:00', 8: '16:51:00',
    9: '11:29:00', 10: '5:45:00', 11: '15:00:00', 12: '16:45:00',
    13: '9:34:00', 14: '11:04:00', 15: '12:30:00', 16: '13:53:00',
    17: '15:09:00', 18: '11:09:00', 19: '12:32:00', 20: '16:16:00',
    21: '15:21:00', 23: '12:34:00', 24: '14:34:00', 25: '16:06:00',
    26: '9:56:00', 27: '10:56:00', 28: '12:00:00', 29: '13:22:00',
    30: '15:06:00'
}

# Channel mapping
CHANNELS = [1, 4, 5, 6, 7, 8]  # FP2, F3, P4, P3, CZ, P4

time_format = '%H:%M:%S'

def time_to_seconds(time_str):
    """Convert HH:MM:SS to seconds from midnight"""
    try:
        parts = time_str.split(':')
        h, m, s = int(parts[0]), int(parts[1]), int(parts[2])
        return h * 3600 + m * 60 + s
    except:
        return None

def is_conflict(ftime_str, participant_id, window_seconds=60):
    """
    Check if a timestamp is within ±window_seconds of a conflict time

    Args:
        ftime_str: timestamp string (HH:MM:SS)
        participant_id: participant number
        window_seconds: window size in seconds (default 60 = 1 minute)

    Returns:
        1 if conflict, 0 if no conflict
    """
    # Check if participant has conflict data
    if participant_id not in toc:
        return 0

    current_seconds = time_to_seconds(ftime_str)
    if current_seconds is None:
        return 0

    # Check if within window of any conflict time
    for conflict_time in toc[participant_id]:
        conflict_seconds = time_to_seconds(conflict_time)
        if conflict_seconds is None:
            continue

        if abs(current_seconds - conflict_seconds) <= window_seconds:
            return 1

    return 0

def load_and_merge_features():
    """
    Load feature files and merge into single dataset with labels

    Returns:
        DataFrame with all features and conflict labels
    """
    all_data = []

    # Task 1: No-conflict task
    print("Loading Task 1 (No Conflict) features...")
    for channel in CHANNELS:
        file_path = f'features_1_{channel}.csv'
        try:
            df = pd.read_csv(file_path)
            df['channel'] = channel
            df['task'] = 1
            all_data.append(df)
            print(f"  Loaded {len(df)} samples from {file_path}")
        except FileNotFoundError:
            print(f"  Warning: {file_path} not found")

    # Task 2: Conflict task
    print("\nLoading Task 2 (Conflict) features...")
    for channel in CHANNELS:
        file_path = f'features_2_{channel}.csv'
        try:
            df = pd.read_csv(file_path)
            df['channel'] = channel
            df['task'] = 2
            all_data.append(df)
            print(f"  Loaded {len(df)} samples from {file_path}")
        except FileNotFoundError:
            print(f"  Warning: {file_path} not found")

    # Combine all data
    combined_df = pd.concat(all_data, ignore_index=True)

    # Add conflict labels based on timestamp
    print("\nLabeling conflict instances...")
    combined_df['conflict'] = combined_df.apply(
        lambda row: is_conflict(row['ftime'], row['participant_id']),
        axis=1
    )

    return combined_df

def prepare_features_and_labels(df):
    """
    Prepare feature matrix and labels for classification

    Returns:
        X: feature matrix
        y: labels
        df: original dataframe with labels
    """
    # Feature columns (statistical features)
    feature_cols = [
        'power_1', 'differential_entropy1', 'mean1', 'std1',
        'skew1', 'kurtosis1', 'iqr1', 'median1',
        'hjorth_01', 'hjorth_11'
    ]

    # Ensure all feature columns exist
    available_features = [col for col in feature_cols if col in df.columns]

    X = df[available_features].values
    y = df['conflict'].values

    return X, y, df

def train_knn_classifier_5fold(X, y):
    """
    Train k-NN classifier with 5-fold cross-validation

    Args:
        X: feature matrix
        y: labels

    Returns:
        Dictionary with cross-validation results
    """
    # Create pipeline with scaling and k-NN
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('knn', KNeighborsClassifier(n_neighbors=5))
    ])

    # Define scoring metrics
    scoring = {
        'accuracy': make_scorer(accuracy_score),
        'precision': make_scorer(precision_score, zero_division=0),
        'recall': make_scorer(recall_score, zero_division=0),
        'f1': make_scorer(f1_score, zero_division=0),
        'roc_auc': make_scorer(roc_auc_score, needs_proba=True)
    }

    # 5-fold stratified cross-validation
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)


    # Perform cross-validation
    cv_results = cross_validate(
        pipeline, X, y,
        cv=cv,
        scoring=scoring,
        return_train_score=True,
        n_jobs=-1
    )

    return cv_results

def print_cv_results(cv_results):

    metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']

    print("\nTest Performance:")
    for metric in metrics:
        test_scores = cv_results[f'test_{metric}']
        mean_score = np.mean(test_scores)
        std_score = np.std(test_scores)
        print(f"  {metric.upper():12s}: {mean_score:.3f} ± {std_score:.3f}")
        print(f"    Per fold: {[f'{s:.3f}' for s in test_scores]}")

    print("\nTrain Performance:")
    for metric in metrics:
        train_scores = cv_results[f'train_{metric}']
        mean_score = np.mean(train_scores)
        std_score = np.std(train_scores)
        print(f"  {metric.upper():12s}: {mean_score:.3f} ± {std_score:.3f}")

def train_final_model_and_evaluate(X, y):
    """
    Train final model on all data and show confusion matrix
    """
    from sklearn.model_selection import train_test_split

    # Split data for final evaluation
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )

    # Train model
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    knn = KNeighborsClassifier(n_neighbors=5)
    knn.fit(X_train_scaled, y_train)

    # Predictions
    y_pred = knn.predict(X_test_scaled)

    # Confusion matrix
    cm = confusion_matrix(y_test, y_pred)

    print(classification_report(y_test, y_pred,
                               target_names=['No Conflict', 'Conflict']))

# Main execution
if __name__ == "__main__":

    # Load and merge features
    print("\n1. Loading and merging features from all channels...")
    df = load_and_merge_features()

    print(f"\nTotal samples: {len(df)}")
    print(f"Participants: {df['participant_id'].nunique()}")
    print(f"Tasks: {sorted(df['task'].unique())}")
    print(f"Channels: {sorted(df['channel'].unique())}")

    # Prepare features and labels
    print("\n2. Preparing features and labels...")
    X, y, df_labeled = prepare_features_and_labels(df)

    print(f"\nFeature matrix shape: {X.shape}")
    print(f"Class distribution:")
    print(f"  No Conflict (0): {np.sum(y == 0)} samples ({np.sum(y == 0)/len(y)*100:.1f}%)")
    print(f"  Conflict (1):    {np.sum(y == 1)} samples ({np.sum(y == 1)/len(y)*100:.1f}%)")

    # Train and evaluate with 5-fold CV
    print("\n3. Training k-NN classifier with 5-fold cross-validation...")
    cv_results = train_knn_classifier_5fold(X, y)

    # Print results
    print_cv_results(cv_results)

    # Train final model and show confusion matrix
    print("\n4. Training final model for visualization...")
    train_final_model_and_evaluate(X, y)

    # Save results
    results_dict = {
        'metric': ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc'],
        'mean': [],
        'std': []
    }

    for metric in results_dict['metric']:
        test_scores = cv_results[f'test_{metric}']
        results_dict['mean'].append(np.mean(test_scores))
        results_dict['std'].append(np.std(test_scores))




In [None]:
# Function to convert time string to datetime object
def time_to_datetime(time_str):
    """Convert time string to datetime object (using arbitrary date)"""
    return datetime.strptime(time_str, '%H:%M:%S')


def parse_datetime_to_seconds(datetime_str):
    """Parse datetime string and extract time in seconds from midnight"""
    try:
        datetime_str = str(datetime_str).strip()
        if ' ' in datetime_str:
            time_part = datetime_str.split(' ')[1]
        else:
            time_part = datetime_str
        parts = time_part.split(':')
        h, m, s = int(parts[0]), int(parts[1]), int(parts[2])
        return h * 3600 + m * 60 + s
    except:
        return None


def is_within_window(datetime_str, target_seconds, window_seconds=60):
    """Check if datetime/timestamp is within window of target time"""
    try:
        ts_seconds = parse_datetime_to_seconds(datetime_str)
        if ts_seconds is None:
            return False
        return (target_seconds - window_seconds) <= ts_seconds <= (target_seconds + window_seconds)
    except:
        return False

def extract_features_for_window(participant_id, target_times, task_type, label):
    """
    Extract features from CSV files within ±1 minute window of target times

    Args:
        participant_id: int, participant number
        target_times: list of time strings
        task_type: 1 or 2 (for features_1_*.csv or features_2_*.csv)
        label: 0 for no conflict, 1 for conflict

    Returns: DataFrame with features, label, and participant_id
    """
    feature_file_numbers = [1, 4, 5, 6, 7, 8]
    all_samples = []

    participant_folder = os.path.join(datadir, str(participant_id))
    #print("Folder = "+participant_folder)
    if not os.path.exists(participant_folder):
        print(f"  WARNING: Folder not found for participant {participant_id}")
        return None

    for target_time in target_times:
        target_seconds = time_to_seconds(target_time)

        if target_seconds is None:
            print(f"  WARNING: Invalid target time format: {target_time}")
            continue

        # Read all 6 files and merge horizontally by timestamp
        file_dataframes = []

        # # Storage for features from all files for this time window
        # window_features = {}
        # timestamps_collected = []

        for file_num in feature_file_numbers:
            file_path = os.path.join(participant_folder, f"features_{task_type}_{file_num}.csv")

            if os.path.exists(file_path):
                try:
                    df = pd.read_csv(file_path)

                    # Use 'ftime' as timestamp column
                    if 'ftime' not in df.columns:
                        print(f"  ERROR: 'ftime' column not found in {file_path}")
                        continue

                    # Filter rows within the time window
                    df['time_seconds'] = df['ftime'].apply(parse_datetime_to_seconds)
                    matching_rows = df[(df['time_seconds'] >= target_seconds - 60) &
                                      (df['time_seconds'] <= target_seconds + 60)]

                    if not matching_rows.empty:
                        # Keep only feature columns and rename with file number
                        feature_cols = [col for col in df.columns if col not in ['ftime', 'time_seconds']]
                        matching_rows = matching_rows[['ftime', 'time_seconds'] + feature_cols].copy()

                        # Rename feature columns to include file number
                        rename_dict = {col: f"file{file_num}_{col}" for col in feature_cols}
                        matching_rows.rename(columns=rename_dict, inplace=True)

                        file_dataframes.append(matching_rows)

                except Exception as e:
                    print(f"  ERROR reading {file_path}: {e}")

        # Merge all 6 files by timestamp (should have same timestamps)
        if len(file_dataframes) == 6:
            # Use the first file as base
            merged_df = file_dataframes[0].copy()

            # Merge other files on 'time_seconds'
            for df in file_dataframes[1:]:
                merged_df = merged_df.merge(df, on=['ftime', 'time_seconds'], how='inner')

            # Add metadata columns
            merged_df['label'] = label
            merged_df['participant_id'] = participant_id
            merged_df['target_time'] = target_time

            # Sort by time_seconds to ensure chronological order
            merged_df = merged_df.sort_values('time_seconds')

            all_samples.append(merged_df)
        else:
            if len(file_dataframes) > 0:
                print(f"  WARNING: Incomplete data for participant {participant_id}, "
                      f"time {target_time} (only {len(file_dataframes)}/6 files had matching data)")

    if all_samples:
        return pd.concat(all_samples, ignore_index=True)
    return None




In [None]:
# Prepare features and labels
X = new_final_dataset.drop(['label', 'participant_id'], axis=1)
y = new_final_dataset['label']
groups = new_final_dataset['participant_id']

In [None]:
# Check for infinity values
print(f"\nChecking for infinity values...")
inf_mask = np.isinf(X.values)
n_inf = inf_mask.sum()
print(f"Number of infinity values: {n_inf}")

if n_inf > 0:
    # Find columns with infinity
    inf_cols = X.columns[np.isinf(X.values).any(axis=0)]
    print(f"Columns with infinity: {list(inf_cols)}")

    # Replace infinity with NaN
    X = X.replace([np.inf, -np.inf], np.nan)
    print("✓ Replaced infinity values with NaN")

# Check for NaN values
n_nan = X.isna().sum().sum()
print(f"\nNumber of NaN values: {n_nan}")

if n_nan > 0:
    # Option 1: Fill NaN with column mean
    X = X.fillna(X.mean())
    print("✓ Filled NaN values with column means")

# Check for extremely large values
print(f"\nChecking for extremely large values...")
max_val = X.max().max()
min_val = X.min().min()
print(f"Maximum value in dataset: {max_val}")
print(f"Minimum value in dataset: {min_val}")

# Clip extreme values (optional - if you have unreasonably large values)
if max_val > 1e10 or min_val < -1e10:
    print(f"WARNING: Extremely large values detected!")
    # Clip to reasonable range (adjust these thresholds based on your data)
    X = X.clip(lower=-1e10, upper=1e10)
    print("✓ Clipped extreme values")

In [None]:
# Initialize LOPO cross-validator
logo = LeaveOneGroupOut()
n_splits = logo.get_n_splits(groups=groups)
print(f"\nNumber of LOPO folds: {n_splits}")


# Storage for results
results = {
    'participant_id': [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1_score': [],
    'auc': [],
    'n_samples_test': [],
    'n_class0_test': [],
    'n_class1_test': []
}

# Store predictions for overall metrics
all_y_true = []
all_y_pred = []
all_y_proba = []

# Initialize model
model = KNeighborsClassifier(n_neighbors=5)

print("\n" + "="*80)
print("LOPO CROSS-VALIDATION")
print("="*80)

fold = 0
for train_idx, test_idx in logo.split(X, y, groups=groups):
    fold += 1

    # Get train and test data
    X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
    y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
    test_participant = groups.iloc[test_idx].iloc[0]

    # Standardize features (fit on train, transform both)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Train model
    model.fit(X_train_scaled, y_train)

    # Predict
    y_pred = model.predict(X_test_scaled)
    y_proba = model.predict_proba(X_test_scaled)[:, 1]

    # Calculate metrics
    acc = accuracy_score(y_test, y_pred)
    prec = precision_score(y_test, y_pred, zero_division=0)
    rec = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)

    # AUC (handle case where only one class in test set)
    try:
        auc = roc_auc_score(y_test, y_proba)
    except:
        auc = np.nan

    # Store results
    results['participant_id'].append(test_participant)
    results['accuracy'].append(acc)
    results['precision'].append(prec)
    results['recall'].append(rec)
    results['f1_score'].append(f1)
    results['auc'].append(auc)
    results['n_samples_test'].append(len(y_test))
    results['n_class0_test'].append(sum(y_test == 0))
    results['n_class1_test'].append(sum(y_test == 1))

    # Aggregate predictions
    all_y_true.extend(y_test)
    all_y_pred.extend(y_pred)
    all_y_proba.extend(y_proba)

    # print(f"Fold {fold:2d} | Participant {test_participant:2d} | "
    #       f"Samples: {len(y_test):3d} | Acc: {acc:.3f} | "
    #       f"Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f} | "
    #       f"AUC: {auc:.3f if not np.isnan(auc) else 'N/A'}")

    print(f"Fold {fold:2d} | Participant {test_participant:2d} | "
      f"Samples: {len(y_test):3d} | Acc: {acc:.3f} | "
      f"Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f} | "
      f"AUC: {auc:.3f}")

# Convert results to DataFrame
results_df = pd.DataFrame(results)

print("\n" + "="*80)
print("PER-PARTICIPANT RESULTS")
print("="*80)
print(results_df.to_string(index=False))

# Calculate aggregate statistics
print("\n" + "="*80)
print("AGGREGATE STATISTICS (Mean ± Std)")
print("="*80)
for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'auc']:
    values = results_df[metric].dropna()
    mean = values.mean()
    std = values.std()
    print(f"{metric.capitalize():15s}: {mean:.3f} ± {std:.3f}")

# Overall confusion matrix
cm = confusion_matrix(all_y_true, all_y_pred)
print("\n" + "="*80)
print("OVERALL CONFUSION MATRIX")
print("="*80)
print(cm)
print(f"\nTrue Negatives:  {cm[0,0]}")
print(f"False Positives: {cm[0,1]}")
print(f"False Negatives: {cm[1,0]}")
print(f"True Positives:  {cm[1,1]}")

# Overall metrics
print("\n" + "="*80)
print("OVERALL PERFORMANCE METRICS")
print("="*80)
print(f"Accuracy:  {accuracy_score(all_y_true, all_y_pred):.3f}")
print(f"Precision: {precision_score(all_y_true, all_y_pred):.3f}")
print(f"Recall:    {recall_score(all_y_true, all_y_pred):.3f}")
print(f"F1-Score:  {f1_score(all_y_true, all_y_pred):.3f}")
print(f"AUC-ROC:   {roc_auc_score(all_y_true, all_y_proba):.3f}")

# Plot ROC Curve
fpr, tpr, _ = roc_curve(all_y_true, all_y_proba)
roc_auc = roc_auc_score(all_y_true, all_y_proba)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2,
         label=f'ROC curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - LOPO Cross-Validation')
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.savefig('roc_curve_lopo.png', dpi=300, bbox_inches='tight')
print("✓ ROC curve saved to: roc_curve_lopo.png")

# Plot Precision-Recall Curve
precision, recall, _ = precision_recall_curve(all_y_true, all_y_proba)

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='blue', lw=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - LOPO Cross-Validation')
plt.grid(alpha=0.3)
plt.savefig('pr_curve_lopo.png', dpi=300, bbox_inches='tight')