In [1]:
import os
import sys

# Ensure project root is on sys.path when running from examples/
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print("PROJECT_ROOT:", PROJECT_ROOT)

# Basic sanity check for torch and pyhealth
try:
    import torch
    print("✓ PyTorch is installed")
except ImportError as e:
    raise RuntimeError(
        "PyTorch is not installed. Install it into your environment first "
        "(e.g., `pip install torch` matching your CUDA/CPU)." 
    ) from e

try:
    import pyhealth
    print("✓ pyhealth is importable, version:", getattr(pyhealth, "__version__", "unknown"))
except ImportError as e:
    raise RuntimeError(
        "pyhealth is not importable. From the project root, run\n"
        "    pip install -e .\n"
        "to install PyHealth in editable mode."
    ) from e

# Core dataset + MedLink imports
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import BaseTask
from pyhealth.models.medlink import (
    BM25Okapi,
    convert_to_ir_format,
    filter_by_candidates,
    generate_candidates,
    get_bm25_hard_negatives,
    get_eval_dataloader,
    get_train_dataloader,
    tvt_split,
)
from pyhealth.models.medlink.model import MedLink
from pyhealth.metrics import ranking_metrics_fn


PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth
✓ PyTorch is installed
✓ pyhealth is importable, version: 1.1.4


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Downloaded from: https://physionet.org/content/mimiciii-demo/1.4/
MIMIC3_DEMO_ROOT = "/path/to/mimic-iii-clinical-database-demo-1.4"  # <-- adjust for real

print("MIMIC-III demo root:", MIMIC3_DEMO_ROOT)


MIMIC-III demo root: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4


In [3]:
# STEP 1: Load base MIMIC-III dataset from the demo

base_dataset = MIMIC3Dataset(
    root=MIMIC3_DEMO_ROOT,
    tables=["diagnoses_icd"],  # matches `diagnoses_icd` in configs/mimic3.yaml
    dev=False,                 # True => small subset of patients
)

base_dataset.stats()


No config path provided, using default config
Initializing mimic3 dataset from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4 (dev mode: False)
Scanning table: patients from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv.gz
Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv
Scanning table: admissions from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv
Scanning table: icustays from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv
Scanning table: diagnoses_icd from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_

In [4]:
from datetime import datetime
from collections import defaultdict
import math

