In [None]:
import torch
import numpy as np
import signatory
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
from collections import defaultdict
from tqdm import tqdm
import joblib
import random
import csv

features1 = torch.load('/home/sichengyu/text/NCDE/feature_tensor/feature_2020_0.3_whisper_30.pt')
features2 = torch.load('/home/sichengyu/text/NCDE/feature_tensor/feature_2020_0.3_test_whisper_30.pt')
labels1   = torch.load('/home/sichengyu/text/NCDE/feature_tensor/labels1_2020_0.3_train_whisper_30.pt')
labels2   = torch.load('/home/sichengyu/text/NCDE/feature_tensor/labels2_2020_0.3_test_whisper_30.pt')
indices_train = torch.load('/home/sichengyu/text/NCDE/feature_tensor/indices_train_2020_0.3_whisper_30.pt')
indices_test  = torch.load('/home/sichengyu/text/NCDE/feature_tensor/indices_test_2020_0.3_whisper_30.pt')

features1 = features1.numpy()
features2 = features2.numpy()
labels1   = labels1.numpy()
labels2   = labels2.numpy()

indices_train = np.array(indices_train)
indices_test  = np.array(indices_test)

indices_train = torch.from_numpy(indices_train)
indices_test  = torch.from_numpy(indices_test)

print("features1 shape:", features1.shape)
print("features2 shape:", features2.shape)
print("indices_train shape:", indices_train.shape)
print("indices_test shape:", indices_test.shape)

y_train = labels1
y_test  = labels2

features1_torch = torch.from_numpy(features1).float()  # (N, T, C)
features2_torch = torch.from_numpy(features2).float()

audio_ids_test = indices_test[:, 0]

print("Calculating signature features (training set)...")
X_train_sig_torch = signatory.signature(features1_torch, depth=2)
print("Calculating signature features (test set)...")
X_test_sig_torch  = signatory.signature(features2_torch, depth=2)

X_train_sig = X_train_sig_torch.numpy()
X_test_sig  = X_test_sig_torch.numpy()

print("Calculating mean features (training set)...")
X_train_mean_torch = features1_torch.mean(dim=1)
print("Calculating mean features (test set)...")
X_test_mean_torch  = features2_torch.mean(dim=1)

X_train_mean = X_train_mean_torch.numpy()
X_test_mean  = X_test_mean_torch.numpy()

# These are the labels for the ADReSS dataset, where 0 indicates healthy and 1 indicates dementia. Adjust parameters according to the specific dataset.
labels_test = np.concatenate([np.zeros(24), np.ones(24)])  
labels_train = np.concatenate([np.zeros(54), np.ones(54)])

def run_experiment(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # signature
    base_estimator = DecisionTreeClassifier(random_state=seed)
    model_sig = BaggingClassifier(
        estimator=base_estimator,
        n_estimators=100,
        random_state=seed,
        verbose=0,
        n_jobs=-1
    )
    model_sig.fit(X_train_sig, y_train)
    y_pred_segments_sig = model_sig.predict(X_test_sig)
    
        
    audio_segments_test = {}
    for idx, pred_val in zip(indices_test[:, 0], y_pred_segments_sig):
        idx = int(idx.item()) if hasattr(idx, 'item') else int(idx)
        if idx not in audio_segments_test:
            audio_segments_test[idx] = []
        audio_segments_test[idx].append(pred_val)
    
    audio_predictions_test = {}
    for idx, preds in audio_segments_test.items():
        count_gt_0_5 = sum(1 for pred in preds if pred > 0.5)
        count_le_0_5 = len(preds) - count_gt_0_5
        final_pred = 1 if count_gt_0_5 > count_le_0_5 else 0
        audio_predictions_test[idx] = final_pred
    
    correct_count = 0
    predict_label = []
    for idx, avg_pred in audio_predictions_test.items():
        true_label = labels_test[idx]
        pred_label = 1 if avg_pred > 0.5 else 0
        predict_label.append((avg_pred > 0.5))
        if pred_label == true_label:
            correct_count += 1
    
    audio_acc_sig = correct_count / len(audio_predictions_test)
    from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
    
    audio_f1_sig  = f1_score(labels_test, predict_label, average='macro')
    audio_pre_sig = precision_score(labels_test, predict_label, average='macro')
    audio_rec_sig = recall_score(labels_test, predict_label, average='macro')

    # Average-pooling
    base_estimator2 = DecisionTreeClassifier(random_state=seed)
    model_mean = BaggingClassifier(
        estimator=base_estimator2,
        n_estimators=100,
        random_state=seed,
        verbose=0,
        n_jobs=-1
    )
    model_mean.fit(X_train_mean, y_train)
    y_pred_segments_mean = model_mean.predict(X_test_mean)

    audio_segments_test = {}
    for idx, pred_val in zip(indices_test[:, 0], y_pred_segments_mean):
        idx = int(idx.item()) if hasattr(idx, 'item') else int(idx)
        if idx not in audio_segments_test:
            audio_segments_test[idx] = []
        audio_segments_test[idx].append(pred_val)
    
    audio_predictions_test = {}
    for idx, preds in audio_segments_test.items():
        count_gt_0_5 = sum(1 for pred in preds if pred > 0.5)
        count_le_0_5 = len(preds) - count_gt_0_5
        final_pred = 1 if count_gt_0_5 > count_le_0_5 else 0
        audio_predictions_test[idx] = final_pred
        
    correct_count = 0
    predict_label = []
    for idx, avg_pred in audio_predictions_test.items():
        true_label = labels_test[idx]
        pred_label = 1 if avg_pred > 0.5 else 0
        predict_label.append((avg_pred > 0.5))
        if pred_label == true_label:
            correct_count += 1
    
    audio_acc_mean = correct_count / len(audio_predictions_test)
 
    audio_f1_mean  = f1_score(labels_test, predict_label, average='macro')
    audio_pre_mean = precision_score(labels_test, predict_label, average='macro')
    audio_rec_mean = recall_score(labels_test, predict_label, average='macro')

    return (audio_acc_sig, audio_f1_sig, audio_pre_sig, audio_rec_sig,
            audio_acc_mean, audio_f1_mean, audio_pre_mean, audio_rec_mean)

results = []
num_seeds = 50
all_seeds = seeds = np.arange(1001, 1051).tolist()

for seed in all_seeds:
    print(f"\n===== Starting experiment {seed+1}/{num_seeds} (seed={seed}) =====")
    (acc_sig, f1_sig, pre_sig, rec_sig,
     acc_mean, f1_mean, pre_mean, rec_mean) = run_experiment(seed)
    
    print(f"[Signature method] Audio level: accuracy={acc_sig:.4f}, f1={f1_sig:.4f}, precision={pre_sig:.4f}, recall={rec_sig:.4f}")
    print(f"[Mean method] Audio level: accuracy={acc_mean:.4f}, f1={f1_mean:.4f}, precision={pre_mean:.4f}, recall={rec_mean:.4f}")

    # Record results to list
    results.append([
        seed,
        acc_sig, f1_sig, pre_sig, rec_sig,
        acc_mean, f1_mean, pre_mean, rec_mean
    ])

csv_filename = "solution/signature_experiment_results_2020mean_x.csv"
with open(csv_filename, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow([
        "seed",
        "sign_acc", "sign_f1", "sign_precision", "sign_recall",
        "mean_acc", "mean_f1", "mean_precision", "mean_recall"
    ])
    for row in results:
        writer.writerow(row)

print(f"\nAll {num_seeds} experiments completed, results saved to {csv_filename}")

