In [1]:
import numpy as np
import pandas as pd
import polars as pl
import sklearn.metrics.pairwise

In [2]:
manual_embeddings_df = pd.read_csv("../data/manual_embeddings.csv", index_col=0)
gpt_embeddings_df = pd.read_csv("../data/gpt_embeddings.csv", index_col=0)
rxnorm_embeddings_df = pd.read_csv("../data/rxnorm_embeddings.csv.zst", index_col=0)

In [3]:
def compute_cosine_df(**kwargs):
    assert len(kwargs) == 2
    x_name, y_name = list(kwargs.keys())
    X = kwargs[x_name]
    Y = kwargs[y_name]
    
    return (
        pd.DataFrame(
            sklearn.metrics.pairwise.cosine_similarity(X, Y),
            index=X.index,
            columns=Y.index,
        )
        .melt(ignore_index=False, var_name=y_name, value_name="cosine_similarity")
        .rename_axis(index=x_name)
        .reset_index()
        .pipe(pl.DataFrame)
    )

# 1. Direct comparison between GPT and manual annotations using semantic similarity matching

In [4]:
DIRECT_MATCH_CUTOFF = 0.7097613

In [5]:
manual_vs_gpt_cosine_df = compute_cosine_df(manual_annotation=manual_embeddings_df, gpt_annotation=gpt_embeddings_df)

In [6]:
manual_annotations_df = pl.read_csv("../data/manual_annotations.csv")
gpt_annotations_df = pl.read_csv("../data/gpt_annotations.csv")

manual_df = (
    manual_annotations_df
    .drop_nulls("manual_annotation")
    .select(
        "set_id", "label_id", 
        (
            pl.col("manual_annotation")
            .str.split(",")
            .list.eval(pl.element().str.strip_chars())
        )
    )
    .explode("manual_annotation")
)

gpt_df = (
    gpt_annotations_df
    .drop_nulls("gpt_annotation")
    .select(
        "set_id", "label_id", 
        (
            pl.col("gpt_annotation")
            .str.split(",")
            .list.eval(pl.element().str.strip_chars().str.to_lowercase())
        )
    )
    .explode("gpt_annotation")
    .filter(pl.col("gpt_annotation").ne('""') & pl.col("gpt_annotation").str.len_chars().ne(0))
)

unmatched_rows = (
    manual_df
    .join(gpt_df, on=["set_id", "label_id"])
    .join(manual_vs_gpt_cosine_df, on=["manual_annotation", "gpt_annotation"], how="anti")
    .shape[0]
)
assert unmatched_rows == 0

merged_df = (
    manual_df
    .join(gpt_df, on=["set_id", "label_id"], how="full")
    .join(manual_vs_gpt_cosine_df, on=["manual_annotation", "gpt_annotation"])
)

merged_df.head(2)

