## Helper

In [None]:
import wandb 
import json
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
select_group  = lambda runs,group:[run for run in runs if run.group == group]
def calc_mean_std_auc(runs_for_auc):
    auc = [runs.history()["auc"].dropna() for runs in runs_for_auc]
    return round(np.array(auc).mean(),3),round(np.array(auc).std(),3)

def plot_loss(runs,metric="loss"):
    if metric == "loss":
        metric_train = "Loss train"
        metric_val = "Loss val"
        label = "Loss"
    elif metric == "auc":
        metric_train = "auc_train"
        metric_val = "auc_val"
        label = "Auc"
    loss_train_val = pd.concat([run.history().loc[:,[metric_train,metric_val,"epoch"]].dropna() for run in runs])
    mean_std = loss_train_val.groupby("epoch").agg(["mean","std"])
    mean_std.reset_index(inplace=True)
    df_melted = mean_std.melt(id_vars='epoch', var_name='Metric', value_name='Value')

    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))
    plot = sns.lineplot(x='epoch', y='Value', hue='Metric', style='Metric', markers=True, data=df_melted)
    plot.set(xlabel='Epoch', ylabel=label, title=f'{label} Train vs. {label} Validation')
    plt.legend(title='Data', loc='upper right')
    plt.show()
    
def plot_auc(runs_diff_group,auc_var='auc'):
    auc_values_list = [[run.history()[auc_var].dropna().values[0] for run in runs] for runs in runs_diff_group]

    names = [runs[0].group for runs in  runs_diff_group]
    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 6))

    for i, auc_values in enumerate(auc_values_list):
        print(np.round(np.array(auc_values).mean(),4))
        print(np.round(np.array(auc_values).std(),4))
        stripplot = sns.stripplot(x=[i] * len(auc_values), y=auc_values, jitter=False, dodge=False, alpha=1, size=8)
        sns.pointplot(x=[i] * len(auc_values), y=auc_values, errorbar="sd", markers="o", capsize=0.2, label=names[i])

    # Draw a horizontal line at 0.5
    plt.axhline(y=0.5, color='red', linestyle='--', label='Random (0.5)')

    # Customize the plot
    # Insert line breaks after every 10 characters
    max_chars = 9
    formatted_labels = [name[:max_chars] + '\n' + name[max_chars:] if len(name) > max_chars else name for name in names]
    # plt.xticks(range(len(auc_values_list)), [names[i] for i in range(len(auc_values_list))])
    plt.xticks(range(len(auc_values_list)), formatted_labels)

    plt.xlabel('Runs')
    plt.ylabel('AUC')
    plt.title('AUC-values with 1 standard deviation')
    plt.ylim(0, 1)  # Set y-axis limits to start from 0 and go up to 1
    plt.legend()
    plt.show()

## MRI

In [None]:
runs = wandb.Api().runs(f"pro5d-classification-prolactinoma/MRI-Modell")
baseline_mlp_runs = select_group(runs,"Baseline MLP")
baseline_mlp_test_runs = select_group(runs,"Baseline_MLP_Test") 

baseMed3d_weighted = select_group(runs,"BaseMed3d_weighted")

baseline_resnet = select_group(runs,"Baseline Resnet")
resnet_weighted  = [run for run in select_group(runs,"Weighted") if run.config["model architecture"] == "Resnet 18"]


resnet_augmented_weighted = select_group(runs,"Augmented_weighted")
resnet_augmented_weighted_freezed = select_group(runs,"Augmented_weighted_freezed")

resnet_augmented_weighted_test  = select_group(runs,"Augmented_weighted_Test")

## Baseline MLP

In [None]:
plot_loss(baseline_mlp_runs,"loss")
plot_loss(baseline_mlp_runs,"auc")

+ ROC Curve Test

### MED3D

In [None]:
plot_loss(baseMed3d_weighted)
plot_loss(baseMed3d_weighted,"auc")

## Resnet

In [None]:
plot_loss(resnet_augmented_weighted,"loss")


In [None]:
plot_loss(resnet_augmented_weighted,"auc")

## Test AUC

In [None]:
plot_auc([baseline_mlp_runs,baseline_resnet,resnet_weighted,resnet_augmented_weighted,resnet_augmented_weighted_freezed])

In [None]:
plot_auc([baseline_mlp_test_runs,resnet_augmented_weighted_test])

## LAB

