In [None]:
import os
import pandas as pd
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import seaborn as sns

In [None]:
def load_csv(file_path):
    """Load CSV file and ensure consistent formatting."""
    df = pd.read_csv(file_path)
    # Convert all boolean columns to lowercase strings for consistency
    for col in df.columns[1:]:  # Skip first column (image names)
        df[col] = df[col].astype(str).str.lower().str.strip()
    return df

In [None]:
def calculate_metrics(y_true, y_pred):
    """Calculate precision, recall, f1, accuracy, and MCC."""
    # Convert to boolean numpy arrays
    true_pos = ((y_true == "true") & (y_pred == "true")).sum()
    false_pos = ((y_true == "false") & (y_pred == "true")).sum()
    false_neg = ((y_true == "true") & (y_pred == "false")).sum()
    true_neg = ((y_true == "false") & (y_pred == "false")).sum()
    
    # Calculate metrics
    precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0
    recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    accuracy = (true_pos + true_neg) / len(y_true) if len(y_true) > 0 else 0

    # Calculate Matthews Correlation Coefficient (MCC)
    # Denominator components
    denom_tp_fp = true_pos + false_pos
    denom_tp_fn = true_pos + false_neg
    denom_tn_fp = true_neg + false_pos
    denom_tn_fn = true_neg + false_neg
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'accuracy': accuracy,
        'support': len(y_true)
    }

In [None]:
def compare_models(ground_truth_path, model_paths):
    """
    Compare model predictions against ground truth with full metrics.
    
    Args:
        ground_truth_path: Path to ground truth CSV file
        model_paths: Dictionary of {model_name: model_csv_path}
    
    Returns:
        Dictionary containing comparison results and missing image info
    """
    # Load ground truth data
    try:
        gt_df = load_csv(ground_truth_path)
    except Exception as e:
        print(f"Error loading ground truth: {e}")
        return None
    
    # Initialize results storage
    results = {
        'tasks': list(gt_df.columns[1:]),  # Skip image name column
        'model_metrics': defaultdict(dict),
        'missing_images': defaultdict(list),
        'image_counts': {}
    }
    
    # Get all ground truth images
    all_gt_images = set(gt_df.iloc[:, 0])
    results['image_counts']['ground_truth'] = len(all_gt_images)
    
    # Compare each model
    for model_name, model_path in model_paths.items():
        try:
            model_df = load_csv(model_path)
        except Exception as e:
            print(f"Error loading model {model_name} data: {e}")
            continue
        
        # Get model images and report missing ones
        model_images = set(model_df.iloc[:, 0])
        missing_images = all_gt_images - model_images
        results['missing_images'][model_name] = sorted(missing_images)
        results['image_counts'][model_name] = len(model_images)
        
        # Use only images present in both ground truth and model
        common_images = all_gt_images.intersection(model_images)
        if not common_images:
            print(f"Warning: No common images between ground truth and {model_name}")
            continue
        
        # Filter and sort dataframes by image name
        gt_common = gt_df[gt_df.iloc[:, 0].isin(common_images)]
        model_common = model_df[model_df.iloc[:, 0].isin(common_images)]
        
        # Sort both dataframes by image name to ensure alignment
        gt_common = gt_common.sort_values(by=gt_df.columns[0]).reset_index(drop=True)
        model_common = model_common.sort_values(by=model_df.columns[0]).reset_index(drop=True)
        
        # Verify the same tasks are present
        if set(gt_common.columns[1:]) != set(model_common.columns[1:]):
            print(f"Warning: Model {model_name} has different tasks than ground truth")
            # Continue with intersection of tasks
            common_tasks = set(gt_common.columns[1:]).intersection(set(model_common.columns[1:]))
            gt_common = gt_common[gt_common.columns[:1].tolist() + list(common_tasks)]
            model_common = model_common[model_common.columns[:1].tolist() + list(common_tasks)]
            if 'tasks_updated' not in results:
                results['tasks_updated'] = {}
            results['tasks_updated'][model_name] = list(common_tasks)
        
        # Calculate metrics for each task
        tasks_to_evaluate = results['tasks']
        if 'tasks_updated' in results and model_name in results['tasks_updated']:
            tasks_to_evaluate = results['tasks_updated'][model_name]
        
        for task in tasks_to_evaluate:
            try:
                metrics = calculate_metrics(gt_common[task].values, model_common[task].values)
                results['model_metrics'][model_name][task] = metrics
            except KeyError:
                print(f"Warning: Task {task} not found in model {model_name}")
                continue
    
    return results

