# Downstream Evaluation

Used to generate plots for downstream performance evaluation.
You need to provide a `.csv`-file with the following columns: `[level,group,fold,metric,value]`.
For example like this:

```csv
level,group,fold,metric,value
overall,overall,0,accuracy,0.6460431654676259
overall,overall,0,balanced_accuracy,0.6463224565592254
overall,overall,0,roc_auc,0.7061986006101605
overall,overall,0,precision,0.663768115942029
overall,overall,0,recall,0.637883008356546
overall,overall,0,f1_score,0.6505681818181818
dataset,INTERNAL,0,accuracy,0.6287051482059283
dataset,INTERNAL,0,balanced_accuracy,0.6272424598511555
dataset,INTERNAL,0,roc_auc,0.6942322757540149
dataset,INTERNAL,0,precision,0.6578171091445427
dataset,INTERNAL,0,recall,0.6463768115942029
dataset,INTERNAL,0,f1_score,0.652046783625731
dataset,BTXRD,0,accuracy,0.6608811748998665
dataset,BTXRD,0,balanced_accuracy,0.6607580856768012
dataset,BTXRD,0,roc_auc,0.7119032000456335
dataset,BTXRD,0,precision,0.6695156695156695
dataset,BTXRD,0,recall,0.6300268096514745
dataset,BTXRD,0,f1_score,0.649171270718232
entity,Enchondrom,0,accuracy,0.6296296296296297
...
```



## Setup Functions and so on

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sklearn

### Get Computed Evaluation

These are precomputed metrics on the overall dataset and on each group for each metadata attribute (age, gender, anatomy-site) (as shown above).
This is nice, so, we don't have to compute the metrics in this notebook.

In [None]:
# load pd dataframe
# file = "../../evaluation/baseline/only_imaging_baseline_resnet34.csv" # ResNet34 Baseline Only Imaging
# file = "../../evaluation/baseline/imaging_and_clinical_baseline_resnet34.csv" # ResNet34 Baseline Imaging+Clinical
# file = "../../evaluation/baseline/only_imaging_baseline_nest_small.csv" # Nest Small Baseline Only Imaging
# file = "../../evaluation/baseline/imaging_and_clinical_baseline_nest_small.csv" # Nest Small Baseline Imaging+Clinical
# file = "../../evaluation/baseline_pretrained/only_imaging_pretrained_baseline_resnet50.csv" # ResNet50 Pretrained Baseline Only Imaging
# file = "../../evaluation/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50.csv" # ResNet50 Pretrained Baseline Imaging+Clinical
# file = "../../evaluation/vlp/linear_probe_only_imaging_resnet34.csv" # ResNet34 VLP Linear Probe Only Imaging
# file =  "../../evaluation/finetune/only_imaging_finetune_resnet34.csv" # ResNet34 Finetune Only Imaging
file = "../../evaluation/finetune/imaging_and_clinical_finetune_resnet34.csv" # ResNet34 Finetune Imaging+Clinical

df = pd.read_csv(file)
df.head()

In [None]:
# get number of folds
number_of_folds = np.unique(df["fold"]).shape[0]
print(f"Number of folds: {number_of_folds}")

### Get Raw Predictions

Compared to the already computed metrics, we also get the raw predictions. We could compute the metrics from these raw predictions, but that's already done by the evaluation script.
Nevertheless, for the creation of a confusion matrix the raw predictions are needed, that's why we load them here as well.

In [None]:
# I only saved predictions for overal two best models: pretrained baseline resnet50 imaging+clinical and finetuned resnet34 imaging+clinical

# predictions_file = "../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_0.csv" # best fold on val (we should not decide based on test set performance for obvious reasons)
predictions_file = "../../predictions/finetune/imaging_and_clinical_finetune_resnet34/predictions_fold_0.csv" # best fold on val (we should not decide based on test set performance for obvious reasons)

pred_df = pd.read_csv(predictions_file)
# get actual predictions from probs
pred_df['pred'] = pred_df['prob'].apply(lambda x: 1 if x >= 0.5 else 0)
pred_df.head()

### Clean

In [None]:
def clean(df):
    # remove rows, where value is NaN
    clean_df = df[~df["value"].isna()]
    print(f"Number of rows removed: {len(df) - len(clean_df)}")
    return clean_df

clean_df = clean(df)

#### Get mean and std over folds

In [None]:

def average_over_folds(clean_df):
    number_of_folds = np.unique(clean_df["fold"]).shape[0]
    avg_df = clean_df.groupby(["level", "group", "metric"]).agg({'value': ['mean', 'std']}).reset_index()

    assert len(avg_df) == len(clean_df) / number_of_folds, f"Expected {len(clean_df) / number_of_folds} rows, but got {len(avg_df)} rows."

    return avg_df

avg_df = average_over_folds(clean_df)
avg_df.head()

### Get occurences

In [None]:
# get occurences of different attributes
occurences_df = pd.read_csv("../../evaluation/occurrences.csv")
occurences_df.head()

In [None]:
# normalize counts, s.t. for each group its sums up to 1
# just get the total count over the dataset attribute, it is the same total count for each attribute
total_count = occurences_df[occurences_df["attribute"] == "dataset"]["count"].sum()
occurences_df["normalized_count"] = occurences_df["count"] / total_count
occurences_df.head()


### Metadata

This is similar to occurences. But different in two ways:
-  occurences has already computed values for each group (e.g. dataset: INTERNAL, dataset: BTXRD, entity: ostechondroma, etc.), wheras metadata is just a entry for every sample with the metadata
-  this metadata df is from the entire downstream data, wheras the occurence is just based on the test set

In [None]:
metadata_df = pd.read_csv("../../visualizations/data/downstream/metadata.csv")
metadata_df = metadata_df.drop(['Unnamed: 0'], axis=1)
metadata_df.head()

## Single Experiment Evaluation

### Overall

In [None]:
# print the overal metrics with mean and std
print("Overall metrics:")
overall_metrics = avg_df[avg_df["level"] == "overall"]
print(overall_metrics)

### Plotting Function Preparation