class PatientLinkageMIMIC3Task(BaseTask):
    """
    Patient linkage task for MIMIC-III using the new Patient/Event API.

    It produces samples with the same keys as the old
    `patient_linkage_mimic3_fn`, so that medlink.utils.convert_to_ir_format
    works unchanged.
    """

    task_name = "patient_linkage_mimic3"
    # MedLink actually consumes `conditions` / `d_conditions` as sequences,
    # but we don't rely heavily on the feature processors here.
    input_schema = {
        "conditions": "sequence",
        "d_conditions": "sequence",
    }
    # No supervised label for MedLink retrieval
    output_schema = {}

    def __call__(self, patient):
        """
        Process a single patient into MedLink samples.

        Requirements (same as original task):
          - At least 2 visits (admissions)
          - Age >= 18 at both visits
          - Non-empty conditions for both visits
        """
        # --- 1) Get admissions (visits), sorted by time ---
        admissions = patient.get_events(event_type="admissions")
        if len(admissions) < 2:
            return []

        admissions = sorted(admissions, key=lambda e: e.timestamp)
        q_visit = admissions[-1]   # last visit (query)
        d_visit = admissions[-2]   # second last visit (document)

        # --- 2) Get patient demographics (gender, dob) ---
        patients_events = patient.get_events(event_type="patients")
        if not patients_events:
            return []
        demo = patients_events[0]

        gender = str(demo.attr_dict.get("gender") or "")

        dob_raw = demo.attr_dict.get("dob")
        birth_dt = None
        if isinstance(dob_raw, datetime):
            birth_dt = dob_raw
        elif dob_raw is not None:
            # In the MIMIC CSV it's a string like "2111-04-20 00:00:00"
            try:
                birth_dt = datetime.fromisoformat(str(dob_raw))
            except Exception:
                birth_dt = None

        def compute_age(ts):
            if birth_dt is None or ts is None:
                return None
            # rough years
            return int((ts - birth_dt).days // 365.25)

        q_age = compute_age(q_visit.timestamp)
        d_age = compute_age(d_visit.timestamp)

        # Exclude under 18 or missing age
        if q_age is None or d_age is None or q_age < 18 or d_age < 18:
            return []

        # --- 3) Collect diagnosis codes per admission (hadm_id) ---
        diag_events = patient.get_events(event_type="diagnoses_icd")
        hadm_to_codes = defaultdict(list)
        for ev in diag_events:
            hadm = ev.attr_dict.get("hadm_id")
            code = ev.attr_dict.get("icd9_code")
            if hadm is None or code is None:
                continue
            hadm_to_codes[str(hadm)].append(str(code))

        q_hadm = str(q_visit.attr_dict.get("hadm_id"))
        d_hadm = str(d_visit.attr_dict.get("hadm_id"))

        q_conditions = hadm_to_codes.get(q_hadm, [])
        d_conditions = hadm_to_codes.get(d_hadm, [])

        # Exclude if any side has no conditions
        if len(q_conditions) == 0 or len(d_conditions) == 0:
            return []

        # --- 4) Identifier strings (gender + admin attributes) ---
        def clean(x):
            # mimic old NaN handling: empty string if missing/NaN
            if x is None:
                return ""
            if isinstance(x, float) and math.isnan(x):
                return ""
            return str(x)

        def build_identifiers(adm_event):
            insurance = clean(adm_event.attr_dict.get("insurance"))
            language = clean(adm_event.attr_dict.get("language"))
            religion = clean(adm_event.attr_dict.get("religion"))
            marital_status = clean(adm_event.attr_dict.get("marital_status"))
            ethnicity = clean(adm_event.attr_dict.get("ethnicity"))
            return "+".join(
                [gender, insurance, language, religion, marital_status, ethnicity]
            )

        q_identifiers = build_identifiers(q_visit)
        d_identifiers = build_identifiers(d_visit)

        # --- 5) Build sample dict (same keys as old function) ---
        sample = {
            "patient_id": patient.patient_id,
            "visit_id": q_hadm,              # query visit_id
            "conditions": [""] + q_conditions,
            "age": q_age,
            "identifiers": q_identifiers,

            "d_visit_id": d_hadm,            # document visit_id
            "d_conditions": [""] + d_conditions,
            "d_age": d_age,
            "d_identifiers": d_identifiers,
        }

        return [sample]


In [5]:
# STEP 3: Set the patient linkage task and build the sample dataset

patient_linkage_task = PatientLinkageMIMIC3Task()
sample_dataset = base_dataset.set_task(task=patient_linkage_task)

print("Number of samples generated:", len(sample_dataset.samples))
if sample_dataset.samples:
    print("Example sample:\n", sample_dataset.samples[0])


Setting task patient_linkage_mimic3 for mimic3 base dataset...
Generating samples with 1 worker(s)...


Generating samples for patient_linkage_mimic3 with 1 worker: 100%|██████████| 100/100 [00:00<00:00, 1499.64it/s]
Processing samples: 100%|██████████| 14/14 [00:00<00:00, 35246.25it/s]

Generated 14 samples for task patient_linkage_mimic3
Number of samples generated: 14
Example sample:
 {'patient_id': '42346', 'visit_id': '175880', 'conditions': tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20]), 'age': 88, 'identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE', 'd_visit_id': '180391', 'd_conditions': tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23]), 'd_age': 88, 'd_identifiers': 'F+Medicare+ENGL+NOT SPECIFIED+SINGLE+WHITE'}





In [6]:
# Convert samples to IR format and split train/val/test
from pyhealth.models.medlink import convert_to_ir_format, tvt_split

corpus, queries, qrels, corpus_meta, queries_meta = convert_to_ir_format(
    sample_dataset.samples
)

tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(queries, qrels)

print(f"Corpus / Query / Qrel summary: corpus={len(corpus)}, queries={len(queries)}, qrels={len(qrels)}")
print(f"Train queries: {len(tr_queries)}, Val queries: {len(va_queries)}, Test queries: {len(te_queries)}")


Corpus / Query / Qrel summary: corpus=14, queries=14, qrels=14
Train queries: 9, Val queries: 2, Test queries: 3


In [7]:
USE_BM25_HARDNEGS = False

# Optionally refine training qrels with BM25-based hard negatives
if USE_BM25_HARDNEGS:
    bm25_model = BM25Okapi(corpus)
    tr_qrels = get_bm25_hard_negatives(
        bm25_model, corpus, tr_queries, tr_qrels
    )

# STEP 4: Dataloaders for training / validation / test
train_dataloader = get_train_dataloader(
    corpus=corpus,
    queries=tr_queries,
    qrels=tr_qrels,
    batch_size=32,
    shuffle=True,
)

val_dataloader = get_train_dataloader(
    corpus=corpus,
    queries=va_queries,
    qrels=va_qrels,
    batch_size=32,
    shuffle=False,
)

test_corpus_dataloader, test_queries_dataloader = get_eval_dataloader(
    corpus=corpus,
    queries=te_queries,
    batch_size=32,
)

batch = next(iter(train_dataloader))
for k, v in batch.items():
    print(k, type(v), (len(v) if hasattr(v, "__len__") else None))


Loaded 9 training pairs.
Loaded 2 training pairs.
Loaded 14 eval corpus.
Loaded 3 eval queries.
query_id <class 'list'> 9
id_p <class 'list'> 9
s_q <class 'list'> 9
s_p <class 'list'> 9


In [8]:
# Build train_loader for MedLink (run this before the Step 5 MedLink cell)

from pyhealth.models.medlink import get_train_dataloader, tvt_split

tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(
    queries, qrels
)

train_loader = get_train_dataloader(
    corpus=corpus,
    queries=tr_queries,
    qrels=tr_qrels,
    batch_size=32,
    shuffle=True,
)

