In [22]:
import numpy as np
import pandas as pd

In [31]:
model_metadata = pd.read_csv("../../config/models.csv")
palette = model_metadata.set_index("description")["color"].to_dict()
model_renaming = model_metadata.set_index("name")["description"].to_dict()
model_metadata

Unnamed: 0,name,description,color
0,CADD,CADD,C0
1,GPN-MSA,GPN-MSA,C1
2,Borzoi,Borzoi,C2
3,Enformer,Enformer,C3
4,GPN,gLM-Promoter,C4
5,NucleotideTransformer,NT,C5
6,HyenaDNA,HyenaDNA,C6
7,Caduceus,Caduceus,C7
8,CADD+GPN-MSA+Borzoi,Ensemble,C8


In [23]:
mendelian_traits = (
    pd.read_csv("../../config/omim/filtered_traits.txt", header=None, dtype=str)
    .values.ravel().tolist()
)
complex_traits = (
    pd.read_csv("../../config/gwas/independent_traits_filtered.csv", header=None)
    .values.ravel().tolist()
)
#complex_traits_n30 = (
#    pd.read_csv("../../config/gwas/independent_traits_filtered_n30.csv", header=None)
#    .values.ravel().tolist()
#)

complex_trait_renaming = pd.read_csv(
    "../../results/gwas/raw/release1.1/UKBB_94traits_release1.traits", sep="\t",
    usecols=["trait", "description"]
).set_index("trait")["description"].to_dict()

mendelian_trait_renaming = {
    "600886": "Hyperferritinemia",
    "613985": "Beta-thalassemia",
    "614743": "Pulmonary fibrosis",
    "306900": "Hemophilia B",
    "250250": "Cartilage-hair hypoplasia",
    "174500": "Preaxial polydactyly II",
    "143890": "Hypercholesterolemia-1",
    "210710": "Dwarfism (MOPD1)",
}

trait_renaming = {**complex_trait_renaming, **mendelian_trait_renaming}

dataset_renaming = {
    "mendelian_traits_matched_9": "Mendelian traits",
    "complex_traits_matched_9": "Complex traits",
}

subset_renaming = trait_renaming

In [48]:
datasets = [
    "mendelian_traits_matched_9",
    "complex_traits_matched_9",
]

subsets = {
    "mendelian_traits_matched_9": mendelian_traits,
    "complex_traits_matched_9": complex_traits,
}

#linear_probing_subsets = [f"non_coding_AND_{trait}" for trait in complex_traits_n30]

modalities = [
    "Zero-shot",
    "Linear probing",
]

models = [
    "Borzoi",
    "GPN-MSA",
    "CADD",
    "CADD+GPN-MSA+Borzoi",
]

def get_model_path(model, modality, dataset, subset):
    if modality == "Linear probing":
        predictor = f"{model}.LogisticRegression.chrom.subset_from_all"
    elif modality == "Zero-shot":
        if model == "CADD":
            predictor = "CADD.plus.RawScore"
        elif model in ["Enformer", "Borzoi"]:
            predictor = f"{model}_L2_L2.plus.all"
        else:
            if "mendelian" in dataset:
                llr_version = "LLR"
                sign = "minus"
            elif "complex" in dataset:
                llr_version = "absLLR"
                sign = "plus"
            predictor = f"{model}_{llr_version}.{sign}.score"
    return f"../../results/dataset/{dataset}/metrics_by_chrom_weighted_average/{subset}/{predictor}.csv"

In [49]:
rows = []
for dataset in datasets:
    for subset in subsets[dataset]:
        for modality in modalities:
            for model in models:
                if modality == "Zero-shot" and "+" in model: continue
                model_for_path = model
                if "mendelian" in dataset and model == "CADD+GPN-MSA+Borzoi":
                    model_for_path = "CADD+GPN-MSA_LLR+Borzoi_L2_L2"
                path = get_model_path(model_for_path, modality, dataset, subset)
                df = pd.read_csv(path).iloc[0]
                rows.append([
                    dataset_renaming.get(dataset, dataset),
                    subset_renaming.get(subset, subset),
                    modality,
                    model_renaming.get(model, model),
                    df["score"],
                    df["se"],
                ])
