In [1]:
# 1. Imports & paths

from pathlib import Path
import zipfile, re, random, math
from collections import Counter, defaultdict

from lxml import etree
import pandas as pd
from sklearn.model_selection import train_test_split

DATA_DIR  = Path("../data")
XML_PATH  = DATA_DIR / "full database.xml"

In [None]:
# 2. Parse DrugBank -> name, indication, categories

root = etree.parse(XML_PATH.open("rb")).getroot()
ns   = {"db": "http://www.drugbank.ca"}

def clean_indication(text: str) -> str:
    if not text:
        return ""

    # 1a) normalise & strip leading boiler‑plate
    text = text.strip().lower()
    text = re.sub(
        r"""
        ^
        (?:  # alternative openings
            (?:for|to|as)\s+(?:the\s+)?(?:treatment|management|prevention|relief)\s+(?:of|in)\s+
          | indicated\s+for\s+(?:the\s+)?(?:treatment|management|prevention)\s+(?:of|in)\s+
          | indicated\s+(?:in|for)\s+
          | investigated\s+for\s+(?:use|treatment)\s+(?:in|of)\s+
          | used\s+as\s+(?:adjunct|an\s+adjunct)\s+(?:therapy\s+)?(?:in|for)\s+
          | intended\s+for\s+
        )
        """,
        "",
        text,
        flags=re.I | re.VERBOSE,
    )

    # 1b) drop bracketed citations like “[FDA label]”, “[db12345]”
    text = re.sub(r"\[[^\]]+\]", "", text)

    # 2) keep first ≤3 sentences, cap at 350 chars
    sents = re.split(r"[.?!]\s+", text)
    cleaned = ". ".join(sents[:3]).strip()
    return cleaned[:350]

raw = []
for d in root.findall("db:drug", ns):
    name = d.findtext("db:name", namespaces=ns)
    ind  = clean_indication(d.findtext("db:indication", namespaces=ns))

    if not (name and ind):
        continue

    cats = [c.text for c in d.findall("db:categories/db:category/db:category", ns) if c.text]

    raw.append({"drug": name.lower(), "indication": ind.lower(), "categories": cats})

df = pd.DataFrame(raw)
print(f"Stage-1 rows: {len(df):,}")

Stage‑1 rows: 4,393


In [None]:
# 3. Build stop‑list of ultra‑common categories

cat_freq = Counter(c for cats in df["categories"] for c in cats)
TOTAL_DRUGS = df["drug"].nunique()

STOPLIST = {
    c for c, freq in cat_freq.items()
    if freq / TOTAL_DRUGS > 0.30  # appears in >30 % of drugs
}
print(f"Stop-list size: {len(STOPLIST)} (examples: {list(STOPLIST)[:5]})")

# Per‑drug TF‑IDF -> top‑5  (after stop‑list)

def tfidf_topk(cats, k=5):
    weighted = [
        (c, math.log(TOTAL_DRUGS / (cat_freq[c] or 1)))
        for c in cats if c not in STOPLIST
    ]
    weighted.sort(key=lambda x: x[1], reverse=True)
    return [c for c, _ in weighted[:k]]

df["pruned_cats"] = df["categories"].apply(tfidf_topk, k=5)

Stop‑list size: 0 (examples: [])


In [4]:
# 4. Medical‑synonym augmentation

MED_SYNONYMS = {
    # metabolic & endocrine
    "diabetes": ["diabetes mellitus", "high blood sugar"],
    "hyperthyroidism": ["overactive thyroid"],
    "hypothyroidism": ["underactive thyroid"],
    "osteoporosis": ["bone loss", "weak bones"],
    # cardiovascular
    "hypertension": ["high blood pressure"],
    "hypotension": ["low blood pressure"],
    "myocardial infarction": ["heart attack"],
    "congestive heart failure": ["heart failure"],
    "angina": ["chest pain"],
    "arrhythmia": ["irregular heartbeat"],
    # respiratory & infectious
    "influenza": ["flu"],
    "pneumonia": ["lung infection"],
    "tuberculosis": ["tb"],
    "urinary tract infection": ["uti", "bladder infection"],
    # neuro / psych
    "epilepsy": ["seizure disorder"],
    "schizophrenia": ["psychotic disorder"],
    "depression": ["major depressive disorder"],
    "parkinson": ["parkinson's disease"],
    # dermatology
    "acne": ["acne vulgaris", "pimples"],
    "eczema": ["atopic dermatitis"],
    "psoriasis": ["psoriatic disease"],
    # oncology
    "leukemia": ["blood cancer"],
    "lymphoma": ["lymphatic cancer"],
    "melanoma": ["skin cancer"],
    "carcinoma": ["cancer"],
    # gastrointestinal
    "gastroesophageal reflux": ["acid reflux", "gerd"],
    "peptic ulcer": ["stomach ulcer"],
    "hepatitis c": ["hcv infection"],
}