In [None]:
def generate_report(results, output_file=None):
    """Generate a detailed comparison report with all metrics."""
    if not results:
        return "No results to report."
    
    report = []
    report.append("Model Comparison Report")
    report.append("=" * 40)
    report.append(f"Tasks evaluated: {', '.join(results['tasks'])}")
    
    # Report image counts and missing images
    report.append("\nImage Counts:")
    report.append("-" * 20)
    report.append(f"Ground truth images: {results['image_counts']['ground_truth']}")
    for model in results['image_counts']:
        if model != 'ground_truth':
            report.append(f"{model}: {results['image_counts'][model]} (missing {len(results['missing_images'][model])})")
    
    # Report metrics for each task
    report.append("\nDetailed Metrics by Task:")
    report.append("-" * 20)
    
    for task in results['tasks']:
        report.append(f"\nTask: {task}")
        report.append("-" * len(f"Task: {task}"))
        
        # Get all models that have this task
        task_data = []
        for model in results['model_metrics']:
            if task in results['model_metrics'][model]:
                metrics = results['model_metrics'][model][task]
                task_data.append((
                    model,
                    metrics['accuracy'],
                    metrics['precision'],
                    metrics['recall'],
                    metrics['f1'],
                    metrics['support']
                ))
        
        # Sort by F1 score (descending)
        task_data.sort(key=lambda x: x[4], reverse=True)
        
        # Add header
        report.append(f"{'Model':<15} {'Precision':<10} {'Recall':<10} {'F1':<10}")
        report.append("-" * 75)
        
        # Add metrics for each model
        for data in task_data:
            model, acc, prec, rec, f1, supp = data
            report.append(
                f"{model:<15} {prec:.04f}      {rec:.04f}      {f1:.04f}"
            )
    
    # Report missing images if any
    has_missing = any(len(missing) > 0 for missing in results['missing_images'].values())
    if has_missing:
        report.append("\nMissing Images by Model:")
        report.append("-" * 20)
        for model, missing in results['missing_images'].items():
            if missing:
                report.append(f"\n{model} missing {len(missing)} images:")
                # Display up to 5 missing images to avoid cluttering the report
                display_count = min(5, len(missing))
                report.append(", ".join(missing[:display_count]))
                if len(missing) > display_count:
                    report.append(f"... and {len(missing) - display_count} more")
    
    full_report = "\n".join(report)
    
    if output_file:
        with open(output_file, 'w') as f:
            f.write(full_report)
    
    return full_report