# quick sanity check
batch = next(iter(train_loader))
print(batch.keys())


Loaded 9 training pairs.
dict_keys(['query_id', 'id_p', 's_q', 's_p'])


In [9]:
import torch
from pyhealth.models import BaseModel
from pyhealth.datasets import SampleDataset
from pyhealth.models.medlink.model import MedLink

# ---------------------------------------------------------
# 1) Patch BaseModel.__init__ so MedLink's legacy kwargs are ignored
# ---------------------------------------------------------
if not hasattr(BaseModel, "_orig_init_for_medlink"):
    BaseModel._orig_init_for_medlink = BaseModel.__init__

    def _patched_bm_init(self, dataset=None, *args, **kwargs):
        # MedLink passes feature_keys, label_key, mode; ignore them here
        return BaseModel._orig_init_for_medlink(self, dataset=dataset)

    BaseModel.__init__ = _patched_bm_init

# ---------------------------------------------------------
# 2) Patch SampleDataset.get_all_tokens used in MedLink.__init__
# ---------------------------------------------------------
if not hasattr(SampleDataset, "get_all_tokens"):
    def _get_all_tokens(self, key, remove_duplicates=True, sort=False):
        tokens = []

        for sample in self.samples:
            if key not in sample:
                continue
            value = sample[key]

            # Flatten nested lists/tuples
            stack = [value]
            while stack:
                cur = stack.pop()
                if isinstance(cur, (list, tuple)):
                    stack.extend(cur)
                else:
                    tokens.append(cur)

        if remove_duplicates:
            seen = set()
            uniq = []
            for t in tokens:
                if t in seen:
                    continue
                seen.add(t)
                uniq.append(t)
            tokens = uniq

        if sort:
            try:
                tokens = sorted(tokens)
            except Exception:
                pass

        return tokens

    SampleDataset.get_all_tokens = _get_all_tokens

# ---------------------------------------------------------
# 3) Helper: normalize sequences so tokenizer sees lists, not tensors
# ---------------------------------------------------------
def _normalize_seqs(obj):
    """
    Convert batch field (tensor or list of tensors/lists) into
    List[List[str]] as expected by Tokenizer.batch_encode_2d.
    """
    if torch.is_tensor(obj):
        obj = obj.tolist()  # -> list[list[int]]

    seqs_out = []
    for seq in obj:
        if torch.is_tensor(seq):
            seq = seq.tolist()
        # at this point seq is list[int] or list[str]
        seqs_out.append([str(tok) for tok in seq])
    return seqs_out

# ---------------------------------------------------------
# 4) Instantiate MedLink and run a single forward/backward pass
# ---------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# sample_dataset and train_loader must already be defined in earlier cells
model = MedLink(
    dataset=sample_dataset,
    feature_keys=["conditions"],
    embedding_dim=128,
).to(device)

# Take one batch from the MedLink train dataloader
batch = next(iter(train_loader))
print("Raw batch keys:", batch.keys())

# Normalize the sequence fields so AdmissionPrediction/Tokenizer work
if "s_q" in batch:
    batch["s_q"] = _normalize_seqs(batch["s_q"])
if "s_p" in batch:
    batch["s_p"] = _normalize_seqs(batch["s_p"])
if "s_n" in batch and batch["s_n"] is not None:
    batch["s_n"] = _normalize_seqs(batch["s_n"])

model.train()
outputs = model(**batch)
print("MedLink outputs keys:", outputs.keys())
print("Loss:", float(outputs["loss"]))

outputs["loss"].backward()
print("Backward pass completed.")


Raw batch keys: dict_keys(['query_id', 'id_p', 's_q', 's_p'])
MedLink outputs keys: dict_keys(['loss'])
Loss: 32.6289176940918
Backward pass completed.


In [10]:
#Sanity
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
    total = 0.0
    n = 0
    for batch in train_loader:
        # normalize s_q / s_p as before
        batch["s_q"] = _normalize_seqs(batch["s_q"])
        batch["s_p"] = _normalize_seqs(batch["s_p"])
        if "s_n" in batch and batch["s_n"] is not None:
            batch["s_n"] = _normalize_seqs(batch["s_n"])

        optimizer.zero_grad()
        out = model(**batch)
        loss = out["loss"]
        loss.backward()
        optimizer.step()

        total += float(loss)
        n += 1
    print(f"epoch {epoch}: avg loss = {total / max(n,1):.4f}")


epoch 0: avg loss = 30.2854
epoch 1: avg loss = 34.0946
epoch 2: avg loss = 27.1799


In [22]:
!python /Users/saurabhatri/Downloads/PyHealth/tests/core/test_medlink.py

Processing samples: 100%|███████████████████████| 2/2 [00:00<00:00, 6641.81it/s]
Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 60787.01it/s]
Processing samples: 100%|██████████████████████| 2/2 [00:00<00:00, 64527.75it/s]
.
----------------------------------------------------------------------
Ran 3 tests in 0.037s

OK
