In [None]:
import os
import sys

# Ensure project root is on sys.path when running from examples/
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print("PROJECT_ROOT:", PROJECT_ROOT)

# Basic sanity check for torch and pyhealth
try:
    import torch
    print("PyTorch is installed")
except ImportError as e:
    raise RuntimeError(
        "PyTorch is not installed. Install it into your environment first " 
    ) from e

try:
    import pyhealth
    print("pyhealth is importable, version:", getattr(pyhealth, "__version__", "unknown"))
except ImportError as e:
    raise RuntimeError(
        "pyhealth is not importable."
    ) from e

# Core dataset + MedLink imports
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.tasks import BaseTask
from pyhealth.models.medlink import (
    BM25Okapi,
    convert_to_ir_format,
    filter_by_candidates,
    generate_candidates,
    get_bm25_hard_negatives,
    get_eval_dataloader,
    get_train_dataloader,
    tvt_split,
)
from pyhealth.models.medlink.model import MedLink
from pyhealth.metrics import ranking_metrics_fn


In [None]:
# Downloaded from: https://physionet.org/content/mimiciii-demo/1.4/
MIMIC3_DEMO_ROOT = "/path/to/mimic-iii-clinical-database-demo-1.4"  # <-- adjust for real
#MIMIC3_DEMO_ROOT = "/Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4"
print("MIMIC-III demo root:", MIMIC3_DEMO_ROOT)


In [None]:
# Load base MIMIC-III dataset from the demo

base_dataset = MIMIC3Dataset(
    root=MIMIC3_DEMO_ROOT,
    tables=["diagnoses_icd"],  # matches in configs/mimic3.yaml
    dev=False,                 # True => small subset of patients
)

base_dataset.stats()


In [None]:
from pyhealth.tasks.patient_linkage_mimic3 import PatientLinkageMIMIC3Task
from datetime import datetime
from collections import defaultdict
import math


In [None]:
# patient linkage task and build the sample dataset

patient_linkage_task = PatientLinkageMIMIC3Task()
sample_dataset = base_dataset.set_task(task=patient_linkage_task)

print("Number of samples generated:", len(sample_dataset.samples))
if sample_dataset.samples:
    print("Example sample:\n", sample_dataset.samples[0])


In [None]:
# Convert samples to IR format and split train/val/test
from pyhealth.models.medlink import convert_to_ir_format, tvt_split

corpus, queries, qrels, corpus_meta, queries_meta = convert_to_ir_format(
    sample_dataset.samples
)

tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(queries, qrels)

print(f"Corpus / Query / Qrel summary: corpus={len(corpus)}, queries={len(queries)}, qrels={len(qrels)}")
print(f"Train queries: {len(tr_queries)}, Val queries: {len(va_queries)}, Test queries: {len(te_queries)}")


In [None]:
USE_BM25_HARDNEGS = False

# optionally refine training qrels with BM25-based hard negatives
if USE_BM25_HARDNEGS:
    bm25_model = BM25Okapi(corpus)
    tr_qrels = get_bm25_hard_negatives(
        bm25_model, corpus, tr_queries, tr_qrels
    )

#Dataloaders for training / validation / test
train_dataloader = get_train_dataloader(
    corpus=corpus,
    queries=tr_queries,
    qrels=tr_qrels,
    batch_size=32,
    shuffle=True,
)

val_dataloader = get_train_dataloader(
    corpus=corpus,
    queries=va_queries,
    qrels=va_qrels,
    batch_size=32,
    shuffle=False,
)

test_corpus_dataloader, test_queries_dataloader = get_eval_dataloader(
    corpus=corpus,
    queries=te_queries,
    batch_size=32,
)

batch = next(iter(train_dataloader))
for k, v in batch.items():
    print(k, type(v), (len(v) if hasattr(v, "__len__") else None))


In [None]:
# Build train_loader for MedLink

from pyhealth.models.medlink import get_train_dataloader, tvt_split

tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(
    queries, qrels
)

train_loader = get_train_dataloader(
    corpus=corpus,
    queries=tr_queries,
    qrels=tr_qrels,
    batch_size=32,
    shuffle=True,
)

# quick sanity check
batch = next(iter(train_loader))
print(batch.keys())


In [None]:
import torch
from pyhealth.models import BaseModel
from pyhealth.datasets import SampleDataset
from pyhealth.models.medlink.model import MedLink

# normalize sequences so tokenizer sees lists, not tensors
def _normalize_seqs(obj):
    """
    Convert batch field (tensor or list of tensors/lists) into
    List[List[str]] as expected by Tokenizer.batch_encode_2d.
    """
    if torch.is_tensor(obj):
        obj = obj.tolist()  # -> list[list[int]]

    seqs_out = []
    for seq in obj:
        if torch.is_tensor(seq):
            seq = seq.tolist()
        # at this point seq is list[int] or list[str]
        seqs_out.append([str(tok) for tok in seq])
    return seqs_out

#init medlink and run a single forward/backward pass
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# sample_dataset and train_loader must already be defined in earlier cells
model = MedLink(
    dataset=sample_dataset,
    feature_keys=["conditions"],
    embedding_dim=128,
).to(device)

# Take one batch from the MedLink train dataloader
batch = next(iter(train_loader))
print("Raw batch keys:", batch.keys())

# Normalize the sequence fields so AdmissionPrediction/Tokenizer work
if "s_q" in batch:
    batch["s_q"] = _normalize_seqs(batch["s_q"])
if "s_p" in batch:
    batch["s_p"] = _normalize_seqs(batch["s_p"])
if "s_n" in batch and batch["s_n"] is not None:
    batch["s_n"] = _normalize_seqs(batch["s_n"])

model.train()
outputs = model(**batch)
print("MedLink outputs keys:", outputs.keys())
print("Loss:", float(outputs["loss"]))

outputs["loss"].backward()
print("Backward pass completed.")


In [None]:
#Sanity
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
    total = 0.0
    n = 0
    for batch in train_loader:
        # normalize s_q / s_p as before
        batch["s_q"] = _normalize_seqs(batch["s_q"])
        batch["s_p"] = _normalize_seqs(batch["s_p"])
        if "s_n" in batch and batch["s_n"] is not None:
            batch["s_n"] = _normalize_seqs(batch["s_n"])

        optimizer.zero_grad()
        out = model(**batch)
        loss = out["loss"]
        loss.backward()
        optimizer.step()

        total += float(loss)
        n += 1
    print(f"epoch {epoch}: avg loss = {total / max(n,1):.4f}")


In [None]:
#Unit test script - pytest tests/core/test_medlink.py
