In [1]:
import pandas as pd
import polars as pl
from sqlmodel import Session, select
from sentence_transformers import SentenceTransformer

from undina_llm.db import SessionManager
from undina_llm.models import Response, Prompt, SystemPrompt, DrugLabel

  from tqdm.autonotebook import tqdm, trange


# 1. Load the annotations

In [2]:
manual_df = (
    pl.read_csv("../data/merged_diff.csv")
    .with_columns(
        pl.col("consensus", "additions")
        .str.split(",")
        .list.eval(pl.element().str.strip_chars())
        .list.unique()
    )
    .with_columns(
        manual_annotation=(
            pl.col("consensus")
            .list.set_union(pl.col("additions"))
            .list.unique()
        )
    )
    .select("set_id", "label_id", "spl_version", "title", "section", "label", "manual_annotation")
)

manual_df.with_columns(pl.col("manual_annotation").list.join(", ")).write_csv("../data/manual_annotations.csv")

manual_df.head(2)

set_id,label_id,spl_version,title,section,label,manual_annotation
str,str,i64,str,str,str,list[str]
"""908691b4-7950-4f3e-bbea-ea568f…","""e6f0f0dd-940a-490f-a404-56dd56…",1,"""Isoxsuprine Hydrochloride Tabl…","""DI""",,
"""5a0ba417-8a4a-4d7f-b85a-1839ee…","""8e64b577-1ecb-46f2-a7c8-3577a1…",5,"""These highlights do not includ…","""DI""",,


In [3]:
SQLITE_FILE = "sqlite:///../data/project.db"

manager = SessionManager(SQLITE_FILE)

with Session(manager.engine) as session:
    responses_df = pl.DataFrame(session.exec(select(Response).where((Response.section == "DI") & (Response.prompt_id == 2))).all())
    system_prompts_df = pl.DataFrame(session.exec(select(SystemPrompt)).all())
    prompts_df = pl.DataFrame(session.exec(select(Prompt)).all())
    drug_labels_df = pl.DataFrame(session.exec(select(DrugLabel)).all())

responses_df.head(2)

id,system_prompt_id,prompt_id,drug_label_id,section,model,temperature,response
i64,i64,i64,i64,str,str,f64,str
1469,1,2,1,"""DI""","""gpt-4o-2024-05-13""",0.0,"""maprotiline,tricyclic antidepr…"
1470,1,2,2,"""DI""","""gpt-4o-2024-05-13""",0.0,""""""


In [4]:
gpt_df = (
    responses_df
    .join(drug_labels_df, left_on=["drug_label_id"], right_on=["id"])
    .drop(["system_prompt_id", "prompt_id", "drug_label_id", "section", "model", "temperature"])
    .select(
        "set_id", "label_id", "spl_version", "title", pl.lit("DI").alias("section"), 
        pl.col("DI").alias("label"), pl.col("response").alias("gpt_annotation")
    )
)

gpt_df.write_csv("../data/gpt_annotations.csv")

gpt_df.head(2)

set_id,label_id,spl_version,title,section,label,gpt_annotation
str,str,str,str,str,str,str
"""297f0888-729c-4ce6-8779-6b239a…","""297f0888-729c-4ce6-8779-6b239a…","""1""","""ALBALON® (naphazoline hydrochl…","""DI""","""Drug Interactions: Concurrent…","""maprotiline,tricyclic antidepr…"
"""908691b4-7950-4f3e-bbea-ea568f…","""e6f0f0dd-940a-490f-a404-56dd56…","""1""","""Isoxsuprine Hydrochloride Tabl…","""DI""",,""""""


# 2. Gather the unique annotation strings

In [5]:
manual_annotations = (
    manual_df
    .explode("manual_annotation")
    .select("manual_annotation")
    .drop_nulls()
    .unique()
    ["manual_annotation"]
    .to_list()
)
print(len(manual_annotations))

631


In [6]:
gpt_annotations = (
    gpt_df
    .select(pl.col("gpt_annotation").str.split(","))
    .explode("gpt_annotation")
    .filter(pl.col("gpt_annotation").ne('""'))
    .select(
        pl.col("gpt_annotation")
        .str.strip_chars()
        .str.to_lowercase()
    )
    .unique()
    ["gpt_annotation"]
    .to_list()
)

print(len(gpt_annotations))

1896


# 3. Gather RxNorm ingredient names

In [7]:
rxnorm_df = (
    pl.read_csv("../data/rxnorm_ingredients_synonyms.tsv", separator="\t")
    .melt(id_vars=["concept_id"], value_vars=["concept_name", "concept_synonym_name"], value_name="name")
    .drop("variable")
    .drop_nulls("name")
)

rxnorm_annotations = rxnorm_df.unique("name")["name"].to_list()
print(len(rxnorm_annotations))

rxnorm_df.head(2)

17131


concept_id,name
i64,str
501343,"""hepatitis B immune globulin"""
501488,"""Hepatitis B Vaccine"""


# 4. Compute embeddings for each annotation

In [8]:
embed_model = SentenceTransformer('llmrails/ember-v1')

manual_embeddings = embed_model.encode(manual_annotations)
gpt_embeddings = embed_model.encode(gpt_annotations)
rxnorm_embeddings = embed_model.encode(rxnorm_annotations)



In [9]:
manual_embeddings_df = pd.DataFrame(manual_embeddings, index=manual_annotations)
gpt_embeddings_df = pd.DataFrame(gpt_embeddings, index=gpt_annotations)
rxnorm_embeddings_df = pd.DataFrame(rxnorm_embeddings, index=rxnorm_annotations)

manual_embeddings_df.to_csv("../data/manual_embeddings.csv")
gpt_embeddings_df.to_csv("../data/gpt_embeddings.csv")
rxnorm_embeddings_df.to_csv("../data/rxnorm_embeddings.csv.zst")