In [None]:
runs = wandb.Api().runs(f"pro5d-classification-prolactinoma/tabular-data")
baseline_logreg  = [run for run in select_group(runs,"Tab-Data-LogReg-Data-Pairs")]
baseline_rf  = [run for run in select_group(runs,"Tab-Data-RandomForest-Data-Pairs")]
xgboost = [run for run in select_group(runs,"Tab-Data-XGBoost-Data-Pairs")]
xgboost_ = [run for run in select_group(runs,"Tab-Data-XGBoost-Data-Pairs-noFT4")]
xgboost__ = [run for run in select_group(runs,"Tab-Data-XGBoost-Data-Pairs-ohneCOR")]


In [None]:
plot_auc([baseline_logreg,baseline_rf],'auc-test')
plot_auc([baseline_logreg,baseline_rf,xgboost_],'auc-test')
plot_auc([xgboost,xgboost__,xgboost_],'auc-test')

In [None]:
runs = wandb.Api().runs(f"pro5d-classification-prolactinoma/tabular-data")
learning_curve_runs  = [run for run in select_group(runs,"Tab-Data-XGBoost-Learning_curve")]
learning_curve_sizes = []
auc_values = []

for run in learning_curve_runs:
    # Fetch the learning_curve_size and auc values from each run
    run_data = run.history()
    if "learning_curve_val_auc" in run_data.columns:
        learning_curve_sizes.append(run_data["learning_curve_train_size"].dropna().values)
        auc_values.append(run_data["learning_curve_val_auc"].dropna().values)

In [None]:
values = pd.DataFrame({"learning_curve_sizes":np.array(learning_curve_sizes).reshape(-1),"auc_values":np.array(auc_values).reshape(-1)})
mean = values.groupby("learning_curve_sizes").mean().reset_index()
std = values.groupby("learning_curve_sizes").std().reset_index()

In [None]:
learning_curve_sizes = mean["learning_curve_sizes"]
mean_auc_values = mean["auc_values"]
std_auc_values = std["auc_values"]

plt.figure(figsize=(10, 6))
# Create a barplot with error bars representing standard deviation using sns.barplot
plt.errorbar(x=learning_curve_sizes, y=mean_auc_values, yerr=std_auc_values, fmt='o', color='black', label='Std dev')

# Add legend
plt.legend()

# Set labels for axes
plt.xlabel('Train data size')
plt.ylabel('AUC')

# Set title for the plot
plt.title('Learning curve for XGBoost')

# Show the plot
plt.show()



In [None]:
runs = wandb.Api().runs(f"pro5d-classification-prolactinoma/tabular-data")
permutation_curve_runs  = [run for run in select_group(runs,"Tab-Data-XGBoost-Data-Pairs-Permutation")]

permutation_importance_auc_LHs = []
permutation_importance_auc_CORs = []
permutation_importance_auc_FSHs = []
permutation_importance_auc_FT4s = []
permutation_importance_auc_IGF1s = []
permutation_importance_auc_LHs = []
permutation_importance_auc_PROLs = []
permutation_importance_auc_Patient_ages = []
permutation_importance_auc_TESTOs = []
permutation_importance_auc_Patient_genders = []

for run in permutation_curve_runs:
    # Fetch the learning_curve_size and auc values from each run
    run_data = run.history()


    if run_data.columns.str.contains(r"permutation_importance_auc_*",regex=True).any():
        #get the permutation importance values for each feature
        permutation_importance_auc_CORs.append(run_data["permutation_importance_auc_COR"].dropna().values)  
        permutation_importance_auc_FSHs.append(run_data["permutation_importance_auc_FSH"].dropna().values)
        permutation_importance_auc_FT4s.append(run_data["permutation_importance_auc_FT4"].dropna().values)
        permutation_importance_auc_IGF1s.append(run_data["permutation_importance_auc_IGF1"].dropna().values)
        permutation_importance_auc_LHs.append(run_data["permutation_importance_auc_LH"].dropna().values)
        permutation_importance_auc_PROLs.append(run_data["permutation_importance_auc_PROL"].dropna().values)
        permutation_importance_auc_Patient_ages.append(run_data["permutation_importance_auc_Patient_age"].dropna().values)
        permutation_importance_auc_TESTOs.append(run_data["permutation_importance_auc_TEST"].dropna().values)
        permutation_importance_auc_Patient_genders.append(run_data["permutation_importance_auc_Patient_gender_male"].dropna().values)