In [None]:
def plot_metrics_per_group(avg_dfs, group, group_name=None, occurences_df=None, order_by_occurrence=False, exclude_most_common=False, df_titles=None, show_values=True):
    """
    Plot metrics per group for one or multiple dataframes.
    
    Args:
        avg_dfs: Single dataframe or list of dataframes to plot (each representing a different experiment)
        group: Group to plot metrics for (e.g., 'entity', 'anatomy_site')
        group_name: Optional display name for the group
        occurences_df: Dataframe with occurrence data (optional)
        order_by_occurrence: Whether to order bars by occurrence (requires occurences_df)
        exclude_most_common: Whether to exclude most common group from color normalization
        df_titles: Optional list of titles for each experiment (if multiple)
    """
    # Handle single dataframe case
    if not isinstance(avg_dfs, list):
        avg_dfs = [avg_dfs]
    
    # Setup df_titles if not provided
    n_dfs = len(avg_dfs)
    if df_titles is None:
        df_titles = [f"Experiment {i+1}" for i in range(n_dfs)]
    else:
        assert len(df_titles) == n_dfs, "Number of titles must match number of experiments"

    if order_by_occurrence:
        assert occurences_df is not None, "To order by occurrence, occurrence data must be provided"
    
    if exclude_most_common:
        assert occurences_df is not None, "To exclude most common, occurrence data must be provided"

    # Get unique group values across all dataframes first
    all_group_values = set()
    for avg_df in avg_dfs:
        group_metrics_df = avg_df[avg_df["level"] == group]
        group_values = group_metrics_df["group"].unique()
        all_group_values.update(group_values)
    
    # Get metrics data for each dataframe
    all_metrics_dfs = []
    for df_idx, avg_df in enumerate(avg_dfs):
        # Filter for group metrics
        metrics_df = avg_df[avg_df["level"] == group]
        metrics_df = metrics_df.copy()
        metrics_df["df_idx"] = df_idx  # Add source dataframe index
        metrics_df["df_title"] = df_titles[df_idx]  # Add source dataframe title
        all_metrics_dfs.append(metrics_df)
    
    # Combine all dataframes
    combined_metrics_df = pd.concat(all_metrics_dfs, ignore_index=True)
    
    # Handle occurrence data
    use_heatmap = False
    order = None
    most_common_group = None
    
    if occurences_df is not None:
        # Get occurrence data for this group
        group_occurrences = occurences_df[occurences_df["attribute"] == group]
        
        # Only proceed with heatmap if we have matching occurrence data
        if len(group_occurrences) > 0:
            use_heatmap = True
            
            # Use normalized_count if available, otherwise use count
            count_col = "normalized_count" if "normalized_count" in group_occurrences.columns else "count"
            
            # Find the most common group
            if exclude_most_common:
                most_common_group = group_occurrences.loc[group_occurrences[count_col].idxmax(), "value"]
            
            # Create a mapping of group values to counts
            count_map = dict(zip(group_occurrences["value"], group_occurrences[count_col]))
            
            # Add count information to metrics_df
            combined_metrics_df["count"] = combined_metrics_df["group"].map(count_map)
            
            # Set order based on occurrence if requested
            if order_by_occurrence:
                # Sort group occurrences by count for ordering
                order = group_occurrences.sort_values(by=count_col)["value"].tolist()
    
    # Get unique metrics and unique group values across all dataframes
    metrics = combined_metrics_df["metric"].unique()
    
    # Create a figure for each metric
    plots = {}
    for metric in metrics:
        # Filter for current metric across all dataframes
        metric_data = combined_metrics_df[combined_metrics_df["metric"] == metric]
        
        # Create plot
        plt.figure(figsize=(14, 8)).set_dpi(600)
        
        # Create palette for heatmap if needed
        palette = None
        if use_heatmap:
            if "count" in metric_data.columns:
                # Prepare data for normalization
                if exclude_most_common:
                    # Create a mask for non-most-common groups
                    non_most_common_mask = metric_data["group"] != most_common_group
                    
                    # Get min and max counts only from non-most-common groups
                    min_count = metric_data.loc[non_most_common_mask, "count"].min()
                    max_count = metric_data.loc[non_most_common_mask, "count"].max()
                else:
                    min_count = metric_data["count"].min()
                    max_count = metric_data["count"].max()
                
                # Map each group value to a color based on its count
                norm = plt.Normalize(min_count, max_count)
                cmap = plt.cm.coolwarm
                
                if n_dfs > 1:
                    # For multiple dataframes, the palette needs to map to df_titles
                    # Create a base color for each group
                    group_colors = {}
                    for g in metric_data["group"].unique():
                        if exclude_most_common and g == most_common_group:
                            group_colors[g] = 'black'
                        else:
                            g_count = metric_data[metric_data["group"] == g]["count"].iloc[0]
                            group_colors[g] = cmap(norm(g_count))
                    
                    # Now create a palette that maps each df_title to its appropriate color
                    # For seaborn's catplot with hue, we need the palette to be keyed by the hue values
                    palette = {}
                    for title in df_titles:
                        for g in metric_data["group"].unique():
                            # Create a shade of the group color for each dataset
                            base_color = group_colors[g]
                            # We use the same color for all datasets for each group
                            # as we're coloring by group occurrence, not by dataset
                            palette[title] = base_color
                else:
                    # Single dataframe case - map group values to colors
                    palette = {}
                    for g in metric_data["group"].unique():
                        if exclude_most_common and g == most_common_group:
                            palette[g] = 'black'
                        else:
                            g_count = metric_data[metric_data["group"] == g]["count"].iloc[0]
                            palette[g] = cmap(norm(g_count))
        
        # Create the plot
        if n_dfs > 1:
            # For multiple experiments, use a different approach - use FacetGrid
            # Since there are issues with palette mapping in catplot with hue
            # We'll create a barplot on each position manually
            plt.figure(figsize=(14, 8))
            ax = plt.gca()
            
            # Get the unique groups and sort them if order is specified
            if order is not None:
                # Make sure all groups from all dataframes are included
                unique_groups = pd.Series(order).loc[pd.Series(order).isin(all_group_values)].tolist()
            else:
                unique_groups = sorted(list(all_group_values))
            
            # Define bar width and positions with more spacing
            bar_width = 0.7 / n_dfs  # Reduced from 0.8 to create more space between bars
            group_positions = np.arange(len(unique_groups))
            group_spacing = 0.05  # Add spacing between bars within a group
            
            # Define hatch patterns for different experiments - using clearly visible patterns
            hatch_patterns = ['////', '....', '\\\\\\\\', 'xxxx', '++++', 'oooo', '||||']
            
            # Create empty list to store legend handles
            legend_handles = []
            
            # For each experiment, add a set of bars
            for df_idx, df_title in enumerate(df_titles):
                # Filter data for this experiment
                df_data = metric_data[metric_data["df_title"] == df_title]
                
                # Prepare heights, errors, and positions for each group
                heights = []
                errors = []
                positions = []
                
                for i, g in enumerate(unique_groups):
                    group_data = df_data[df_data["group"] == g]
                    if not group_data.empty:
                        heights.append(group_data[("value", "mean")].iloc[0])
                        errors.append(group_data[("value", "std")].iloc[0])
                        positions.append(i)
                # Don't add this group if it's missing from this dataframe
            
                # Calculate bar positions with spacing
                bar_positions = np.array(positions) + (df_idx - n_dfs/2 + 0.5) * (bar_width + group_spacing)
                
                # Determine colors for each bar
                if use_heatmap:
                    colors = []
                    for i, pos in enumerate(positions):
                        g = unique_groups[pos]  # Get the group value using integer index
                        if exclude_most_common and g == most_common_group:
                            colors.append('black')
                        else:
                            g_count = df_data[df_data["group"] == g]["count"].iloc[0]
                            colors.append(cmap(norm(g_count)))
                else:
                    # Use a consistent color scheme (not different for each experiment)
                    colors = [plt.cm.coolwarm(0.5)] * len(heights)  # Neutral color
                
                # Define different hatch patterns for different experiments
                hatch = hatch_patterns[df_idx % len(hatch_patterns)]
                
                # Draw the bars with visible hatches but no visible borders
                bars = ax.bar(
                    bar_positions, 
                    heights, 
                    width=bar_width,
                    yerr=errors,
                    color=colors,
                    alpha=0.8,
                    capsize=5,
                    label=df_title,
                    hatch=hatch,  # Add hatch pattern to distinguish experiments
                    edgecolor="#454545",  # Use same color as fill for the edge (invisible border but hatch shows)
                    linewidth=0.5  # Thin line for hatch visibility
                )
                
                # Create a separate patch for the legend with clearly visible hatches
                
                # Create a rectangle for the legend with prominent hatches
                legend_patch = plt.Rectangle(
                    (0, 0), 1, 1, 
                    fill=True,
                    hatch=hatch*2,  # Double the hatch density for better visibility
                    label=df_title,
                    edgecolor='black',  # Add black border for better visibility
                    linewidth=0.5,  # Thin border line
                    facecolor='lightgray'  # Light gray background makes hatches more visible
                )
                legend_handles.append(legend_patch)
                
                # Add value labels
                if show_values:
                    for bar, height in zip(bars, heights):
                        ax.text(
                            bar.get_x() + bar.get_width()/2.,
                            height + 0.01,
                            f'{height:.3f}',
                            ha='center',
                            va='bottom',
                            fontsize=9
                        )
            
            # Add the legend with clear, visible rectangles
            legend = plt.legend(
                handles=legend_handles, 
                title="Experiments", 
                loc='lower right',
                framealpha=1.0,  # Solid legend background
                handlelength=1.8,  # Legend rectangle width
                handleheight=1.2,  # Legend rectangle height
                labelspacing=0.6   # Add more space between legend items
            )
            
            # Force drawing of the legend to ensure hatches are rendered properly
            plt.draw()
            
            # Access legend patches directly from the legend_handles we created
            # This avoids using the legendHandles attribute which can be problematic
            
            # Explicitly set the renderer for the legend to make sure hatches are drawn
            plt.draw()
            
            # Set x-axis ticks and labels with rotation to prevent overlap
            ax.set_xticks(group_positions)
            ax.set_xticklabels(unique_groups, rotation=45, ha='right')
            
            # Check if x-tick labels would overlap and adjust if needed
            fig = plt.gcf()
            fig.canvas.draw()
            tick_labels = ax.get_xticklabels()
            if len(tick_labels) > 0:
                # Get bounding boxes of tick labels
                bboxes = [label.get_window_extent() for label in tick_labels]
                # Check if any labels overlap
                overlap = False
                for i in range(len(bboxes)-1):
                    if bboxes[i].x1 > bboxes[i+1].x0:
                        overlap = True
                        break
                # If overlap detected, increase rotation angle
                if overlap:
                    ax.set_xticklabels(unique_groups, rotation=60, ha='right')
        
        else:
            # Single dataframe case - use barplot without visible borders
            ax = sns.barplot(
                x="group",
                y=("value", "mean"),
                data=metric_data,
                order=order,
                palette=palette,
                capsize=5,
                alpha=0.8,
                edgecolor="w",  # White border to blend with background
                linewidth=0.1  # Very thin line
            )
            
            # Add error bars manually
            for i, bar in enumerate(ax.patches):
                # Get corresponding data point
                if order is not None:
                    group_val = order[i]
                    idx = metric_data[metric_data["group"] == group_val].index[0]
                else:
                    group_val = metric_data["group"].unique()[i]
                    idx = metric_data[metric_data["group"] == group_val].index[0]
                
                # Get error value and bar height
                err = metric_data.iloc[metric_data.index.get_indexer([idx])[0]][("value", "std")]
                height = bar.get_height()
                
                # Add error bar
                ax.errorbar(
                    x=i,
                    y=height,
                    yerr=err,
                    fmt='none',
                    ecolor='black',
                    capsize=5
                )
                
                # Add value label
                if show_values:
                    ax.text(
                        bar.get_x() + bar.get_width()/2.,
                        height + 0.01,
                        f'{height:.3f}',
                        ha='center',
                        va='bottom',
                        fontsize=10
                    )
            
            # Set x-axis tick labels with no rotation by default
            plt.xticks(rotation=0, ha='center')
            
            # Check if x-tick labels would overlap and adjust if needed
            fig = plt.gcf()
            fig.canvas.draw()
            tick_labels = ax.get_xticklabels()
            if len(tick_labels) > 0:
                # Get bounding boxes of tick labels
                bboxes = [label.get_window_extent() for label in tick_labels]
                # Check if any labels overlap
                overlap = False
                for i in range(len(bboxes)-1):
                    if bboxes[i].x1 > bboxes[i+1].x0:
                        overlap = True
                        break
                # If overlap detected, use 45 degree rotation
                if overlap:
                    plt.xticks(rotation=45, ha='right')

        # Customize the plot
        title = f'{metric.replace("_", " ").title() if metric != 'roc_auc' else 'AUROC'} by {group_name if group_name else group}'
            
        plt.title(title, fontsize=16)
        plt.ylabel((metric.replace("_", " ").title() if metric != 'roc_auc' else 'AUROC'), fontsize=14)
        plt.xlabel(f'{group_name if group_name else group}', fontsize=14)
        
        # Set y-axis limit
        max_height = metric_data[("value", "mean")].max()
        plt.ylim(0, max(1.0, max_height * 1.15))
        
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add color bar for heatmap
        if use_heatmap and "count" in metric_data.columns:
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])
            cbar = plt.colorbar(sm, ax=ax)
            count_label = "Normalized occurrence" if "normalized_count" in group_occurrences.columns else "Count"
            if exclude_most_common:
                cbar.set_label(f'{count_label} of {group_name if group_name else group} (except {most_common_group})', fontsize=12)
            else:
                cbar.set_label(f'{count_label} of {group_name if group_name else group}', fontsize=12)
        
        # Add note about error bars centered under the plot
        ax.text(0.5, -0.15, 'Error bars represent standard deviation across folds', 
                transform=ax.transAxes, ha='center', va='top', fontsize=10, alpha=0.7)
        
        plt.tight_layout()
        plot_to_return = plt.gcf()
        plots[metric] = plot_to_return
        plt.show()
    return plots

