In [26]:
import pandas as pd

In [27]:
TASK = "heatsink"

In [None]:
field_name_maps = {
    "rolling": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "deformation": r"\textbf{Deformation (mm)}",
        "nodes_LE": r"\textbf{Logarithmic Strain ($\mathbf{\times 10^{-2}}$)}",
        "nodes_PEEQ": r"\textbf{Equivalent Plastic Strain ($\mathbf{\times 10^{-2}}$)}",
        "nodes_mises_stress": r"\textbf{Mises Stress (MPa)}",
        "nodes_stresses": r"\textbf{Stress (MPa)}",
        "custom": r"\textbf{Rel Custom Error (-)}",
        "mae": r"\textbf{MAE (-)}",
        "r2": r"\textbf{R2 (-)}"
    },
    "forming": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "deformation": r"\textbf{Deformation (mm)}",
        "nodes_LE": r"\textbf{Logarithmic Strain ($\mathbf{\times 10^{-2}}$)}",
        "nodes_PEEQ": r"\textbf{Equivalent Plastic Strain ($\mathbf{\times 10^{-2}}$)}",
        "nodes_mises_stress": r"\textbf{Mises Stress (MPa)}",
        "nodes_stresses": r"\textbf{Stress (MPa)}",
        "custom": r"\textbf{Rel Custom Error (-)}",
        "mae": r"\textbf{MAE (-)}",
        "r2": r"\textbf{R2 (-)}"
    },
    "motor": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "deformation": r"\textbf{Deformation (m)}",
        "logarithmic_strain": r"\textbf{Logarithmic Strain ($\mathbf{\times 10^{-2}}$)}",
        "principal_strain": r"\textbf{Principal Strain ($\mathbf{\times 10^{-2}}$)}",
        "stress": r"\textbf{Stress (MPa)}",
        "stress_cauchy": r"\textbf{Cauchy Stress (MPa)}",
        "stress_mises": r"\textbf{Mises Stress (MPa)}",
        "stress_principal": r"\textbf{Principal Stress (MPa)}",
        "total_strain": r"\textbf{Total Strain ($\mathbf{\times 10^{-2}}$)}",
        "custom": r"\textbf{Rel Custom Error (-)}",
        "mae": r"\textbf{MAE (-)}",
        "r2": r"\textbf{R2 (-)}"
    },
    "motor_geometric_pointnet": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "deformation": r"\textbf{Deformation (m)}",
        "logarithmic_strain": r"\textbf{Logarithmic Strain ($\mathbf{\times 10^{-2}}$)}",
        "principal_strain": r"\textbf{Principal Strain ($\mathbf{\times 10^{-2}}$)}",
        "stress": r"\textbf{Stress (MPa)}",
        "stress_cauchy": r"\textbf{Cauchy Stress (MPa)}",
        "stress_mises": r"\textbf{Mises Stress (MPa)}",
        "stress_principal": r"\textbf{Principal Stress (MPa)}",
        "total_strain": r"\textbf{Total Strain ($\mathbf{\times 10^{-2}}$)}",
        "custom": r"\textbf{Rel Custom Error (-)}"
    },
    "motor_2D": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "deformation": r"\textbf{Deformation (m)}",
        "logarithmic_strain": r"\textbf{Logarithmic Strain ($\mathbf{\times 10^{-2}}$)}",
        "principal_strain": r"\textbf{Principal Strain ($\mathbf{\times 10^{-2}}$)}",
        "stress": r"\textbf{Stress (MPa)}",
        "stress_cauchy": r"\textbf{Cauchy Stress (MPa)}",
        "stress_mises": r"\textbf{Mises Stress (MPa)}",
        "stress_principal": r"\textbf{Principal Stress (MPa)}",
        "total_strain": r"\textbf{Total Strain ($\mathbf{\times 10^{-2}}$)}",
        "custom": r"\textbf{Rel Custom Error (-)}"
    },
    "heatsink": {
        "all": r"\textbf{All Fields Normalized Avg (-)}",
        "U": r"\textbf{Velocity (m/s)}",
        "p": r"\textbf{Pressure (kPa)}",
        "T": r"\textbf{Temperature (K)}",
        "custom": r"\textbf{Rel Custom Error (-)}",
        "mae": r"\textbf{MAE (-)}",
        "r2": r"\textbf{R2 (-)}"
    }
}

