In [30]:
import pandas as pd
import numpy as np
# import metrics mae, rmse
import os
from sklearn.metrics import mean_absolute_error, mean_squared_error

ckpt_dir = "/vol/aimspace/users/sdm/Projects/WholeBodyRL/logs/lightning_logs/version_216"
test_pred_df = pd.read_csv(os.path.join(ckpt_dir, "test_preds.csv"))

pred = test_pred_df["pred"].values
target = test_pred_df["target"].values

train_df = pd.read_csv("/u/home/sdm/GitHub/WholeBodyRL/configs/data_files/labels_healthy_age.csv")



In [31]:
print("MAE: ", mean_absolute_error(target, pred))


MAE:  1.2793437926627669


In [32]:
len(pred)

3553

In [33]:
train_df["pred"] = pred
train_df["target"] = target


In [34]:
train_df

Unnamed: 0,eid,age,split,pred,target
0,1000456,68.4,train,67.599335,68.400002
1,1003657,56.0,train,56.461300,56.000000
2,1007569,84.7,train,83.312622,84.699997
3,1010183,64.6,train,63.904839,64.599998
4,1011243,48.0,train,48.592815,48.000000
...,...,...,...,...,...
3548,5974244,69.2,val,64.278450,69.199997
3549,5985123,56.0,val,61.945686,56.000000
3550,6005574,68.2,val,63.663387,68.199997
3551,6008012,54.9,val,61.741096,54.900002


In [36]:
# calculate mae, rmse
for split in ["train", "val"]:
    X = train_df[train_df["split"] == split]["age"]
    y = train_df[train_df["split"] == split]["pred"]
    mae = mean_absolute_error(X, y)
    mae_corrected = mean_absolute_error(X, train_df[train_df["split"] == split]["pred_corrected"])
    print(f"Split: {split}, MAE: {mae}, MAE corrected: {mae_corrected}")


Split: train, MAE: 0.8168295441694304, MAE corrected: 0.710178447747656
Split: val, MAE: 3.1248534234722016, MAE corrected: 2.9970815445540055


In [35]:
from sklearn.linear_model import LinearRegression

def perform_linear_correction(df_results:pd.DataFrame, column_predictions:str, column_target:str):
    """
        This function takes a dataframe with the results in the column column_predictions 
        and the actual age of all subjects and returns a new dataframe with additional 
        columns including:
          - the age gap ("age_gap_raw")
          - the linearly corrected age gap ("age_gap_corrected")
          - the corrected prediction ("<column_predictions>_corrected")
        df_results must additionally include a column called "split" where the split of each data point is indicated

        df_results: data frame with the results
        column_predictions: the column with the predictions
        split_to_use_as_train: the split that is used as training set for the linear regression

        NOTE: the column column_predictions should not contain NaNs
    """

    df_results[column_target + "_gap_raw"] = [x-y for x, y in zip(df_results[column_predictions], df_results[column_target])]
    x_train = df_results[df_results["split"].isin(["train"])]["age"].values.reshape(-1, 1)
    y_train = df_results[df_results["split"].isin(["train"])]["age_gap_raw"].values

    reg_wholebody = LinearRegression().fit(x_train, y_train)

    new_age_gaps = reg_wholebody.predict(df_results[column_target].values.reshape(-1, 1))

    str_pred_corrected = column_predictions + "_corrected"
    df_results[str_pred_corrected] = df_results[column_predictions] - new_age_gaps

    df_results[column_target+"_gap_corrected"] = [x-y for x, y in zip(df_results[str_pred_corrected], df_results[column_target])]

    return df_results

train_df = perform_linear_correction(train_df, "pred", "age")

In [37]:
from scipy.stats import gmean
def discretize_distribution(distribution, bins):
    # Convert continuous distribution to discrete by fixed bins
    counts, bin_edges = np.histogram(distribution, bins=bins)
    discrete_distribution = (bin_edges[:-1] + bin_edges[1:]) / 2
    return discrete_distribution, counts, bin_edges

def classify_bin_counts(counts):
    # Classify each bin's count independently
    return ["few" if count < 20 else "medium" if 20 <= count <= 100 else "many" for count in counts]

def create_grouped_region_mapping(classifications):
    # Create a dictionary grouping indices by their classifications
    mapping = {"few": [], "medium": [], "many": []}
    for index, region in enumerate(classifications):
        mapping[region].append(index)
    return mapping