In [None]:
def print_mean_and_std_over_group(df, group):
    metrics_df = avg_df[avg_df["level"] == group]

    metrics = metrics_df["metric"].unique()
    for metric in metrics:
        print(f"\nMetric: {metric}")

        metric_by_group = metrics_df[metrics_df["metric"] == metric]

        # Calculate mean and std of metric across all group types
        mean_metric = metric_by_group[("value", "mean")].mean()
        std_metric = metric_by_group[("value", "mean")].std()

        print(f"Mean ± std {metric} across all {group} types: {mean_metric:.4f} ± {std_metric:.4f}")

        # Calculate median metric
        median_metric = metric_by_group[("value", "mean")].median()
        print(f"Median {metric} across all {group} types: {median_metric:.4f}")

        # # Get min and max metric with corresponding group types
        min_metric_idx = metric_by_group[("value", "mean")].idxmin()
        max_metric_idx = metric_by_group[("value", "mean")].idxmax()

        min_group = metric_by_group.loc[min_metric_idx, "group"]
        min_value = metric_by_group.loc[min_metric_idx, ("value", "mean")]

        max_group = metric_by_group.loc[max_metric_idx, "group"]
        max_value = metric_by_group.loc[max_metric_idx, ("value", "mean")]

        print(f"Lowest {metric}: {min_value:.4f} ({min_group})")
        print(f"Highest {metric}: {max_value:.4f} ({max_group})")

