In [None]:
import os
import torch
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split, cross_val_score, RepeatedStratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, recall_score

# Paths to the directories containing the .pt files
hc_path = '/content/drive/MyDrive/removedspikes_nonan_replace_window10avg/ply_dtrnd/dtrnd-tddr-fit-ply-dtrnd/baselinecorr/ready for segmentation/merged_datas_all_normalized(abslute value max each ch)/whole2_allfunctions_hc_rest_whole_swn+graphfeat+hbr'
mci_path = '/content/drive/MyDrive/removedspikes_nonan_replace_window10avg/ply_dtrnd/dtrnd-tddr-fit-ply-dtrnd/baselinecorr/ready for segmentation/merged_datas_all_normalized(abslute value max each ch)/whole2_allfunctions_mci_rest_whole_swn+graphfeat+hbr'

# Function to load .pt files from a directory and extract node and edge features
def load_data_from_directory(directory, label):
    data_list = []
    for filename in os.listdir(directory):
        if filename.endswith('.pt'):
            data = torch.load(os.path.join(directory, filename))
            node_features = data.x.flatten().numpy()
            edge_features = data.edge_attr.flatten().numpy()
            combined_features = np.concatenate((node_features, edge_features))
            data_list.append((combined_features, label))
    return data_list

# Load data from both directories
hc_data = load_data_from_directory(hc_path, 0)
mci_data = load_data_from_directory(mci_path, 1)

# Combine the data and create feature matrix and labels
all_data = hc_data + mci_data

# Find the maximum length of the feature vectors
max_length = max(len(item[0]) for item in all_data)

# Pad feature vectors to the maximum length
X = np.array([np.pad(item[0], (0, max_length - len(item[0])), 'constant') for item in all_data])
y = np.array([item[1] for item in all_data])

# Get the shape of the original data
sample_data = torch.load(os.path.join(hc_path, os.listdir(hc_path)[2]))
num_nodes = sample_data.x.shape[0]
num_node_features = sample_data.x.shape[1]
num_edge_features = sample_data.edge_attr.shape[1]
total_node_features = num_nodes * num_node_features

# Function to evaluate model and get feature importances
def evaluate_model(X_train, X_test, y_train, y_test, random_state):
    model = xgb.XGBClassifier(n_estimators=100, random_state=random_state)
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)

    feature_importance = model.feature_importances_
    node_importance = feature_importance[:total_node_features].reshape(num_nodes, num_node_features).sum(axis=0)
    edge_importance = feature_importance[total_node_features:].reshape(-1, num_edge_features).sum(axis=0)

    return accuracy, f1, recall, node_importance, edge_importance

# Cross-validation
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=10, random_state=42)
model = xgb.XGBClassifier(n_estimators=100)
cv_accuracy_scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy')
cv_f1_scores = cross_val_score(model, X, y, cv=cv, scoring='f1')
cv_recall_scores = cross_val_score(model, X, y, cv=cv, scoring='recall')

print(f"Cross-validation accuracy: {cv_accuracy_scores.mean():.4f} (+/- {cv_accuracy_scores.std() * 2:.4f})")
print(f"Cross-validation F1 score: {cv_f1_scores.mean():.4f} (+/- {cv_f1_scores.std() * 2:.4f})")
print(f"Cross-validation recall: {cv_recall_scores.mean():.4f} (+/- {cv_recall_scores.std() * 2:.4f})")

# Multiple runs with different random states
n_runs = 10
accuracies = []
f1_scores = []
recalls = []
node_importances = []
edge_importances = []

for i in range(n_runs):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=i, stratify=y)
    accuracy, f1, recall, node_importance, edge_importance = evaluate_model(X_train, X_test, y_train, y_test, i)

    accuracies.append(accuracy)
    f1_scores.append(f1)
    recalls.append(recall)
    node_importances.append(node_importance)
    edge_importances.append(edge_importance)

# Print results
print(f"\nAverage accuracy: {np.mean(accuracies):.4f} (+/- {np.std(accuracies):.4f})")
print(f"Average F1 score: {np.mean(f1_scores):.4f} (+/- {np.std(f1_scores):.4f})")
print(f"Average recall: {np.mean(recalls):.4f} (+/- {np.std(recalls):.4f})")

print("\nAverage Node Feature Importances:")
for i, importance in enumerate(np.mean(node_importances, axis=0)):
    print(f"Node Feature {i+1}: {importance:.4f}")

print("\nAverage Edge Feature Importances:")
for i, importance in enumerate(np.mean(edge_importances, axis=0)):
    print(f"Edge Feature {i+1}: {importance:.4f}")

# Identify most important features
node_importance_mean = np.mean(node_importances, axis=0)
edge_importance_mean = np.mean(edge_importances, axis=0)

top_node_features = np.argsort(node_importance_mean)[-4:][::-1]
top_edge_features = np.argsort(edge_importance_mean)[-3:][::-1]

print("\nTop 3 Most Important Node Features:")
for i, feature in enumerate(top_node_features):
    print(f"{i+1}. Node Feature {feature+1}: {node_importance_mean[feature]:.4f}")

print("\nTop 3 Most Important Edge Features:")
for i, feature in enumerate(top_edge_features):
    print(f"{i+1}. Edge Feature {feature+1}: {edge_importance_mean[feature]:.4f}")

node_feature_names = [
    'HbO Max', 'HbO Min', 'HbO Mean', 'HbO Std',
    'HbO Slope', 'HbO Wavelet',
    'HbR Max', 'HbR Min', 'HbR Mean', 'HbR Std',
    'HbR Slope', 'HbR Wavelet'
]

# Edge feature names
edge_feature_names = [
    'Covariance', 'Pearson Correlation', 'Spearman Correlation',
    'Kendall Tau', 'Distance Correlation', ' Dynamic Time Wraping Distance', 'PLI',
    'Coherence', 'PLV' , 'Cross Correlation'
]

# Calculate mean importances
node_importance_mean = np.mean(node_importances, axis=0)
edge_importance_mean = np.mean(edge_importances, axis=0)

def plot_feature_importance(feature_names, importances, title):
    # Check if the number of feature names matches the number of importances
    if len(feature_names) != len(importances):
        print(f"Warning: Number of feature names ({len(feature_names)}) doesn't match the number of importances ({len(importances)}).")
        feature_names = [f"Feature {i+1}" for i in range(len(importances))]

    # Sort features by importance
    sorted_idx = np.argsort(importances)[::-1]  # Sort in descending order
    sorted_names = [feature_names[i] for i in sorted_idx]
    sorted_importances = importances[sorted_idx]

    # Create the plot
    plt.figure(figsize=(10, 6))
    bars = plt.barh(range(len(sorted_importances)), sorted_importances)

    # Set colors and labels
    for bar in bars:
        bar.set_color('skyblue')
    plt.yticks(range(len(sorted_importances)), sorted_names)
    plt.xlabel('Feature Importance')
    plt.title(title)

    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()

    # Print top 5 features
    print(f"\nTop 5 Most Important {title}:")
    for i in range(min(5, len(sorted_names))):
        print(f"{i+1}. {sorted_names[i]}: {sorted_importances[i]:.4f}")

# Plot node features
plot_feature_importance(node_feature_names, node_importance_mean, "Node Feature Importance for MCI vs HC Classification")

# Plot edge features
plot_feature_importance(edge_feature_names, edge_importance_mean, "Edge Feature Importance for MCI vs HC Classification")

# Print the shapes of importances for debugging
print(f"\nShape of node_importance_mean: {node_importance_mean.shape}")
print(f"Shape of edge_importance_mean: {edge_importance_mean.shape}")