set_id,label_id,manual_annotation,set_id_right,label_id_right,gpt_annotation,cosine_similarity
str,str,str,str,str,str,f64
"""cdfbe0cd-eb15-45a1-ac17-531bcd…","""1d6c9e9d-e17d-4609-91fa-75e5bd…","""cyp1a2 substrates""","""cdfbe0cd-eb15-45a1-ac17-531bcd…","""1d6c9e9d-e17d-4609-91fa-75e5bd…","""colestipol""",0.42328
"""cdfbe0cd-eb15-45a1-ac17-531bcd…","""1d6c9e9d-e17d-4609-91fa-75e5bd…","""ocaliva""","""cdfbe0cd-eb15-45a1-ac17-531bcd…","""1d6c9e9d-e17d-4609-91fa-75e5bd…","""colestipol""",0.450723


In [7]:
manual_annotations_best_match_df = (
    merged_df
    .group_by("set_id", "label_id", "manual_annotation")
    .agg(pl.col("cosine_similarity").max())
    .with_columns(match=pl.col("cosine_similarity").ge(DIRECT_MATCH_CUTOFF))
)

gpt_annotations_best_match_df = (
    merged_df
    .group_by("set_id", "label_id", "gpt_annotation")
    .agg(pl.col("cosine_similarity").max())
    .with_columns(match=pl.col("cosine_similarity").ge(DIRECT_MATCH_CUTOFF))
)

In [8]:
(
    manual_annotations_best_match_df
    .select(
        N=pl.len(),
        P=pl.col("match").sum(),
    )
    .with_columns(recall=pl.col("P").truediv(pl.col("N")))
)

N,P,recall
u32,u32,f64
1130,827,0.731858


In [9]:
(
    gpt_annotations_best_match_df
    .select(
        N=pl.len(),
        P=pl.col("match").sum(),
    )
    .with_columns(precision=pl.col("P").truediv(pl.col("N")))
)

N,P,precision
u32,u32,f64
729,721,0.989026


# 2. Evaluation using RxNorm semantic similarity. Best match only!

In [10]:
RXNORM_MATCH_CUTOFF = 0.9

In [11]:
rxnorm_df = (
    pl.read_csv("../data/rxnorm_ingredients_synonyms.tsv", separator="\t")
    .melt(id_vars=["concept_id"], value_vars=["concept_name", "concept_synonym_name"], value_name="name")
    .drop("variable")
    .drop_nulls("name")
)

In [12]:
manual_cosine_df = compute_cosine_df(manual=manual_embeddings_df, rxnorm=rxnorm_embeddings_df)
gpt_cosine_df = compute_cosine_df(gpt=gpt_embeddings_df, rxnorm=rxnorm_embeddings_df)

In [13]:
manual_eval_df = (
    manual_cosine_df
    .filter(pl.col("cosine_similarity").ge(RXNORM_MATCH_CUTOFF))
    .sort("cosine_similarity")
    .filter(pl.col("cosine_similarity").eq(pl.col("cosine_similarity").max().over("manual")))
    .join(rxnorm_df, left_on=["rxnorm"], right_on=["name"])
)

gpt_eval_df = (
    gpt_cosine_df
    .filter(pl.col("cosine_similarity").ge(RXNORM_MATCH_CUTOFF))
    .sort("cosine_similarity")
    .filter(pl.col("cosine_similarity").eq(pl.col("cosine_similarity").max().over("gpt")))
    .join(rxnorm_df, left_on=["rxnorm"], right_on=["name"])
)

In [14]:
manual_rxnorm_df = (
    manual_annotations_df
    .with_columns(
        pl.col("manual_annotation").str.split(",").list.eval(pl.element().str.strip_chars())
    )
    .explode("manual_annotation")
    .join(manual_eval_df, left_on=["manual_annotation"], right_on=["manual"])
    .select("set_id", "label_id", "concept_id")
)

gpt_rxnorm_df = (
    gpt_annotations_df
    .with_columns(
        pl.col("gpt_annotation").str.split(",").list.eval(pl.element().str.strip_chars())
    )
    .explode("gpt_annotation")
    .join(gpt_eval_df, left_on=["gpt_annotation"], right_on=["gpt"])
    .select("set_id", "label_id", "concept_id")
)

In [15]:
(
    manual_rxnorm_df
    .join(gpt_rxnorm_df, on=["set_id", "label_id"], suffix="_gpt", how="full")
    .with_columns(match=pl.col("concept_id").eq(pl.col("concept_id_gpt")))
    .group_by("set_id", "label_id", "concept_id")
    .agg(pl.col("match").any().cast(pl.UInt32))
    .select(
        N=pl.len(),
        P=pl.col("match").sum(),
    )
    .with_columns(recall=pl.col("P").truediv(pl.col("N")))
)

N,P,recall
u32,u32,f64
674,399,0.591988


In [16]:
(
    gpt_rxnorm_df
    .join(manual_rxnorm_df, on=["set_id", "label_id"], suffix="_manual")
    .with_columns(match=pl.col("concept_id").eq(pl.col("concept_id_manual")))
    .group_by("set_id", "label_id", "concept_id")
    .agg(pl.col("match").any().cast(pl.UInt32))
    .select(
        N=pl.len(),
        P=pl.col("match").sum(),
    )
    .with_columns(precision=pl.col("P").truediv(pl.col("N")))
)

N,P,precision
u32,u32,f64
430,399,0.927907