### Confusion Matrix

In [None]:
labels=['No Tumor', 'Tumor']
confusion_matrix = sklearn.metrics.confusion_matrix(pred_df['tumor'], pred_df['pred'])
disp = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels)
disp = disp.plot(cmap=plt.cm.Blues)
disp.figure_.set_dpi(500)

plt.savefig("../../visualizations/data/downstream/confusion_matrix.svg")
plt.show()

### Per dataset

In [None]:
# Filter for dataset metrics
plot_metrics_per_group(avg_df, "dataset", group_name="Dataset")

### Per Entity

In [None]:
plot_metrics_per_group(avg_df, "entity", group_name="Entity", occurences_df=occurences_df, order_by_occurrence=True, exclude_most_common=True)

In [None]:
print_mean_and_std_over_group(avg_df, "entity")

### Per anatomy site

In [None]:
plot_metrics_per_group(avg_df, "anatomy_site", occurences_df=occurences_df, order_by_occurrence=True, group_name="Anatomy Site")

In [None]:
print_mean_and_std_over_group(avg_df, "anatomy_site")

### Per Sex

In [None]:
plot_metrics_per_group(avg_df, "sex", group_name="Sex")

In [None]:
print_mean_and_std_over_group(avg_df, "sex")

### Per Age (Encoded)

In [None]:
plot_metrics_per_group(avg_df, "age_encoded", occurences_df=occurences_df, group_name="Age Group")

In [None]:
print_mean_and_std_over_group(avg_df, "age_encoded")

## Multiple Experiment Evaluation

### Bar Charts for each Metric on Different variables (Anatomy Site, Age)

In [None]:
files_and_titles = [
    # ("../../evaluation/baseline/only_imaging_baseline_resnet34.csv", "ResNet34 Baseline Imaging"),
    ("../../evaluation/baseline/imaging_and_clinical_baseline_resnet34.csv", "ResNet34 Baseline Imaging+Clinical"),
    # ("../../evaluation/baseline/only_imaging_baseline_nest_small.csv", "Nest Small Baseline Imaging"),
    # ("../../evaluation/baseline/imaging_and_clinical_baseline_nest_small.csv", "Nest Small Baseline Imaging+Clinical"),
    # ("../../evaluation/baseline_pretrained/only_imaging_pretrained_baseline_resnet50.csv", "Torchxrayvision ResNet50 Imaging"),
    ("../../evaluation/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50.csv", "Torchxrayvision ResNet50 Imaging+Clinical"),
    # ("../../evaluation/vlp/linear_probe_only_imaging_resnet34.csv", "ResNet34 VLP Linear Probe Imaging"),
    # ("../../evaluation/finetune/only_imaging_finetune_resnet34.csv", "ResNet34 Finetune Imaging"),
    ("../../evaluation/finetune/imaging_and_clinical_finetune_resnet34.csv", "VLP Finetune ResNet34 Imaging+Clinical")
]
imaging_files = [f[0] for f in files_and_titles]
imaging_titles = [f[1] for f in files_and_titles]
imaging_experiment_dfs = [pd.read_csv(f) for f in imaging_files]
imaging_experiment_dfs = [clean(df) for df in imaging_experiment_dfs]
imaging_dfs = [average_over_folds(df) for df in imaging_experiment_dfs]

In [None]:
for df in imaging_dfs:
    print(df[df['group'] == 'arm'])

In [None]:
plot_metrics_per_group(imaging_dfs, "entity", occurences_df=occurences_df, group_name="Entity", exclude_most_common=True, order_by_occurrence=True, df_titles=imaging_titles, show_values=False)

In [None]:
plots = plot_metrics_per_group(imaging_dfs, "anatomy_site", occurences_df=occurences_df, group_name="Anatomy Site", exclude_most_common=False, order_by_occurrence=True, df_titles=imaging_titles, show_values=False)
plots['roc_auc'].savefig("../../visualizations/results/pretrained_baseline_vs_finetune_auroc_on_anatomy_site.svg")

In [None]:
plot_metrics_per_group(imaging_dfs, "age_encoded", occurences_df=occurences_df, group_name="Age Encoded", exclude_most_common=False, order_by_occurrence=False, df_titles=imaging_titles, show_values=False)

In [None]:
plot_metrics_per_group(imaging_dfs, "age_group", occurences_df=occurences_df, group_name="Age Grouped", exclude_most_common=False, order_by_occurrence=False, df_titles=imaging_titles, show_values=False)

### Performance Latex Table Generation

