In [None]:
import os
import numpy as np
import pandas as pd
from scipy.special import rel_entr

# ====================================================
# 1. 核心指标计算函数
# ====================================================
true_path = "CSIRO/true.npy"
def calculate_weighted_r2_numpy(y_true, y_pred):
    """
    计算 Kaggle CSIRO Image2Biomass 比赛的加权 R2 分数。
    
    参数:
    y_true: 真实值，形状为 [n_samples, 5]
    y_pred: 预测值，形状为 [n_samples, 5]
    
    列顺序假设:
    0: Dry_Clover_g (w=0.1)
    1: Dry_Dead_g   (w=0.1)
    2: Dry_Green_g  (w=0.1)
    3: Dry_Total_g  (w=0.5)
    4: GDM_g        (w=0.2)
    """
    
    # 1. 定义权重向量
    weights = np.array([0.1, 0.1, 0.1, 0.5, 0.2])
    
    # 2. 确保输入是 numpy 数组
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # 3. 计算全局加权均值 (Global Weighted Mean)
    # 这里的 sum(weights) = 1.0，所以分母实际上就是 样本数 * 1.0
    # 我们利用广播机制将权重应用到每一行
    weighted_sum = np.sum(y_true * weights) 
    total_weight = np.sum(weights) * y_true.shape[0] # weights总和 * 样本数
    y_bar_w = weighted_sum / total_weight
    
    # 4. 计算残差平方和 (SS_res)
    # 公式: sum( w_j * (y_j - y_hat_j)^2 )
    ss_res = np.sum(weights * (y_true - y_pred)**2)
    
    # 5. 计算总离差平方和 (SS_tot)
    # 公式: sum( w_j * (y_j - y_bar_w)^2 )
    # 注意这里减去的是全局加权均值 y_bar_w
    ss_tot = np.sum(weights * (y_true - y_bar_w)**2)
    
    # 6. 计算 R2
    # 避免分母为0的极个别情况
    if ss_tot == 0:
        return 0.0
        
    r2 = 1 - (ss_res / ss_tot)
    
    return r2

def calculate_kl_numpy(p, q, epsilon=1e-8):
    """
    计算 KL 散度 (P || Q)
    """
    # 截断负值
    p = np.clip(p, epsilon, None)
    q = np.clip(q, epsilon, None)
    
    # 按行归一化为概率分布
    p_prob = p / np.sum(p, axis=1, keepdims=True)
    q_prob = q / np.sum(q, axis=1, keepdims=True)
    
    # 计算 KL
    kl_per_sample = np.sum(p_prob * np.log(p_prob / q_prob), axis=1)
    
    return np.mean(kl_per_sample)

# ====================================================
# 2. 主比较逻辑
# ====================================================

def compare_oofs_to_dataframe(baseline_path, comparison_dict):
    """
    Args:
        baseline_path: 基准 oof.npy 路径
        comparison_dict: {模型名: 路径} 字典
    """
    # 加载 Baseline
    if not os.path.exists(baseline_path):
        raise FileNotFoundError(f"Baseline file not found: {baseline_path}")
    
    baseline_pred = np.load(baseline_path)
    print(f"✅ Baseline Loaded: {os.path.basename(baseline_path)} | Shape: {baseline_pred.shape}")
    
    results = []
    
    y_true = np.load(true_path)
    for name, path in comparison_dict.items():
        if not os.path.exists(path):
            print(f"⚠️ Warning: File not found for [{name}], skipping...")
            continue
            
        current_pred = np.load(path)
        
        # 形状检查
        if current_pred.shape != baseline_pred.shape:
            print(f"❌ Shape mismatch [{name}]: {current_pred.shape} vs {baseline_pred.shape}")
            continue
            
        # --- 核心指标计算 ---
        
        # 1. 加权 R2 (Weighted R2)
        # 将 Baseline 视为 True，Current 视为 Pred
        # R2 越高 (接近1)，说明 Current 与 Baseline 的预测趋势越一致
        similarity_r2 = calculate_weighted_r2_numpy(baseline_pred, current_pred)
        
        # 2. KL 散度
        # KL 越低 (接近0)，说明两者预测的“成分分布”越一致
        diff_kl = calculate_kl_numpy(baseline_pred, current_pred)

        # 3. 计算CV
        true_r2 = calculate_weighted_r2_numpy(y_true, current_pred)
        
        results.append({
            "Model": name,
            "Weighted_R2_Similarity": similarity_r2,
            "KL_Divergence": diff_kl,
            "True(CV)": true_r2
        })

    # 生成 DataFrame
    df = pd.DataFrame(results)
    if not df.empty:
        df = df.set_index("Model")
        # 按 Weighted R2 降序排列
        df = df.sort_values(by="Weighted_R2_Similarity", ascending=False)
    
    return df

In [None]:
# ====================================================
# 3. 运行配置 (请修改为你的实际路径)
# ====================================================

# [示例]
BASELINE_PATH = "CSIRO/output/2026-01-15_23:00:56_vit_large_patch16_dinov3.lvd1689m_output/oof.npy"
COMPARISON_MODELS = {
    "BASELINE(8patch)": "CSIRO/output/2026-01-15_23:00:56_vit_large_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-14_01:18:11" : "CSIRO/output/2026-01-14_01:18:11_vit_small_plus_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-13_21:45:33" : "CSIRO/output/2026-01-13_21:45:33_vit_base_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-12_22:22:54" : "CSIRO/output/2026-01-12_22:22:54_vit_base_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-14_13:39:06(vit_base8patch+EMA+Mamba)" : "CSIRO/output/2026-01-14_13:39:06_vit_base_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-14_19:15:08(vit_huge+8patch+EMA+Mamba)" : "CSIRO/output/2026-01-14_19:15:08_vit_huge_plus_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-15_14:30:30_convnextv2_tiny" : "CSIRO/output/2026-01-15_14:30:30_convnextv2_tiny.fcmae_ft_in22k_in1k_output/oof.npy",
    "2026-01-15_20:19:30_vit_large_patch16_dinov3" : "CSIRO/output/2026-01-15_20:19:30_vit_large_patch16_dinov3.lvd1689m_output/oof.npy",
    "2026-01-15_20:19:30_vit_large_patch16_dinov3(Mamba+EMA)" : "CSIRO/output/2026-01-15_22:56:47_vit_large_patch16_dinov3.lvd1689m_output/oof.npy",
}

df = compare_oofs_to_dataframe(BASELINE_PATH, COMPARISON_MODELS)
df

✅ Baseline Loaded: oof.npy | Shape: (357, 5)


Unnamed: 0_level_0,Weighted_R2_Similarity,KL_Divergence,True(CV)
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
BASELINE(8patch),1.0,0.0,0.860987
2026-01-14_13:39:06(vit_base8patch+EMA+Mamba),0.980384,0.02283,0.843581
2026-01-15_20:19:30_vit_large_patch16_dinov3(Mamba+EMA),0.975329,0.044618,0.853224
2026-01-14_01:18:11,0.969712,0.028855,0.833518
2026-01-14_19:15:08(vit_huge+8patch+EMA+Mamba),0.968011,0.072083,0.857842
2026-01-13_21:45:33,0.967207,0.026342,0.83757
2026-01-15_20:19:30_vit_large_patch16_dinov3,0.965239,0.052297,0.851112
2026-01-12_22:22:54,0.950515,0.09481,0.834901
2026-01-15_14:30:30_convnextv2_tiny,0.92189,0.076614,0.80492
