In [1]:
from datasets import load_dataset

# load the "drug_metadata" config
ds = load_dataset("tahoebio/Tahoe-100M", "drug_metadata")  

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ds = ds["train"]

# extract the unique drug column
all_drugs = list(set(ds["drug"]))
print(f"{len(all_drugs)} unique drugs found")

379 unique drugs found


In [4]:
from dotenv import load_dotenv
load_dotenv()  
import os
api_key = os.getenv("LAMBDA_API_KEY")

In [20]:
from openai import OpenAI
import os

client = OpenAI(
    api_key=os.environ["LAMBDA_API_KEY"],
    base_url="https://api.lambda.ai/v1",
)

def summarize_drug(name: str) -> list[str]:
    messages = [
        {"role":"system","content":(
            "You’re a concise biomedical summarizer. "
            "Output exactly three bullets, each starting with “•”. "
            "Bullet 1: list only the molecular targets (no verbs). "
            "Bullet 2: list only the main downstream pathways (no targets). "
            "Bullet 3: list only the approved indications. "
            "Do NOT repeat the drug name, add headings, or give extra text."
        )},
        {"role":"user","content":(
            f"For the drug **{name}**, give these bullets:\n"
            "• Primary molecular target(s)\n"
            "• Key pathways modulated\n"
            "• Approved therapeutic indications\n"
        )}
    ]
    resp = client.chat.completions.create(
        model="llama3.3-70b-instruct-fp8",
        messages=messages,
        max_tokens=200,
        temperature=0.0,
        stop=["\n\n"]
    )
    # split into lines and strip leading bullets
    lines = resp.choices[0].message.content.strip().split("\n")
    return [l.lstrip("• ").strip() for l in lines if l.startswith("•")]

# example
bullets = summarize_drug("Imatinib")
for b in bullets:
    print("- ", b)

-  BCR-ABL, c-Kit, PDGFR
-  Ras/Raf/MEK/ERK, PI3K/AKT, JAK/STAT
-  Chronic myeloid leukemia, gastrointestinal stromal tumor, acute lymphoblastic leukemia, myeloproliferative neoplasms, hypereosinophilic syndrome, dermatofibrosarcoma protuberans


In [21]:
import json, time, os

summaries = {}
path = "drug_summaries.json"


for drug in all_drugs:
    if drug in summaries:
        continue
    try:
        text = summarize_drug(drug)  
        summaries[drug] = text
        print(f"[V] {drug}")
    except Exception as e:
        print(f"[X] {drug} → {e}")
        time.sleep(5)  # back‑off on errors
    time.sleep(0.3)     # throttle to avoid rate limits

# write out cache
with open(path, "w") as f:
    json.dump(summaries, f, indent=2)

[V] Decitabine
[V] Methotrexate
[V] Delamanid
[V] Demeclocycline
[V] 5-Fluorouracil
[V] LB-100
[V] Glycyrrhizic acid
[V] (S)-Crizotinib
[V] Nafamostat (mesylate)
[V] Diammonium Glycyrrhizinate
[V] Vinblastine (sulfate)
[V] AZD-7648
[V] Dinaciclib
[V] Erlotinib
[V] Almonertinib (hydrochloride)
[V] Daptomycin
[V] Sulfatinib
[V] Tolmetin
[V] Idarubicin (hydrochloride)
[V] Dexamethasone
[V] Loperamide (hydrochloride)
[V] Fostamatinib (disodium hexahydrate)
[V] APTO-253
[V] Quinidine (15% dihydroquinidine)
[V] Lapatinib ditosylate
[V] Pioglitazone
[V] Nimesulide
[V] Paclitaxel
[V] Lenalidomide (hemihydrate)
[V] AT7519
[V] Drospirenone
[V] Gefitinib
[V] ETC-206
[V] Binimetinib
[V] Relugolix
[V] Resveratrol
[V] MK-8353
[V] Belumosudil (mesylate)
[V] Canagliflozin (hemihydrate)
[V] Radotinib
[V] Valdecoxib
[V] Ouabain (Octahydrate)
[V] Oxaliplatin
[V] Apalutamide
[V] Rapamycin
[V] Altretamine
[V] Cinacalcet
[V] Afatinib
[V] Methylthiouracil
[V] Irinotecan (hydrochloride)
[V] Clotrimazole
[V] E

In [23]:
import numpy as np
from sentence_transformers import SentenceTransformer

# Load summaries & sort so we know the order
with open("drug_summaries.json") as f:
    summaries = json.load(f)
drug_list = sorted(summaries.keys())
texts     = [summaries[d] for d in drug_list]

# Embed
embedder = SentenceTransformer("all-mpnet-base-v2", device="cuda")
embeddings = embedder.encode(texts, batch_size=32, show_progress_bar=True)
embeddings.shape == (len(drug_list), 768)

# Save
np.save("drug_summary_embeddings.npy", embeddings)
with open("drug_index.json","w") as f:
    json.dump(drug_list, f, indent=2)

Batches: 100%|██████████| 12/12 [00:00<00:00, 25.66it/s]


In [25]:
embeddings.shape

(379, 768)

In [26]:
# dimension reduction
from sklearn.decomposition import PCA
embs = np.load("drug_summary_embeddings.npy")

pca = PCA(n_components=50)
lowd = pca.fit_transform(embs)  # shape (N,50)
np.save("drug_summary_lowd.npy", lowd)