In [None]:
def generate_latex_table(avg_dfs, titles):
    """
    Generate LaTeX table with overall performance metrics.
    """
    # Extract overall metrics for all experiments
    table_data = []
    metrics_of_interest = ['accuracy', 'precision', 'recall', 'roc_auc']  # Removed f1_score
    
    for i, avg_df in enumerate(avg_dfs):
        overall_metrics = avg_df[avg_df["level"] == "overall"]
        row_data = {'experiment': titles[i]}
        
        for metric in metrics_of_interest:
            metric_row = overall_metrics[overall_metrics["metric"] == metric]
            if not metric_row.empty:
                mean_val = metric_row[("value", "mean")].iloc[0]
                std_val = metric_row[("value", "std")].iloc[0]
                row_data[metric] = (mean_val, std_val)
            else:
                row_data[metric] = (0.0, 0.0)
        
        table_data.append(row_data)
    
    # Group experiments by imaging type
    imaging_only = []
    imaging_clinical = []
    
    for row in table_data:
        if "Only Imaging" in row['experiment']:
            imaging_only.append(row)
        elif "Imaging+Clinical" in row['experiment'] or "Imaging & Clinical" in row['experiment']:
            imaging_clinical.append(row)
    
    # Find best performance for each metric within each group
    def find_best_in_group(group_data):
        best_values = {}
        for metric in metrics_of_interest:
            if group_data:
                best_mean = max(row[metric][0] for row in group_data)
                best_values[metric] = best_mean
        return best_values
    
    imaging_only_best = find_best_in_group(imaging_only)
    imaging_clinical_best = find_best_in_group(imaging_clinical)
    
    # Generate LaTeX table
    latex_code = """\\begin{table*}[htbp]
\\centering
\\caption{Test performance across experiments.}
\\label{tab:test_results}
\\renewcommand{\\arraystretch}{1.2}
\\setlength{\\tabcolsep}{6pt}
\\begin{tabular}{llcccc}
\\toprule
\\textbf{Experiment} & \\textbf{Model} & \\textbf{Acc} & \\textbf{Prec} & \\textbf{Rec} & \\textbf{AUROC} \\\\
\\midrule"""
    
    # Add imaging only section
    if imaging_only:
        latex_code += "\n\\multicolumn{6}{l}{\\textbf{Imaging}} \\\\"
        for row in imaging_only:
            # Extract model name from experiment name
            if "ResNet34" in row['experiment']:
                model = "ResNet34"
            elif "ResNet50" in row['experiment']:
                model = "ResNet50"
            elif "Nest Small" in row['experiment']:
                model = "NesT-S"
            else:
                model = "Unknown"
            
            # Extract experiment type
            if "Baseline" in row['experiment'] and "Pretrained" not in row['experiment'] and "VLP" not in row['experiment'] and "Finetune" not in row['experiment']:
                exp_name = "Baseline"
            elif "Pretrained Baseline" in row['experiment']:
                exp_name = "Pretrained Baseline"
            elif "VLP Linear Probe" in row['experiment'] or "Linear Probe" in row['experiment']:
                exp_name = "\\ac{VLP} Linear Probe (ours)"
            elif "Finetune" in row['experiment'] and "VLP" in row['experiment']:
                exp_name = "\\ac{VLP} Finetune (ours)"
            elif "Finetune" in row['experiment']:
                exp_name = "\\ac{VLP} Finetune (ours)"
            else:
                exp_name = "Unknown"
            
            latex_code += f"\n{exp_name} & {model}"
            
            for metric in metrics_of_interest:
                mean_val, std_val = row[metric]
                formatted_val = f"{mean_val:.3f}$\\pm${std_val:.3f}"
                
                # Make bold if best performance within imaging only group
                if abs(mean_val - imaging_only_best[metric]) < 1e-6:
                    formatted_val = f"\\textbf{{{formatted_val}}}"
                
                latex_code += f" & {formatted_val}"
            
            latex_code += " \\\\"
    
    # Add separator
    if imaging_only and imaging_clinical:
        latex_code += "\n\\midrule"
    
    # Add imaging + clinical section
    if imaging_clinical:
        latex_code += "\n\\multicolumn{6}{l}{\\textbf{Imaging + Clinical}} \\\\"
        for row in imaging_clinical:
            # Extract model name from experiment name
            if "ResNet34" in row['experiment']:
                model = "ResNet34"
            elif "ResNet50" in row['experiment']:
                model = "ResNet50"
            elif "Nest Small" in row['experiment']:
                model = "NesT-S"
            else:
                model = "Unknown"
            
            # Extract experiment type
            if "Baseline" in row['experiment'] and "Pretrained" not in row['experiment'] and "VLP" not in row['experiment'] and "Finetune" not in row['experiment']:
                exp_name = "Baseline"
            elif "Pretrained Baseline" in row['experiment']:
                exp_name = "Pretrained Baseline"
            elif "VLP Linear Probe" in row['experiment'] or "Linear Probe" in row['experiment']:
                exp_name = "\\ac{VLP} Linear Probe (ours)"
            elif "Finetune" in row['experiment'] and "VLP" in row['experiment']:
                exp_name = "\\ac{VLP} Finetune (ours)"
            elif "Finetune" in row['experiment']:
                exp_name = "\\ac{VLP} Finetune (ours)"
            else:
                exp_name = "Unknown"
            
            latex_code += f"\n{exp_name} & {model}"
            
            for metric in metrics_of_interest:
                mean_val, std_val = row[metric]
                formatted_val = f"{mean_val:.3f}$\\pm${std_val:.3f}"
                
                # Make bold if best performance within imaging + clinical group
                if abs(mean_val - imaging_clinical_best[metric]) < 1e-6:
                    formatted_val = f"\\textbf{{{formatted_val}}}"
                
                latex_code += f" & {formatted_val}"
            
            latex_code += " \\\\"
    
    latex_code += """
\\bottomrule
\\end{tabular}
\\end{table*}"""
    
    return latex_code

# Generate and display the LaTeX table
latex_table = generate_latex_table(imaging_dfs, imaging_titles)
print("LaTeX Table Code:")
print("="*50)
print(latex_table)
print("="*50)

# Also create a summary DataFrame for easier viewing
summary_data = []
metrics_of_interest = ['accuracy', 'precision', 'recall', 'roc_auc']  # Removed f1_score

for i, avg_df in enumerate(imaging_dfs):
    overall_metrics = avg_df[avg_df["level"] == "overall"]
    row_data = {'Experiment': imaging_titles[i]}
    
    for metric in metrics_of_interest:
        metric_row = overall_metrics[overall_metrics["metric"] == metric]
        if not metric_row.empty:
            mean_val = metric_row[("value", "mean")].iloc[0]
            std_val = metric_row[("value", "std")].iloc[0]
            row_data[metric.replace('_', ' ').title()] = f"{mean_val:.3f} ± {std_val:.3f}"
        else:
            row_data[metric.replace('_', ' ').title()] = "N/A"
    
    summary_data.append(row_data)

summary_df = pd.DataFrame(summary_data)
print("\nSummary Table (for reference):")
print(summary_df.to_string(index=False))

### Line Charts for each Metric on Anatomy Site

In [None]:
baseline_only_imaging_resnet34 = average_over_folds(clean(pd.read_csv("../../evaluation/baseline/only_imaging_baseline_resnet34.csv"))) # ResNet34 Baseline Only Imaging
baseline_imaging_clinical_resnet34 = average_over_folds(clean(pd.read_csv("../../evaluation/baseline/imaging_and_clinical_baseline_resnet34.csv"))) # ResNet34 Baseline Imaging+Clinical
baseline_only_imaging_nest = average_over_folds(clean(pd.read_csv("../../evaluation/baseline/only_imaging_baseline_nest_small.csv"))) # Nest Small Baseline Only Imaging
baseline_imaging_clinical_nest = average_over_folds(clean(pd.read_csv("../../evaluation/baseline/imaging_and_clinical_baseline_nest_small.csv"))) # Nest Small Baseline Imaging+Clinical
pretrained_baseline_only_imaging_resnet50 = average_over_folds(clean(pd.read_csv("../../evaluation/baseline_pretrained/only_imaging_pretrained_baseline_resnet50.csv"))) # ResNet50 Pretrained Baseline Only Imaging
pretrained_baseline_imaging_clinical_resnet50 = average_over_folds(clean(pd.read_csv("../../evaluation/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50.csv"))) # ResNet50 Pretrained Baseline Imaging+Clinical
linear_probe_only_imaging_resnet34 = average_over_folds(clean(pd.read_csv("../../evaluation/vlp/linear_probe_only_imaging_resnet34.csv"))) # ResNet34 VLP Linear Probe Only Imaging
finetune_ony_imaging_resnet34 =  average_over_folds(clean(pd.read_csv("../../evaluation/finetune/only_imaging_finetune_resnet34.csv"))) # ResNet34 Finetune Only Imaging
finetune_imaging_clincial_resnet34 = average_over_folds(clean(pd.read_csv("../../evaluation/finetune/imaging_and_clinical_finetune_resnet34.csv"))) # ResNet34 Finetune Imaging+Clinical

