In [2]:
import os
from metrics import *
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier




In [3]:
# Load features and labels
X = np.load('../features.npy')
y = np.load('../labels.npy')

In [4]:
# Reshape X to 2 dimensions
X = X.reshape(X.shape[0], -1)
X.shape, y.shape

((8724, 2048), (8724,))

In [5]:
# Count the number of samples in each class
print('Number of samples in each class:')
print(np.unique(y, return_counts=True))

Number of samples in each class:
(array([0, 1, 2]), array([6232,  256, 2236]))


In [6]:
os.makedirs('logs', exist_ok=True)
os.makedirs('logs/best', exist_ok=True)
os.makedirs('figures', exist_ok=True)

In [7]:
# Initialize 10-fold cross-validation
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Initialize a classifier (e.g., Random Forest)
clf = RandomForestClassifier(random_state=42)

# Perform cross-validation
accuracies = []
auc_scores = []
f1_scores = []
sensitivities = []
specificities = []
# Define class names
class_names = ["Few", "Many", "None"]

for fold, (train_index, val_index) in enumerate(skf.split(X, y), 1):
    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]
    
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_val)
    y_prob = clf.predict_proba(X_val)
    
    # Calculate metrics
    accuracy, class_metrics, auc, f1, cm, avg_sensitivity, avg_specificity = calculate_metrics(y_val, y_pred, y_prob)


    # Log metrics to custom logger
    metrics = {
        'fold': fold,
        'val_accuracy': accuracy,
        'val_auc': auc,
        'val_f1': f1,
        'avg_sensitivity': avg_sensitivity,
        'avg_specificity': avg_specificity,
        **{f'class_{class_names[i]}_sensitivity': metrics["sensitivity"] for i, metrics in enumerate(class_metrics)},
        **{f'class_{class_names[i]}_specificity': metrics["specificity"] for i, metrics in enumerate(class_metrics)},
        **{f'class_{class_names[i]}_f1': 2 * metrics["sensitivity"] * metrics["specificity"] / (metrics["sensitivity"] + metrics["specificity"]) for i, metrics in enumerate(class_metrics)}
    }

    custom_log(metrics, model_name='random_forest', log_dir='logs')

    # Optionally, plot confusion matrix
    plot_confusion_matrix(cm, class_names=class_names, epoch_num=0, model_name='random_forest', fold_num=fold)

    accuracies.append(accuracy)
    auc_scores.append(auc)
    f1_scores.append(f1)
    sensitivities.append(avg_sensitivity)
    specificities.append(avg_specificity)
    
# Print average metrics
print(f'Average Accuracy: {np.mean(accuracies)}')
print(f'Average AUC: {np.mean(auc_scores)}')
print(f'Average F1 Score: {np.mean(f1_scores)}')
print(f'Average Sensitivity: {np.mean(sensitivities)}')
print(f'Average Specificity: {np.mean(specificities)}')

Average Accuracy: 0.9732867786920563
Average AUC: 0.9886172476045987
Average F1 Score: 0.972588539016568
Average Sensitivity: 0.9521983838836576
Average Specificity: 0.9757101040131102