In [None]:
df = pd.read_pickle(f"model_selection_results_rebuttal/results_{TASK}.pkl")
print(df.info())

if TASK == "heatsink":
    df = df.drop(columns=["test_loss_source_deformation", "test_loss_target_deformation"])

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40 entries, 0 to 39
Data columns (total 20 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   model_name                      40 non-null     object 
 1   da_algorithm_name               40 non-null     object 
 2   model_selection_algorithm_name  40 non-null     object 
 3   seed                            40 non-null     int64  
 4   test_loss_source                40 non-null     float64
 5   test_loss_target                40 non-null     float64
 6   test_loss_source_mae            40 non-null     float64
 7   test_loss_source_r2             40 non-null     float64
 8   test_loss_target_mae            40 non-null     float64
 9   test_loss_target_r2             40 non-null     float64
 10  test_loss_source_deformation    40 non-null     float64
 11  test_loss_target_deformation    40 non-null     float64
 12  test_loss_source_T              40 non

In [32]:
group_cols = ["model_name", "da_algorithm_name", "model_selection_algorithm_name"]

# pick out all your test-loss columns
loss_cols = [c for c in df.columns if c.startswith("test_loss_")]

# build agg dict: each loss_col → [mean, std]
agg_dict = {col: ["mean", "std"] for col in loss_cols}

# do the groupby-agg
agg_df = (
    df
    .groupby(group_cols)[loss_cols]
    .agg(agg_dict)
    .reset_index()
)

# flatten the MultiIndex columns
agg_df.columns = [
    f"{col}_{stat}" if stat else col
    for col, stat in agg_df.columns.to_flat_index()
]

print(agg_df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10 entries, 0 to 9
Data columns (total 31 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   model_name                      10 non-null     object 
 1   da_algorithm_name               10 non-null     object 
 2   model_selection_algorithm_name  10 non-null     object 
 3   test_loss_source_mean           10 non-null     float64
 4   test_loss_source_std            10 non-null     float64
 5   test_loss_target_mean           10 non-null     float64
 6   test_loss_target_std            10 non-null     float64
 7   test_loss_source_mae_mean       10 non-null     float64
 8   test_loss_source_mae_std        10 non-null     float64
 9   test_loss_source_r2_mean        10 non-null     float64
 10  test_loss_source_r2_std         10 non-null     float64
 11  test_loss_target_mae_mean       10 non-null     float64
 12  test_loss_target_mae_std        10 non-

In [33]:
if TASK in ["motor", "motor_2D", "motor_geometric_pointnet"]:
    for col in list(agg_df.columns):
        if "stress" in col:
            agg_df[col] = agg_df[col] / 1e6
        if "strain" in col:
            agg_df[col] = agg_df[col] * 1e2

if TASK in ["rolling", "forming"]:
    for col in list(agg_df.columns):
        if ("nodes_PEEQ" in col) or ("nodes_LE" in col):
            agg_df[col] = agg_df[col] * 1e2

In [34]:
agg_df

Unnamed: 0,model_name,da_algorithm_name,model_selection_algorithm_name,test_loss_source_mean,test_loss_source_std,test_loss_target_mean,test_loss_target_std,test_loss_source_mae_mean,test_loss_source_mae_std,test_loss_source_r2_mean,...,test_loss_target_U_mean,test_loss_target_U_std,test_loss_source_p_mean,test_loss_source_p_std,test_loss_target_p_mean,test_loss_target_p_std,test_loss_source_custom_mean,test_loss_source_custom_std,test_loss_target_custom_mean,test_loss_target_custom_std
0,PointNet,-,-,0.289102,0.004506,0.483857,0.054346,0.143781,0.002413,0.850773,...,0.04293,0.003839,193.4394,8.965762,909.832581,244.474657,0.009295,0.000255,0.02506,0.005286
1,PointNet,deep_coral,DEV,0.229201,0.016376,0.343492,0.014526,0.152822,0.012538,0.835649,...,0.040473,0.002814,178.709297,3.757132,635.996712,199.52261,0.010537,0.001931,0.03441,0.010235
2,PointNet,deep_coral,IWV,0.236677,0.022541,0.374063,0.033831,0.159047,0.01684,0.825653,...,0.042388,0.004847,193.481995,10.639706,953.772827,382.204503,0.011851,0.003303,0.038422,0.014105
3,PointNet,deep_coral,SB,0.218696,0.004284,0.375178,0.029989,0.145324,0.003884,0.849416,...,0.044353,0.003184,189.003838,14.340158,781.205627,69.485238,0.009397,0.000622,0.034952,0.011381
4,PointNet,deep_coral,TB,0.249648,0.019236,0.336213,0.005373,0.168491,0.015101,0.809169,...,0.038369,0.000723,187.414421,14.548829,771.331429,363.399794,0.01355,0.003511,0.035423,0.01082
5,Transolver,-,-,0.237241,0.002229,0.469901,0.048539,0.110502,0.001041,0.894553,...,0.040473,0.004128,256.344341,15.616404,1623.495636,210.078291,0.006504,0.000257,0.013958,0.00661
6,Transolver,deep_coral,DEV,0.179159,0.002008,0.332794,0.016981,0.111716,0.002129,0.893424,...,0.038187,0.001431,237.457539,4.949937,1600.520752,197.426853,0.006569,0.000346,0.01038,0.003399
7,Transolver,deep_coral,IWV,0.17722,0.000667,0.336497,0.021605,0.110218,0.000459,0.895425,...,0.03839,0.001918,243.133797,6.164716,1697.526215,199.257877,0.006504,0.000154,0.008932,0.001377
8,Transolver,deep_coral,SB,0.177245,0.001886,0.350362,0.005484,0.110388,0.001856,0.896004,...,0.039311,0.000713,263.460526,26.325322,1812.385498,179.982828,0.006572,0.000485,0.011064,0.002679
9,Transolver,deep_coral,TB,0.177635,0.000633,0.323479,0.014063,0.109946,0.000461,0.894366,...,0.037716,0.00189,238.612938,12.179123,1523.076324,64.679712,0.006251,0.000117,0.008357,0.001917


In [35]:
agg_df.columns

Index(['model_name', 'da_algorithm_name', 'model_selection_algorithm_name',
       'test_loss_source_mean', 'test_loss_source_std',
       'test_loss_target_mean', 'test_loss_target_std',
       'test_loss_source_mae_mean', 'test_loss_source_mae_std',
       'test_loss_source_r2_mean', 'test_loss_source_r2_std',
       'test_loss_target_mae_mean', 'test_loss_target_mae_std',
       'test_loss_target_r2_mean', 'test_loss_target_r2_std',
       'test_loss_source_T_mean', 'test_loss_source_T_std',
       'test_loss_target_T_mean', 'test_loss_target_T_std',
       'test_loss_source_U_mean', 'test_loss_source_U_std',
       'test_loss_target_U_mean', 'test_loss_target_U_std',
       'test_loss_source_p_mean', 'test_loss_source_p_std',
       'test_loss_target_p_mean', 'test_loss_target_p_std',
       'test_loss_source_custom_mean', 'test_loss_source_custom_std',
       'test_loss_target_custom_mean', 'test_loss_target_custom_std'],
      dtype='object')

In [37]:
import pandas as pd
import re

def df_to_latex_table_with_highlights(
    df: pd.DataFrame,
    caption: str,
    label: str,
    float_fmt: str = "{:.3f}",
    threshold_factor: float = 100.0,
    da_name_map: dict = None,
    field_name_map: dict = None,
    exclude_selection: str = "TB",
    # Shading controls
    shade_best: bool = True,
    shade_baseline: bool = True,
    best_color_hex: str = "DFF0D8",      # light green
    baseline_color_hex: str = "FFF4CC",  # light beige
    baseline_da_markers = ("-", "", None),
    baseline_sel_markers = ("-", "", None),
    include_color_definitions: bool = True,
) -> str:
    """
    LaTeX table with:
      • mean±std metric cells or ★ for outliers,
      • multirow by model with cmidrule separators,
      • global lowest TARGET 'all' is bold+underlined (metric cell),
      • per-model lowest TARGET 'all' is underlined (metric cell),
      • DA + Model Selection also bold+underlined on the global best row,
      • Model name bold+underlined if the model contains the global best,
      • NEW: shade best row per model (green) and baseline row (beige),
             but DO NOT shade the first (Model) column.
    """
    if da_name_map is None:
        da_name_map = {"deep_coral": "Deep Coral", "cmd": "CMD", "DANN": "DANN"}
    if field_name_map is None:
        field_name_map = {
            "all": "All Fields Normalized Avg (-)",
            "deformation": "Deformation (mm)",
            "nodes_LE": "Logarithmic Strain (-)",
            "nodes_PEEQ": "Equivalent Plastic Strain (-)",
            "nodes_mises_stress": "Mises Stress (Pa)",
            "nodes_stresses": "Stress (Pa)",
        }

    # Parse field columns
    field_map = {}
    pat = re.compile(r"^test_loss_(source|target)(?:_(.+?))?_(mean|std)$")
    for c in df.columns:
        m = pat.match(c)
        if m:
            dom, fld, stat = m.groups()
            base = fld or "all"
            field_map.setdefault(base, {}).setdefault(dom.upper(), {})[stat] = c
    if "all" not in field_map or "TARGET" not in field_map["all"]:
        raise ValueError("Missing 'all' TARGET mean/std columns.")

    # Thresholds for ★
    mean_cols = [c for c in df.columns if c.endswith("_mean")]
    med = df[mean_cols].median()

    all_tgt_mean = field_map["all"]["TARGET"]["mean"]

    # Bests (excluding certain selection rows)
    ok_mask = df["model_selection_algorithm_name"] != exclude_selection
    global_min = df[ok_mask][all_tgt_mean].min()
    per_model_min = (
        df[ok_mask].groupby("model_name")[all_tgt_mean].min().to_dict()
    )

    fields = sorted(field_map.keys(), key=lambda x: (x != "all", x))
    nf = len(fields)
    total_cols = 3 + 2 * nf

    lines = [
        r"\begin{table}[h]",
        r"  \centering",
        f"  \\caption{{{caption}}}",
        f"  \\label{{{label}}}",
        r"  \resizebox{\textwidth}{!}{%",
    ]
    if include_color_definitions:
        lines += [
            r"  % Requires \usepackage[table]{xcolor}",
            fr"  \definecolor{{bestrow}}{{HTML}}{{{best_color_hex}}}",
            fr"  \definecolor{{baselinerow}}{{HTML}}{{{baseline_color_hex}}}",
        ]
    lines += [
        "  \\begin{tabular}{" + "lll" + "c" * (2 * nf) + "}",
        "    \\toprule",
    ]

    # Header rows
    hdr1 = (
        r"    \multirow{2}{*}{\textbf{Model}}"
        r" & \multirow{2}{*}{\makecell{\textbf{DA}\\ \textbf{Algorithm}}}"
        r" & \multirow{2}{*}{\makecell{\textbf{Model}\\ \textbf{Selection}}}"
    )
    for f in fields:
        hdr1 += f" & \\multicolumn{{2}}{{c}}{{{field_name_map.get(f, f)}}}"
    hdr1 += r" \\"
    lines.append(hdr1)

    cm = "    "
    for i in range(nf):
        cm += f"\\cmidrule(lr){{{4 + 2*i}-{5 + 2*i}}} "
    lines.append(cm.strip())

    hdr2 = "      &   &  & " + " & ".join([r"\textbf{SRC} & \textbf{TGT}"] * nf) + r" \\"
    lines.extend([hdr2, "    \\midrule"])

    # Body
    for mi, model in enumerate(df["model_name"].unique()):
        sub = df[df["model_name"] == model].sort_values(
            ["da_algorithm_name", "model_selection_algorithm_name"]
        )
        n = len(sub)

        # Does this model contain the global best row?
        model_has_global = any(
            (sub["model_selection_algorithm_name"] != exclude_selection)
            & (sub[all_tgt_mean] == global_min)
        )

        prev_da = None
        for row_i, (idx, row) in enumerate(sub.iterrows()):
            if prev_da is not None and row["da_algorithm_name"] != prev_da:
                lines.append(f"    \\cmidrule(lr){{2-{total_cols}}}")

            da_raw = row["da_algorithm_name"]
            da_disp = da_name_map.get(da_raw, da_raw)
            sel = row["model_selection_algorithm_name"]

            # Flags
            is_global = (sel != exclude_selection) and (row[all_tgt_mean] == global_min)
            is_model_best = (sel != exclude_selection) and (row[all_tgt_mean] == per_model_min.get(model))
            is_baseline = (da_raw in baseline_da_markers) and (sel in baseline_sel_markers)

            # Decide shade name (None, "bestrow", "baselinerow")
            shade_name = None
            if shade_best and is_model_best:
                shade_name = "bestrow"
            elif shade_baseline and is_baseline:
                shade_name = "baselinerow"

            # First column: Model (no shading)
            if row_i == 0:
                model_cell = model
                if model_has_global:
                    model_cell = f"\\underline{{\\textbf{{{model_cell}}}}}"
                line = f"    \\multirow{{{n}}}{{*}}{{{model_cell}}} & "
            else:
                line = "    & "

            # Helper to prefix shading for non-first columns
            def shade(cell_text: str) -> str:
                if shade_name:
                    return fr"\cellcolor{{{shade_name}}}" + cell_text
                return cell_text

            # DA & selection (apply shading to these cells, not the model)
            if is_global:
                da_cell = shade(f"\\underline{{\\textbf{{{da_disp}}}}}")
                sel_cell = shade(f"\\underline{{\\textbf{{{sel}}}}}")
            elif is_model_best:
                da_cell = shade(f"\\underline{{{da_disp}}}")
                sel_cell = shade(f"\\underline{{{sel}}}")
            else:
                da_cell = shade(da_disp)
                sel_cell = shade(sel)

            line += f"{da_cell} & {sel_cell}"

            # Metrics (each metric cell gets shading prefix if needed)
            for f in fields:
                for dom in ("SOURCE", "TARGET"):
                    mc = field_map[f][dom]["mean"]
                    sc = field_map[f][dom]["std"]
                    mv, sv = row[mc], row[sc]
                    if mv > threshold_factor * med[mc]:
                        cell = shade(r"$\star$")
                    else:
                        txt = f"{float_fmt.format(mv)}(\\pm{float_fmt.format(sv)})"
                        if f == "all" and dom == "TARGET":
                            if is_global:
                                cell = shade(f"$\\underline{{\\mathbf{{{txt}}}}}$")
                            elif is_model_best:
                                cell = shade(f"$\\underline{{{txt}}}$")
                            else:
                                cell = shade(f"${txt}$")
                        else:
                            cell = shade(f"${txt}$")
                    line += f" & {cell}"

            line += r" \\"
            lines.append(line)
            prev_da = row["da_algorithm_name"]

        lines.append("    \\midrule" if mi < df["model_name"].nunique() - 1 else "    \\bottomrule")

    lines.extend([r"  \end{tabular}", r"  }", r"\end{table}"])
    return "\n".join(lines)


In [None]:
table = (df_to_latex_table_with_highlights(
    agg_df,
    caption="Performance across different fields.",
    label="tab:fixed",
    da_name_map={"deep_coral":"Deep Coral","cmd":"CMD","DANN":"DANN"},
    field_name_map=field_name_maps[TASK],
    threshold_factor=100
))

print(table)