In [None]:
dfs_and_titles = ([
        baseline_only_imaging_resnet34,
        baseline_imaging_clinical_resnet34,
        # baseline_only_imaging_nest,
        # baseline_imaging_clinical_nest,
        pretrained_baseline_only_imaging_resnet50,
        pretrained_baseline_imaging_clinical_resnet50,
        finetune_ony_imaging_resnet34,
        finetune_imaging_clincial_resnet34,
        # linear_probe_only_imaging_resnet34
    ], [
        'Baseline ResNet34 Imaging',
        'Baseline ResNet34 Imaging+Clinical',
        # 'Baseline Only Imaging NesT-S',
        # 'Baseline Imaging+Clinical NesT-S',
        'Torchxrayvision ResNet50 Imaging',
        'Torchxrayvision ResNet50 Imaging+Clinical',
        'VLP Finetune ResNet34 Imaging',
        'VLP Finetune ResNet34 Imaging+Clinical',
        # 'VLP Linear Probe Only Imaging ResNet34'
    ])

In [None]:
def plot_metric_comparison_by_var(dfs, titles, metric, var, metadata_df=None):
    for i in range(len(dfs)):
        dfs[i]['experiment'] = titles[i]

    combined_df = pd.concat(dfs, ignore_index=True)
    # only consider metric entries
    combined_df = combined_df[combined_df['metric'] == metric]
    # only consider the specified var
    combined_df = combined_df[combined_df['level'] == var]

    # if occurences_df is provided, sort ascending by occurence of var
    order = None
    percentage_data = None
    if metadata_df is not None:
        # drop everything in metadata_df except for the column called var
        metadata_df = metadata_df[[var]]
        # count occurrences of each value in var
        group_occurrences = metadata_df.groupby(var).size().reset_index(name='count')
        group_occurrences = group_occurrences.sort_values(by='count')
        order = group_occurrences[var].tolist()
        # order combined_df accordingly
        combined_df['group'] = pd.Categorical(combined_df['group'], categories=order, ordered=True)
        combined_df = combined_df.sort_values(by='group')
        
        # Prepare percentage data for secondary axis
        total_count = group_occurrences['count'].sum()
        percentage_data = (group_occurrences.set_index(var)['count'] / total_count)

    fig, ax1 = plt.subplots(figsize=(12, 8))
    
    # Primary plot (metric values)
    plot = sns.lineplot(
        data=combined_df,
        x='group',
        y=('value', 'mean'),
        hue='experiment',
        hue_order=titles,
        palette=[c for i, c in enumerate(sns.color_palette("Paired", len(titles) + 3)) if i not in (6, 7, 10)],
        marker='.',
        legend='full',
        sort=False, # sorting is already, handled, otherwise seaborn sorts alphabetically again
        linewidth=1,
        ax=ax1
    )

    # plot.set_style("whitegrid")
    ax1.set_xlabel(var.replace('_', ' ').title())
    ax1.set_ylim(0, 1.05)
    ax1.set_ylabel(metric.replace('_', ' ').title() if metric != 'roc_auc' else 'AUROC')
    ax1.grid(True)
    plt.xticks(rotation=45)
    plot.set_ylabel(metric.replace('_', ' ').title() if metric != 'roc_auc' else 'AUROC')
    plot.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02), frameon=True)
    
    # Add secondary y-axis for percentage data
    if percentage_data is not None:
        ax2 = ax1.twinx()
        
        # Plot percentage line
        groups_in_order = combined_df['group'].cat.categories if hasattr(combined_df['group'], 'cat') else combined_df['group'].unique()
        percentages = [percentage_data.get(group, 0) for group in groups_in_order]
        
        ax2.plot(range(len(groups_in_order)), percentages, 
                color='gray', linestyle='--', linewidth=1, marker='o', 
                markersize=3, label='Percentage of Downstream Dataset', alpha=0.4)
        
        ax2.set_ylabel('Percentage', color='gray')
        ax2.tick_params(axis='y', labelcolor='gray')
        ax2.set_ylim(0, max(percentages) * 1.2)
        
        # Remove grid lines from secondary axis
        ax2.grid(False)
        
        # Add legend for percentage line
        ax2.legend(loc='lower right')

    fig.set_dpi(1000)
    plt.tight_layout()
    plot_to_return = plt.gcf()
    plt.show()
    return plot_to_return

In [None]:
dfs, titles = dfs_and_titles
plot = plot_metric_comparison_by_var(dfs, titles, metric='roc_auc', var='anatomy_site', metadata_df=metadata_df)
plot.savefig("../../visualizations/results/auroc_over_anatomy_sites.svg")

In [None]:
dfs, titles = dfs_and_titles
plot = plot_metric_comparison_by_var(dfs, titles, metric='recall', var='anatomy_site', metadata_df=metadata_df)

In [None]:
dfs, titles = dfs_and_titles
plot = plot_metric_comparison_by_var(dfs, titles, metric='precision', var='anatomy_site', metadata_df=metadata_df)

In [None]:
dfs, titles = dfs_and_titles
plot = plot_metric_comparison_by_var(dfs, titles, metric='balanced_accuracy', var='anatomy_site', metadata_df=metadata_df)

In [None]:
dfs, titles = dfs_and_titles
plot = plot_metric_comparison_by_var(dfs, titles, metric='accuracy', var='anatomy_site', metadata_df=metadata_df)

### Comparing Imaging and Imaging+Clinical Experiments on Age