def map_test_samples_to_regions(test_samples, bin_edges, region_mapping,bin_max):
    # Determine which bin each test sample falls into
    bin_indices = np.digitize(test_samples, bin_edges, right=False) - 1 
    regions = []
    for bin_index in bin_indices:
        if bin_index < 0 or bin_index >= bin_max:
            # Directly assign out-of-range values to the 'few' category.
            regions.append('few')
        else:
            # Find the appropriate region for in-range values.
            region = next((key for key, indices in region_mapping.items() if bin_index in indices), 'few')
            regions.append(region)
    return regions

def testing_shots_regions(pred,label,training_distribution):
    if(training_distribution.shape[0]<= 1000):
        Bins = 10
    else:
        Bins = 50
    
    if(len(list(set(training_distribution)))<Bins):
        Bins = len(list(set(training_distribution.numpy())))
        
    # Discretize the training distribution
    discrete_training, counts_training, bin_edges = discretize_distribution(training_distribution, bins=Bins)
    # Classify each bin based on its count
    classification = classify_bin_counts(counts_training)
    # Create a mapping from index to region
    region_mapping  = create_grouped_region_mapping(classification)
    
    # Map test samples to regions
    label_category = np.array(map_test_samples_to_regions(label, bin_edges, region_mapping,bin_max = Bins))
    
    pred_shot = {'many': [], 'medium': [], 'few': [], 'overall': []}
    label_shot = {'many': [], 'medium': [], 'few': [], 'overall': []}
    metric = {'many': {}, 'medium': {}, 'few': {}, 'overall': {}}
    
    for shot in ['overall', 'many', 'medium', 'few']:
        pred_shot[shot] = np.array(pred)[label_category == shot] if shot != 'overall' else np.array(pred)
        label_shot[shot] = np.array(label)[label_category == shot] if shot != 'overall' else np.array(label)
        metric[shot]['rmse'] = np.sqrt(np.mean((pred_shot[shot] - label_shot[shot]) ** 2)) if pred_shot[shot].size > 0 else 0.
        metric[shot]['mae'] = np.mean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0.
        if pred_shot[shot].size <= 0:
            metric[shot]['gmean'] = 0.
        else:
            diff = np.abs(pred_shot[shot] - label_shot[shot])
            if diff[diff == 0.].size:
                diff[diff == 0.] += 1e-10
                metric[shot]['gmean'] = gmean(diff) if pred_shot[shot].size > 0 else 0.
            else:
                metric[shot]['gmean'] = gmean(np.abs(pred_shot[shot] - label_shot[shot])) if pred_shot[shot].size > 0 else 0.
        metric[shot]['num_samples'] = pred_shot[shot].size
    task_metrics = metric
    
    print(f" * Overall: RMSE {task_metrics['overall']['rmse']:.3f}\tG-Mean {task_metrics['overall']['gmean']:.3f}\tMAE {task_metrics['overall']['mae']:.3f}")
    print(f" * Many: RMSE {task_metrics['many']['rmse']:.3f}\t"
          f"G-Mean {task_metrics['many']['gmean']:.3f}"
          f"\tMAE {task_metrics['many']['mae']:.3f}")
    print(f" * Median: RMSE {task_metrics['medium']['rmse']:.3f}\t"
          f"G-Mean {task_metrics['medium']['gmean']:.3f}"
          f"\tMAE {task_metrics['medium']['mae']:.3f}")
    print(f" * Low: RMSE {task_metrics['few']['rmse']:.3f}\t"
          f"G-Mean {task_metrics['few']['gmean']:.3f}"
          f"\tMAE {task_metrics['few']['mae']:.3f}")

    return task_metrics

In [40]:
res = testing_shots_regions(train_df[train_df["split"]=="val"]["pred"], train_df[train_df["split"]=="val"]["age"], train_df[train_df["split"]=="train"]["age"])

 * Overall: RMSE 3.923	G-Mean 2.033	MAE 3.125
 * Many: RMSE 3.758	G-Mean 2.070	MAE 3.027
 * Median: RMSE 3.924	G-Mean 1.983	MAE 3.107
 * Low: RMSE 5.559	G-Mean 3.815	MAE 4.947


In [41]:
# visualize mae for each region as dataframe with columns overall, many, medium, few and mae as rows
mae = pd.DataFrame(res).T[["mae"]].T
# change the order of the columns
mae = mae[["overall", "many", "medium", "few"]]
# only show 2 decimal places
mae = mae.round(2)
mae

Unnamed: 0,overall,many,medium,few
mae,3.12,3.03,3.11,4.95