In [None]:
def plot_metrics(results, output_dir="figures"):
    """
    Generate academic report quality figures for each metric and task.
    
    Args:
        results: Dictionary containing comparison results.
        output_dir: Directory to save the figures.
    """
    if not results:
        print("No results to plot.")
        return

    os.makedirs(output_dir, exist_ok=True)
    
    metrics_to_plot = ['f1']
    
    # Prepare data for plotting
    plot_data = defaultdict(lambda: defaultdict(dict))
    for model_name, tasks_metrics in results['model_metrics'].items():
        for task_name, metrics in tasks_metrics.items():
            for metric_name, value in metrics.items():
                if metric_name in metrics_to_plot:
                    plot_data[metric_name][task_name][model_name] = value

    for metric_name in metrics_to_plot:
        df_metric = pd.DataFrame()
        for task_name, model_values in plot_data[metric_name].items():
            for model_name, value in model_values.items():
                df_metric = pd.concat([df_metric, pd.DataFrame({
                    'Model': [model_name],
                    'Task': [task_name],
                    'Value': [value]
                })], ignore_index=True)
        
        task_order = df_metric.groupby('Task')['Value'].max().sort_values(ascending=False).index
        
        df_metric['Task'] = pd.Categorical(df_metric['Task'], categories=task_order, ordered=True)

        sns.set_theme(style="whitegrid", palette="viridis")
        plt.figure(figsize=(12, 10))
        
        ax = sns.barplot(data=df_metric, x='Task', y='Value', hue='Model', 
                         edgecolor=".2", linewidth=0.5)
        
        plt.xlabel('Abnormality', fontsize=14)
        plt.ylabel(f'{metric_name.replace("_", " ").title()}', fontsize=14)
        plt.ylim(0, 1.01) # Metrics are between 0 and 1
        plt.xlim(-0.8, 14.5)
        plt.xticks(rotation=45, ha='right', fontsize=14)
        plt.yticks(fontsize=14)

        ax.yaxis.set_major_locator(MultipleLocator(0.2))
        ax.yaxis.set_minor_locator(MultipleLocator(0.1))

        n_models = len(results['model_metrics'])
        
        plt.legend(
            loc='upper center', 
            bbox_to_anchor=(0.5, 1.15), 
            ncol=n_models,
            frameon=False,
            fontsize=14
        )
        plt.tight_layout(rect=[0, 0, 0.88, 1])

        ax.grid(True, axis='y', alpha=0.7)
        ax.grid(which='major', linestyle='-', linewidth='0.5', color='black')
        ax.grid(which='minor', linestyle='-', linewidth='0.2', color='black')
    
        ax.grid(False, axis='x')
        
        figure_path = os.path.join(output_dir, f'{metric_name}_performance.png')
        plt.savefig(figure_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Generated figure: {figure_path}")

In [None]:
ground_truth_file = "out/ground_truth/ct_findings_classification.csv"
model_files = {
    "DINO-B": "out/cls_frozen_onestep_gen/ct_findings_classification.csv",
    "DINO-L": "out/cls_frozen_onestep_large_gen/ct_findings_classification.csv",
    "DINO-B(CT-RATE)": "out/cls_frozen_onestep_ct_rate_gen/ct_findings_classification.csv",
    "CT-CLIP": "out/ct_clip_onestep_gen/ct_findings_classification.csv",
    "CT-FM": "out/ct_fm_onestep_gen/ct_findings_classification.csv",
}

# Verify files exist
if not os.path.exists(ground_truth_file):
    print(f"Error: Ground truth file {ground_truth_file} not found")
    exit()

missing_models = [name for name, path in model_files.items() if not os.path.exists(path)]
if missing_models:
    print(f"Error: Missing model files for {', '.join(missing_models)}")
    exit()

# Compare models
comparison_results = compare_models(ground_truth_file, model_files)

# Generate and print report
report = generate_report(comparison_results, output_file="model_comparison_report.txt")
#print(report)

# Generate and save academic quality figures
plot_metrics(comparison_results, output_dir="out/figures")

In [None]:
metrics = comparison_results['model_metrics']
tasks = comparison_results['tasks']

f1_only = {model: {task:m["f1"] for task, m in tasks_f1.items()} for model, tasks_f1 in metrics.items()}

In [None]:
for model, tasks_perf in f1_only.items():
    print(model)
    values = list(tasks_perf.values())
    means = np.mean(values)
    se = np.std(values) * 0.717
    print(f"{means:.03f}±{se:.03f}")

In [None]:
label_perf = {l:[] for l in tasks}

model_names = ["DINO-B", "DINO-L", "DINO-B(CT-RATE)"] #["CT-CLIP", "CT-FM"] # 

for model_name, tasks_performance in f1_only.items():
    
    if model_name not in model_names:
        continue

    print(model_name)
    
    for label in tasks:

        label_perf[label].append(f"{tasks_performance[label]:.03f}")

for label, row in label_perf.items():
    
    print(f"{label} & " + " & ".join(row) + "\\\\")