In [None]:
imaging_files_and_titles = [
    ("../../evaluation/baseline/only_imaging_baseline_resnet34.csv", "ResNet34 Baseline Only Imaging"),
    ("../../evaluation/baseline/only_imaging_baseline_nest_small.csv", "Nest Small Baseline Only Imaging"),
    ("../../evaluation/baseline_pretrained/only_imaging_pretrained_baseline_resnet50.csv", "ResNet50 Pretrained Baseline Only Imaging"),
    ("../../evaluation/finetune/only_imaging_finetune_resnet34.csv", "ResNet34 Finetune Only Imaging"),
]
imaging_and_clinical_files_and_titles = [
    ("../../evaluation/baseline/imaging_and_clinical_baseline_resnet34.csv", "ResNet34 Baseline Imaging+Clinical"),
    ("../../evaluation/baseline/imaging_and_clinical_baseline_nest_small.csv", "Nest Small Baseline Imaging+Clinical"),
    ("../../evaluation/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50.csv", "ResNet50 Pretrained Baseline Imaging+Clinical"),
    ("../../evaluation/finetune/imaging_and_clinical_finetune_resnet34.csv", "ResNet34 Finetune Imaging+Clinical")
]

def read_clean_and_average(files_and_titles):
    files = [f[0] for f in files_and_titles]
    titles = [f[1] for f in files_and_titles]
    dfs = [pd.read_csv(f) for f in files]
    dfs = [clean(df) for df in dfs]
    dfs = [average_over_folds(df) for df in dfs]
    return dfs, titles

imaging_dfs, imaging_titles = read_clean_and_average(imaging_files_and_titles)
imaging_and_clinical_dfs, imaging_and_clinical_titles = read_clean_and_average(imaging_and_clinical_files_and_titles)

In [None]:
from scipy.stats import wilcoxon

def plot_grouped_metric_comparison_by_var(group_1_dfs, group_2_dfs, group_1_title, group_2_title, metric='precision', var='age_encoded'):
    """
    Plot comparison of a specified metric by age group between imaging-only and imaging+clinical experiments.
    
    Args:
        group_1_dfs: List of dataframes with first group experiment results (e.g., imaging-only)
        group_2_dfs: List of dataframes with second group experiment results (e.g., imaging+clinical)
        metric: Metric to analyze (default: 'precision')
        age_level: Age grouping level to use - 'age_encoded' or 'age_group' (default: 'age_encoded')
    """
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    # Combine first group experiments (keep experiment ID for std calculation)
    group_1_combined = []
    for i in range(len(group_1_dfs)):
        df_copy = group_1_dfs[i].copy()
        df_copy['experiment'] = i
        group_1_combined.append(df_copy)
    
    group_1_experiment_dfs = pd.concat(group_1_combined, ignore_index=True)
    group_2_combined = []
    for i in range(len(group_2_dfs)):
        df_copy = group_2_dfs[i].copy()
        df_copy['experiment'] = i + len(group_1_dfs)
        group_2_combined.append(df_copy)
    
    group_2_experiment_dfs = pd.concat(group_2_combined, ignore_index=True)

    # Extract age data for specified metric and age level
    group_1_age_metric = group_1_experiment_dfs[
        (group_1_experiment_dfs['level'] == var) & 
        (group_1_experiment_dfs['metric'] == metric)
    ].copy()

    group_2_age_metric = group_2_experiment_dfs[
        (group_2_experiment_dfs['level'] == var) & 
        (group_2_experiment_dfs['metric'] == metric)
    ].copy()

    # Prepare data for plotting
    group_1_age_metric['Comparison Group'] = group_1_title
    group_2_age_metric['Comparison Group'] = group_2_title

    # Combine both dataframes
    combined_data = pd.concat([group_1_age_metric, group_2_age_metric], ignore_index=True)

    # Rename columns for easier access
    age_level_titel = var.replace('_', ' ').title() if var != 'age_encoded' else 'Age'
    combined_data = combined_data.rename(columns={
        'level': 'level',
        'group': age_level_titel,
        'metric': 'metric',
        'value': metric
    })

    fig = plt.figure(figsize=(10, 5))
    sns.set_style("whitegrid")
    plot = sns.lineplot(
        data=combined_data,
        x=age_level_titel,
        y=(metric, 'mean'),
        hue='Comparison Group',
        estimator=np.mean,
        marker='o',
        errorbar='sd',
        err_style='band',
        legend='full'
    )

    if var == 'age_encoded':
        plot.set_xticks([0, 1, 2, 3, 4, 5, 6], ['0-9', '10-19', '20-29', '30-39', '40-49', '50-59', '60+'])

    plot.set_ylim(0, 1)
    plot.set_ylabel(metric.replace("_", " ").title() if metric != 'roc_auc' else 'AUROC')
    plot.legend(loc='lower left')

    return fig

In [None]:
plot = plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='precision',
    var='age_encoded'
)
# plot = plot_grouped_metric_comparison_by_var(
#     imaging_dfs, 
#     imaging_and_clinical_dfs,
#     group_1_title='Imaging',
#     group_2_title='Imaging + Clinical',
#     metric='precision',
#     var='age_group'
# )
plot.savefig("../../visualizations/results/imaging_vs_imaging_and_clinical_across_age_encoded_precision.svg")

In [None]:
plot = plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='recall',
    var='age_encoded'
)
plot.savefig("../../visualizations/results/imaging_vs_imaging_and_clinical_across_age_encoded_recall.svg")
# plot_grouped_metric_comparison_by_var(
#     imaging_dfs, 
#     imaging_and_clinical_dfs,
#     group_1_title='Imaging',
#     group_2_title='Imaging + Clinical',
#     metric='recall',
#     var='age_group'
# )

In [None]:
plot = plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='f1_score',
    var='age_encoded'
)
plot.savefig("../../visualizations/results/imaging_vs_imaging_and_clinical_across_age_encoded_f1_score.svg")
# plot_grouped_metric_comparison_by_var(
#     imaging_dfs, 
#     imaging_and_clinical_dfs,
#     group_1_title='Imaging',
#     group_2_title='Imaging + Clinical',
#     metric='f1_score',
#     var='age_group'
# )

In [None]:
plot = plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='roc_auc',
    var='age_encoded'
)
plot.savefig("../../visualizations/results/imaging_vs_imaging_and_clinical_across_age_encoded_auroc.svg")
# plot_grouped_metric_comparison_by_var(
#     imaging_dfs, 
#     imaging_and_clinical_dfs,
#     group_1_title='Imaging',
#     group_2_title='Imaging + Clinical',
#     metric='roc_auc',
#     var='age_group'
# )

In [None]:
plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='precision',
    var='anatomy_site'
)
plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='recall',
    var='anatomy_site'
)
plot_grouped_metric_comparison_by_var(
    imaging_dfs, 
    imaging_and_clinical_dfs,
    group_1_title='Imaging',
    group_2_title='Imaging + Clinical',
    metric='roc_auc',
    var='anatomy_site'
)

#### Have a look at precison over age compared between Scratch and Finetuning