df = pd.DataFrame(rows, columns=["dataset", "subset", "modality", "model", "score", "se"])
df

Unnamed: 0,dataset,subset,modality,model,score,se
0,Mendelian traits,Hyperferritinemia,Zero-shot,Borzoi,0.131633,2.776946e-17
1,Mendelian traits,Hyperferritinemia,Zero-shot,GPN-MSA,0.964481,0.000000e+00
2,Mendelian traits,Hyperferritinemia,Zero-shot,CADD,0.956952,2.221557e-16
3,Mendelian traits,Hyperferritinemia,Linear probing,Borzoi,0.314713,0.000000e+00
4,Mendelian traits,Hyperferritinemia,Linear probing,GPN-MSA,0.964731,0.000000e+00
...,...,...,...,...,...,...
177,Complex traits,Blood clot in the leg,Zero-shot,CADD,0.481151,9.997103e-02
178,Complex traits,Blood clot in the leg,Linear probing,Borzoi,0.573701,7.382674e-02
179,Complex traits,Blood clot in the leg,Linear probing,GPN-MSA,0.487698,9.540585e-02
180,Complex traits,Blood clot in the leg,Linear probing,CADD,0.498392,1.041968e-01


In [50]:
def best_modality(df):
    res = df.sort_values("score", ascending=False).iloc[0]
    return res[["score", "se"]]

df = df.groupby(["dataset", "subset", "model"], sort=False).apply(best_modality).reset_index()
df

Unnamed: 0,dataset,subset,model,score,se
0,Mendelian traits,Hyperferritinemia,Borzoi,0.314713,0.000000e+00
1,Mendelian traits,Hyperferritinemia,GPN-MSA,0.964731,0.000000e+00
2,Mendelian traits,Hyperferritinemia,CADD,0.980622,1.110779e-16
3,Mendelian traits,Hyperferritinemia,Ensemble,0.985000,0.000000e+00
4,Mendelian traits,Beta-thalassemia,Borzoi,0.926512,2.221557e-16
...,...,...,...,...,...
99,Complex traits,Balding Type 4,Ensemble,0.625071,7.998706e-02
100,Complex traits,Blood clot in the leg,Borzoi,0.573701,7.382674e-02
101,Complex traits,Blood clot in the leg,GPN-MSA,0.550735,1.096933e-01
102,Complex traits,Blood clot in the leg,CADD,0.498392,1.041968e-01


In [51]:
def format_score(x):
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

def format_se(x):
    assert (x * 100).max() < 100
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

#df["value"] = format_score(df.score) + "$\pm$" + format_se(df.se)
#df["value"] = format_score(df.score)
#df["value"] = df.score.apply(lambda x: f"{x:.2f}") + "$\pm$" + df.se.apply(lambda x: f"{x:.2f}")
df["value"] = df.score.apply(lambda x: f"{x:.3f}")
df

Unnamed: 0,dataset,subset,model,score,se,value
0,Mendelian traits,Hyperferritinemia,Borzoi,0.314713,0.000000e+00,0.315
1,Mendelian traits,Hyperferritinemia,GPN-MSA,0.964731,0.000000e+00,0.965
2,Mendelian traits,Hyperferritinemia,CADD,0.980622,1.110779e-16,0.981
3,Mendelian traits,Hyperferritinemia,Ensemble,0.985000,0.000000e+00,0.985
4,Mendelian traits,Beta-thalassemia,Borzoi,0.926512,2.221557e-16,0.927
...,...,...,...,...,...,...
99,Complex traits,Balding Type 4,Ensemble,0.625071,7.998706e-02,0.625
100,Complex traits,Blood clot in the leg,Borzoi,0.573701,7.382674e-02,0.574
101,Complex traits,Blood clot in the leg,GPN-MSA,0.550735,1.096933e-01,0.551
102,Complex traits,Blood clot in the leg,CADD,0.498392,1.041968e-01,0.498