data = {
    'LH': np.array(permutation_importance_auc_LHs).reshape(-1),
    'CORTISOL': np.array(permutation_importance_auc_CORs).reshape(-1),
    'FSH': np.array(permutation_importance_auc_FSHs).reshape(-1),
    'FT4': np.array(permutation_importance_auc_FT4s).reshape(-1),
    'IGF1': np.array(permutation_importance_auc_IGF1s).reshape(-1),
    'PROLACTIN': np.array(permutation_importance_auc_PROLs).reshape(-1),
    'Patient_age': np.array(permutation_importance_auc_Patient_ages).reshape(-1),
    'TESTOSTERONE': np.array(permutation_importance_auc_TESTOs).reshape(-1),
    'Patient_gender': np.array(permutation_importance_auc_Patient_genders).reshape(-1)
}

# Creating the DataFrame
df = pd.DataFrame(data)
# Calculate mean and standard deviation for each column
mean_values = df.mean()
std_values = df.std()
sorted_features = mean_values.sort_values(ascending=False).index
# Plotting with plt.errorbar
plt.figure(figsize=(10, 6))
features = mean_values.index
plt.errorbar(x=sorted_features, y=mean_values[sorted_features].values, yerr=std_values[sorted_features], fmt='o', color='black', label='Std dev')
plt.title('Permutation Feature Importance - XGBoost')
plt.ylabel('Feature Importance')
plt.xticks(rotation=45, ha='right')
plt.legend()

# Show the plot
plt.tight_layout()
plt.show()

In [None]:
runs = wandb.Api().runs(f"pro5d-classification-prolactinoma/tabular-data")
permutation_curve_runs  = [run for run in select_group(runs,"Tab-Data-RandomForest-Data-Pairs-Permutation")]

permutation_importance_auc_LHs = []
permutation_importance_auc_CORs = []
permutation_importance_auc_FSHs = []
permutation_importance_auc_FT4s = []
permutation_importance_auc_IGF1s = []
permutation_importance_auc_LHs = []
permutation_importance_auc_PROLs = []
permutation_importance_auc_Patient_ages = []
permutation_importance_auc_TESTOs = []
permutation_importance_auc_Patient_genders = []

for run in permutation_curve_runs:
    # Fetch the learning_curve_size and auc values from each run
    run_data = run.history()


    if run_data.columns.str.contains(r"permutation_importance_auc_*",regex=True).any():
        #get the permutation importance values for each feature
        permutation_importance_auc_CORs.append(run_data["permutation_importance_auc_COR"].dropna().values)  
        permutation_importance_auc_FSHs.append(run_data["permutation_importance_auc_FSH"].dropna().values)
        permutation_importance_auc_FT4s.append(run_data["permutation_importance_auc_FT4"].dropna().values)
        permutation_importance_auc_IGF1s.append(run_data["permutation_importance_auc_IGF1"].dropna().values)
        permutation_importance_auc_LHs.append(run_data["permutation_importance_auc_LH"].dropna().values)
        permutation_importance_auc_PROLs.append(run_data["permutation_importance_auc_PROL"].dropna().values)
        permutation_importance_auc_Patient_ages.append(run_data["permutation_importance_auc_Patient_age"].dropna().values)
        permutation_importance_auc_TESTOs.append(run_data["permutation_importance_auc_TEST"].dropna().values)
        permutation_importance_auc_Patient_genders.append(run_data["permutation_importance_auc_Patient_gender"].dropna().values)

data = {
    'LH': np.array(permutation_importance_auc_LHs).reshape(-1),
    'CORTISOL': np.array(permutation_importance_auc_CORs).reshape(-1),
    'FSH': np.array(permutation_importance_auc_FSHs).reshape(-1),
    'FT4': np.array(permutation_importance_auc_FT4s).reshape(-1),
    'IGF1': np.array(permutation_importance_auc_IGF1s).reshape(-1),
    'PROLACTIN': np.array(permutation_importance_auc_PROLs).reshape(-1),
    'Patient_age': np.array(permutation_importance_auc_Patient_ages).reshape(-1),
    'TESTOSTERONE': np.array(permutation_importance_auc_TESTOs).reshape(-1),
    'Patient_gender': np.array(permutation_importance_auc_Patient_genders).reshape(-1)
}

# Creating the DataFrame
df = pd.DataFrame(data)
# Calculate mean and standard deviation for each column
mean_values = df.mean()
std_values = df.std()
sorted_features = mean_values.sort_values(ascending=False).index
# Plotting with plt.errorbar
plt.figure(figsize=(10, 6))
features = mean_values.index
plt.errorbar(x=sorted_features, y=mean_values[sorted_features].values, yerr=std_values[sorted_features], fmt='o', color='black', label='Std dev')
plt.title('Permutation Feature Importance - Random Forest')
plt.ylabel('Feature Importance')
plt.xticks(rotation=45, ha='right')
plt.legend()

# Show the plot
plt.tight_layout()
plt.show()