In [None]:
scratch_files_and_titles = [
    ("../../evaluation/baseline/only_imaging_baseline_resnet34.csv", "ResNet34 Baseline Only Imaging"),
    ("../../evaluation/baseline/imaging_and_clinical_baseline_resnet34.csv", "ResNet34 Baseline Imaging+Clinical"),
    # ("../../evaluation/baseline/only_imaging_baseline_nest_small.csv", "Nest Small Baseline Only Imaging"),
    # ("../../evaluation/baseline_pretrained/only_imaging_pretrained_baseline_resnet50.csv", "ResNet50 Pretrained Baseline Only Imaging"),
]
finetune_files_and_titles = [
    # ("../../evaluation/baseline/imaging_and_clinical_baseline_nest_small.csv", "Nest Small Baseline Imaging+Clinical"),
    # ("../../evaluation/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50.csv", "ResNet50 Pretrained Baseline Imaging+Clinical"),
    ("../../evaluation/finetune/only_imaging_finetune_resnet34.csv", "ResNet34 Finetune Only Imaging"),
    ("../../evaluation/finetune/imaging_and_clinical_finetune_resnet34.csv", "ResNet34 Finetune Imaging+Clinical")
]

scratch_dfs, scratch_titles = read_clean_and_average(scratch_files_and_titles)
finetune_dfs, finetune_titles = read_clean_and_average(finetune_files_and_titles)

In [None]:
plot_grouped_metric_comparison_by_var(
    scratch_dfs, 
    finetune_dfs,
    group_1_title='Scratch',
    group_2_title='Finetuning',
    metric='precision',
    var='age_encoded'
)

### Significance test, whether pretrained baseline and finetune differ significantly in their predictions

>Note: only considering Imaging+Clinical here

In [None]:
# take fold number 1 from pretrained baseline, since it has best auroc on test set compared to the other model folds
probs_df_pretrained_baseine = pd.read_csv('../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_1.csv')
probs_df_pretrained_baseine = pd.read_csv('../../predictions/baseline/imaging_and_clinical_baseline_resnet34/predictions_fold_0.csv')
probs_pretrained_baseline = probs_df_pretrained_baseine['prob']


# take fold number 2 from finetuned (ourse), same reason
probs_df_finetuned = pd.read_csv('../../predictions/finetune/imaging_and_clinical_finetune_resnet34/predictions_fold_2.csv')
probs_finetuned = probs_df_finetuned['prob']

# we assume that the order of samples is the same in both prediction files -> s.t. we can do a paired test
# having a sanity check by comparing wether everything except the probabilities is the same
cols_to_check = ['dataset', 'entity', 'anatomy_site', 'sex', 'age', 'age_encoded', 'tumor']
for col in cols_to_check:
    if not all(probs_df_pretrained_baseine[col] == probs_df_finetuned[col]):
        raise ValueError(f"Column {col} does not match between the two prediction files.")

In [None]:
from scipy.stats import normaltest

stat, p_value = normaltest(probs_pretrained_baseline - probs_finetuned)
print(f"\nNormality test for difference of probabilities:")
print(f"Statistic: {stat}, p-value: {p_value}")

p-value < 0.05 so we reject the null hypothesis that the data is normally distributed -> t-test not applicable.
Instead, we use the wilcoxon signed-rank test.

In [None]:
from scipy.stats import wilcoxon

stat, p_value = wilcoxon(probs_pretrained_baseline, probs_finetuned, alternative='two-sided')
print(f"Wilcoxon Signed Rank Test for difference in pretrained baseline and finetune")
print(f"Statistic: {stat}, p-value: {p_value}")

p_value < 0.05 so we reject the null hypothesis that both data lists are from the same distribution -> They differ significantly in their predictions.

## Investigating Metadata Groups

I'm interested in whether the fusion approach using the metadata: age, sex_encoded, and anatomy site, can leverage combinations of metadata, where there is only one label at all (e.g. for male, age_encoded 4, knee -> there might only be non-tumor samples). For that I'm having a look whether such groups exist.

In [None]:
def encode_age(age: int):
    if age < 0:
        raise ValueError(f"Age must be a positive integer, got {age}")

    # as described in Michaels thesis the age is binned into 10 year interval with all 60 and above assigned to the same bin 7
    if age < 10:
        bin = 1
    elif age < 20:
        bin = 2
    elif age < 30:
        bin = 3
    elif age < 40:
        bin = 4
    elif age < 50:
        bin = 5
    elif age < 60:
        bin = 6
    else:
        bin = 7

    return bin

metadata_df['age_encoded'] = metadata_df['age'].apply(encode_age)
metadata_df.head()

In [None]:
# drop age, datasetset, and entity since the models actually never see this
skinny_metadata_df = metadata_df.drop(['dataset', 'entity', 'age', 'set'], axis=1)
skinny_metadata_df.head()

In [None]:
metadata_groups = skinny_metadata_df.groupby(['anatomy_site', 'sex', 'age_encoded']).agg({'tumor': ['sum', 'count']})
metadata_groups.head(15)

In two scenarios as there only one possible tumor label per metadata combination: If the sum is 0 -> all have tumor label 0 or if sum == count -> all have tumor label 1

In [None]:
only_one_label_present = (metadata_groups['tumor']['sum'] == 0) | (metadata_groups['tumor']['sum'] == metadata_groups['tumor']['count'])
print(f"{sum(only_one_label_present)}/{len(only_one_label_present)} metadata combinations have only one tumor label present")

only_one_label_present_and_more_than_2 = (metadata_groups['tumor']['sum'] == 0) | (metadata_groups['tumor']['sum'] == metadata_groups['tumor']['count']) & (metadata_groups['tumor']['count'] > 2)
print(f"{sum(only_one_label_present_and_more_than_2)}/{len(only_one_label_present_and_more_than_2)} metadata combinations with 3 or more representatives, have only one tumor label present")


In [None]:
# compute the amount of samples for which its combination has only one label present
samples_with_only_one_label_present = metadata_groups[only_one_label_present]
samples_with_only_one_label_present = samples_with_only_one_label_present['tumor']['count'].sum()

print(f"{samples_with_only_one_label_present}/{metadata_groups['tumor']['count'].sum()} samples have a combination of metadata features with only one tumor label present")

## Calculating Specificity and Sensitivity

In [None]:
from sklearn.metrics import confusion_matrix

# load predictions
# pretrained baseline
files = ['../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_0.csv',
         '../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_1.csv',
         '../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_2.csv',
         '../../predictions/baseline_pretrained/imaging_and_clinical_pretrained_baseline_resnet50/predictions_fold_3.csv'
    ]
dfs = [pd.read_csv(f) for f in files]
sensitivities = []
specificities = []
precisions = []
for df in dfs:
    df['pred'] = df['prob'] >= 0.5
    tn, fp, fn, tp = confusion_matrix(df['tumor'], df['pred']).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    precision = tp / (tp + fp)
    sensitivities.append(sensitivity)
    specificities.append(specificity)
    precisions.append(precision)

# average over folds +- std
print(f"Sensitivity: {np.mean(sensitivities):.4f} ± {np.std(sensitivities):.4f}")
print(f"Specificity: {np.mean(specificities):.4f} ± {np.std(specificities):.4f}")
print(f"Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}")