In [52]:
df = df.pivot_table(
    columns=[
        #"modality",
        "model",
    ],
    index=[
        #"dataset",
        "subset",
    ],
    values="value",
    aggfunc="first", sort=False,
)
df = df.fillna("-")
df

model,Borzoi,GPN-MSA,CADD,Ensemble
subset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Hyperferritinemia,0.315,0.965,0.981,0.985
Beta-thalassemia,0.927,0.796,0.926,0.955
Pulmonary fibrosis,0.564,0.948,1.0,1.0
Hemophilia B,0.914,0.709,1.0,0.991
Cartilage-hair hypoplasia,0.594,0.987,0.923,0.918
Preaxial polydactyly II,0.546,0.959,0.969,0.967
Hypercholesterolemia-1,0.844,0.974,0.887,0.938
Dwarfism (MOPD1),0.484,1.0,1.0,1.0
Adult height,0.292,0.383,0.407,0.339
Platelet count,0.426,0.309,0.397,0.478


In [56]:
def boldface_best_model(x):
    threshold = 0.01
    y = x.astype(float)
    best_score = y.max()
    best_models = y[(best_score - y) < threshold].index
    res = x.copy()
    for best_model in best_models:
        res[best_model] = r"\textbf{" + res[best_model] + r"}"
    return res

df = df.apply(boldface_best_model, axis=1)
df

model,Borzoi,GPN-MSA,CADD,Ensemble
subset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Hyperferritinemia,0.315,0.965,\textbf{0.981},\textbf{0.985}
Beta-thalassemia,0.927,0.796,0.926,\textbf{0.955}
Pulmonary fibrosis,0.564,0.948,\textbf{1.000},\textbf{1.000}
Hemophilia B,0.914,0.709,\textbf{1.000},\textbf{0.991}
Cartilage-hair hypoplasia,0.594,\textbf{0.987},0.923,0.918
Preaxial polydactyly II,0.546,0.959,\textbf{0.969},\textbf{0.967}
Hypercholesterolemia-1,0.844,\textbf{0.974},0.887,0.938
Dwarfism (MOPD1),0.484,\textbf{1.000},\textbf{1.000},\textbf{1.000}
Adult height,0.292,0.383,\textbf{0.407},0.339
Platelet count,0.426,0.309,0.397,\textbf{0.478}


In [57]:
#df.index.names = [None, None]
df.index.name = None
#df.columns.names = [None, None]
df.columns.name = None

In [58]:
print(df.to_latex(multicolumn_format='c', escape=False))

\begin{tabular}{lllll}
\toprule
 & Borzoi & GPN-MSA & CADD & Ensemble \\
\midrule
Hyperferritinemia & 0.315 & 0.965 & \textbf{0.981} & \textbf{0.985} \\
Beta-thalassemia & 0.927 & 0.796 & 0.926 & \textbf{0.955} \\
Pulmonary fibrosis & 0.564 & 0.948 & \textbf{1.000} & \textbf{1.000} \\
Hemophilia B & 0.914 & 0.709 & \textbf{1.000} & \textbf{0.991} \\
Cartilage-hair hypoplasia & 0.594 & \textbf{0.987} & 0.923 & 0.918 \\
Preaxial polydactyly II & 0.546 & 0.959 & \textbf{0.969} & \textbf{0.967} \\
Hypercholesterolemia-1 & 0.844 & \textbf{0.974} & 0.887 & 0.938 \\
Dwarfism (MOPD1) & 0.484 & \textbf{1.000} & \textbf{1.000} & \textbf{1.000} \\
Adult height & 0.292 & 0.383 & \textbf{0.407} & 0.339 \\
Platelet count & 0.426 & 0.309 & 0.397 & \textbf{0.478} \\
Estimated heel bone mineral density & 0.308 & \textbf{0.432} & 0.422 & 0.406 \\
Mean corpuscular volume & 0.434 & 0.319 & 0.391 & \textbf{0.454} \\
Monocyte count & \textbf{0.561} & 0.404 & 0.375 & 0.535 \\
Hemoglobin A1c & 0.475 & 0.375 &