def augment_indication(ind: str) -> list[str]:
    """Return [original, *synonyms found] (deduplicated)."""
    alts = []
    for key, syns in MED_SYNONYMS.items():
        if key in ind:
            alts.extend(syns)
    return list(dict.fromkeys([ind, *alts]))  # preserves order & dedupes

In [None]:
# 5. Build helper dict: indication → list of rows (drug, pruned_cats)

by_indication = defaultdict(list)
for _, row in df.iterrows():
    by_indication[row["indication"]].append(
        {"drug": row["drug"], "cats": row["pruned_cats"]}
    )

In [6]:
# 6. Build prompts (≤5 drugs / chunk) with capped cats

random.seed(42)

PROMPTS = []
BASE_INSTR = (
    "Given the following drug(s) and their categories, "
    "predict the associated health condition.\n"
    "Output answer in the form: 'Indication: <condition>'."
)
K_PROMPT = 12

def cat_union(drug_rows):
    pooled = set().union(*(r["cats"] for r in drug_rows))
    cats = tfidf_topk(list(pooled), k=K_PROMPT)
    return cats                               # may be []


for ind, rows in by_indication.items():

    # 6.a  single‑drug prompts
    for r in rows:
        cat_str = "; ".join(r["cats"]) if r["cats"] else "unknown"
        if cat_str == "unknown":                # ← skip un‑informative prompt
            continue

        PROMPTS.append(
            {
                "instruction": BASE_INSTR,
                "input": f"drug: {r['drug']} || categories: {cat_str}",
                "output": ind,
                "alt_outputs": augment_indication(ind),
            }
        )

    # 6.b  one multi‑drug prompt (if ≥2 drugs share this indication)
    if len(rows) >= 2:
        sample = random.sample(rows, k=min(len(rows), 5))
        cats_list = cat_union(sample)
        if not cats_list:                       # ← skip if all categories pruned
            continue
        cats_str  = "; ".join(cats_list)
        drug_list = ", ".join(r["drug"] for r in sample)

        PROMPTS.append(
            {
                "instruction": BASE_INSTR,
                "input": f"drugs: {drug_list} || categories: {cats_str}",
                "output": ind,
                "alt_outputs": augment_indication(ind),
            }
        )

prompts_df = pd.DataFrame(PROMPTS)
print(
    f"Total prompts: {len(prompts_df):,}  |  "
    f"Unique indications: {prompts_df['output'].nunique():,}"
)

def _norm_ind(ind: str) -> str:
    ind = ind.strip()
    return ind[:-1] if ind.endswith(".") else ind  # drop trailing period

prompts_df["output"] = prompts_df["output"].apply(_norm_ind)
prompts_df["alt_outputs"] = prompts_df["alt_outputs"].apply(
    lambda lst: [_norm_ind(x) for x in lst]
)

Total prompts: 4,060  |  Unique indications: 3,554


In [None]:
# 7. Split 80 / 10 / 10  – stratify when possible, fall back otherwise

min_per_class = 4  # keep classes with ≥4 rows in stratified split

counts = prompts_df["output"].value_counts()
strat_mask = prompts_df["output"].isin(counts[counts >= min_per_class].index)

strat_df = prompts_df[strat_mask]
remainder_df = prompts_df[~strat_mask]   # singletons / doubletons

# stratified 80 / 20 on the larger set
train_big, holdout = train_test_split(
    strat_df,
    test_size=0.20,
    random_state=42,
    stratify=strat_df["output"],
)

# 50 / 50 random split of hold‑out -> val / test
val, test = train_test_split(
    holdout, test_size=0.5, random_state=42, shuffle=True
)

# final training = strat‑train + the rare rows
train = pd.concat([train_big, remainder_df], ignore_index=True)

for name, split in [("train", train), ("val", val), ("test", test)]:
    split.to_json(DATA_DIR / f"llm_{name}.json", orient="records", lines=True)
    print(f"Saved {name:<5}: {len(split):5} rows | classes: {split['output'].nunique():4}")

prompts_df.to_json(DATA_DIR / "llm_all.json", orient="records", lines=True)


Saved train:  3985 rows | classes: 3548
Saved val  :    37 rows | classes:   31
Saved test :    38 rows | classes:   33
