# Smart Processor: Clinical Text + Temporal Data in PyHealth

This notebook demonstrates the **Smart Processor** pattern in PyHealth, where tokenization logic lives inside the `Processor` class rather than being scattered across models.

**Contents:**
1. `TupleTimeTextProcessor` — raw vs. tokenized
2. `EmbeddingModel` — 3D input handling: `(B, N_notes, L)` → `(B, N_notes, H)`
3. `MLP` with tokenized clinical notes
4. `Transformer` with tokenized clinical notes
5. `RNN` with tokenized clinical notes
6. `MultimodalRNN` — notes **+** EHR codes
7. End-to-end backward pass

> Uses `prajjwal1/bert-tiny` (4.4M params) — runs on CPU in under a minute.

In [None]:
# !pip install pyhealth transformers -q

import warnings, random
warnings.filterwarnings('ignore')
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

EMBEDDING_DIM = 128
MODEL_NAME    = 'prajjwal1/bert-tiny'
random.seed(42)

print(f'Device  : {"cuda" if torch.cuda.is_available() else "cpu"}')
print(f'PyTorch : {torch.__version__}')

---
## 1. TupleTimeTextProcessor

Takes `(List[str], List[float])` — clinical notes paired with hours-since-admission timestamps.  
When `tokenizer_model` is set, returns `(input_ids, attn_mask, token_type_ids, time, type_tag)` tensors.

In [None]:
from pyhealth.processors import TupleTimeTextProcessor

notes = ['Admission: chest pain.', 'Day 3: improved.', 'Discharge: stable.']
times = [0.0, 72.0, 120.0]

# ── Without tokenizer ──────────────────────────────────────────────────────
proc_raw = TupleTimeTextProcessor(type_tag='note')
texts_out, time_t, tag = proc_raw.process((notes, times))
print('Without tokenizer')
print(f'  schema()   : {proc_raw.schema()}')
print(f'  is_token() : {proc_raw.is_token()}')
print(f'  output     : {texts_out}, time={time_t}, tag="{tag}"')

In [None]:
# ── With HuggingFace tokenizer ─────────────────────────────────────────────
proc_tok = TupleTimeTextProcessor(tokenizer_model=MODEL_NAME, max_length=32)
input_ids, attn_mask, token_type_ids, time_t, tag = proc_tok.process((notes, times))

print('With tokenizer')
print(f'  schema()        : {proc_tok.schema()}')
print(f'  is_token()      : {proc_tok.is_token()}')
print(f'  input_ids shape : {input_ids.shape}   # (N_notes, max_length)')
print(f'  attn_mask shape : {attn_mask.shape}')
print(f'  sample tokens   : {input_ids[0, :8].tolist()} ...')

---
## 2. EmbeddingModel — 3D Input Handling

`EmbeddingModel` detects `is_token()=True` and loads the HuggingFace model automatically. It handles **3D** inputs:
```
(B, N_notes, L) → flatten(B·N, L) → BERT → CLS[:, 0, :] → unflatten → (B, N_notes, H)
```

In [None]:
from pyhealth.datasets import create_sample_dataset
from pyhealth.models.embedding import EmbeddingModel

# Binary label processor requires at least one sample per class
ds_emb = create_sample_dataset(
    samples=[
        {'patient_id': 'p0', 'visit_id': 'v0',
         'notes': (['Chest pain.', 'Day 3.', 'Discharge.'], [0., 72., 120.]),
         'label': 1},
        {'patient_id': 'p1', 'visit_id': 'v1',
         'notes': (['Pneumonia.', 'Antibiotics.', 'Stable.'], [0., 24., 72.]),
         'label': 0},
    ],
    input_schema={'notes': proc_tok},
    output_schema={'label': 'binary'},
    dataset_name='emb_demo',
)

embedding_model = EmbeddingModel(ds_emb, embedding_dim=EMBEDDING_DIM)
print(f'Embedding layer : {type(embedding_model.embedding_layers["notes"]).__name__}')

# Simulate a padded batch: 2 patients × 3 notes × 32 tokens
B, N, L = 2, 3, 32
fake_ids  = torch.randint(1, 1000, (B, N, L))
fake_mask = torch.ones(B, N, L, dtype=torch.long)

with torch.no_grad():
    out = embedding_model({'notes': fake_ids}, masks={'notes': fake_mask})

print(f'\nInput  : {fake_ids.shape}   → (B, N_notes, Seq_len)')
print(f'Output : {out["notes"].shape}  → (B, N_notes, Embedding_dim)')
assert out['notes'].shape == (B, N, EMBEDDING_DIM)
print('\n✓  3D → CLS pool → 3D works.')

---
## 3. Dataset + Custom Collate

Our processor returns **5-element tuples**. The default PyHealth collate handles 2-element tuples only.  
We provide a `collate_smart_processor` that properly stacks each element across the batch.

In [None]:
def collate_smart_processor(batch):
    """Batch collator for TupleTimeTextProcessor 5-element tuples.

    For tuple-valued features, zips across the batch and stacks each
    tensor element to produce (B, N_notes, ...) tensors.
    """
    result = {}
    for key in batch[0].keys():
        vals = [s[key] for s in batch]

        if isinstance(vals[0], tuple):
            collated = []
            for elem_list in zip(*vals):
                if isinstance(elem_list[0], torch.Tensor):
                    if all(e.shape == elem_list[0].shape for e in elem_list):
                        collated.append(torch.stack(list(elem_list)))
                    else:
                        collated.append(pad_sequence(list(elem_list), batch_first=True))
                else:
                    collated.append(list(elem_list))   # e.g. type_tag strings
            result[key] = tuple(collated)

        elif isinstance(vals[0], torch.Tensor):
            if all(v.shape == vals[0].shape for v in vals):
                result[key] = torch.stack(vals)
            else:
                result[key] = pad_sequence(vals, batch_first=True)
        else:
            result[key] = vals

    return result

print('collate_smart_processor defined ✓')

In [None]:
CLINICAL_NOTES = [
    'Patient presents with acute chest pain.',
    'Labs: elevated troponin. ST elevation on ECG.',
    'Hemodynamically stable post-PCI procedure.',
    'Discharge summary: no complications. Follow-up in 2 weeks.',
    'Readmission: persistent cough, possible pneumonia.',
    'CXR: bilateral infiltrates. Started antibiotics.',
]

def make_sample(pid, vid, label):
    """3 notes per patient — constant N so tensors collate cleanly."""
    ns = random.choices(CLINICAL_NOTES, k=3)
    return {'patient_id': pid, 'visit_id': vid,
            'notes': (ns, [0., 24., 72.]), 'label': label}

samples = [
    make_sample('p0', 'v0', 1),
    make_sample('p1', 'v1', 0),
    make_sample('p2', 'v2', 1),
    make_sample('p3', 'v3', 0),
    make_sample('p4', 'v4', 1),
]

dataset = create_sample_dataset(
    samples=samples,
    input_schema={'notes': proc_tok},
    output_schema={'label': 'binary'},
    dataset_name='smart_proc_demo',
)
dataset.set_shuffle(False)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_smart_processor)
batch  = next(iter(loader))

# notes is a 5-tuple: (input_ids, attn_mask, token_type_ids, time, [type_tags])
input_ids_b, attn_mask_b, _, time_b, tags_b = batch['notes']

print(f'Batch keys         : {list(batch.keys())}')
print(f'input_ids shape    : {input_ids_b.shape}   # (B, N_notes, max_length)')
print(f'attn_mask shape    : {attn_mask_b.shape}')
print(f'time shape         : {time_b.shape}')
print(f'labels             : {batch["label"]}')

---
## 4. MLP Model

Note embeddings `(B, N, H)` are pooled across the note sequence → linear classification head.

In [None]:
from pyhealth.models import MLP

mlp_model = MLP(dataset=dataset, embedding_dim=EMBEDDING_DIM, hidden_dim=64)
mlp_model.eval()

with torch.no_grad():
    out = mlp_model(**batch)

print('MLP')
print(f'  loss   : {out["loss"].item():.4f}')
print(f'  y_prob : {out["y_prob"].squeeze().tolist()}')
print(f'  y_true : {out["y_true"].tolist()}')
print('\n✓  MLP succeeded.')

---
## 5. Transformer Model

Self-attention over the note embedding sequence (`N_notes` steps), capturing inter-note relationships.

In [None]:
from pyhealth.models import Transformer

# Note: Transformer uses heads= (not num_heads=)
tfm_model = Transformer(
    dataset=dataset, embedding_dim=EMBEDDING_DIM,
    heads=4, num_layers=2, dropout=0.1,
)
tfm_model.eval()

with torch.no_grad():
    out = tfm_model(**batch)

print('Transformer')
print(f'  loss   : {out["loss"].item():.4f}')
print(f'  y_prob : {out["y_prob"].squeeze().tolist()}')
print('\n✓  Transformer succeeded.')

---
## 6. RNN Model

GRU processes the note embedding sequence in chronological order.

In [None]:
from pyhealth.models import RNN

rnn_model = RNN(
    dataset=dataset, embedding_dim=EMBEDDING_DIM,
    hidden_dim=64, rnn_type='GRU',
)
rnn_model.eval()

with torch.no_grad():
    out = rnn_model(**batch)

print('RNN  (GRU over note sequence)')
print(f'  loss   : {out["loss"].item():.4f}')
print(f'  y_prob : {out["y_prob"].squeeze().tolist()}')
print('\n✓  RNN succeeded.')

---
## 7. MultimodalRNN — Clinical Notes + EHR Codes

**Flagship use-case**: two modalities fused in a single model:
- **Notes** → BERT-tiny → `(B, 3, H)` note embeddings  
- **EHR codes** → learned embeddings → `(B, T, H)` code sequence  

Both are independently embedded and then concatenated before the shared GRU.

In [None]:
from pyhealth.processors import SequenceProcessor
from pyhealth.models import MultimodalRNN

code_proc = SequenceProcessor()

EHR_CODES = [
    ['I21.0', 'Z87.39', 'I10'],      # STEMI + cardiovascular hx + HTN
    ['I21.0', 'R07.9',  'I10'],      # STEMI + chest pain + HTN
    ['J18.9', 'R05',    'Z87.01'],   # Pneumonia + cough + smoking hx
    ['I21.0', 'I25.10', 'Z95.1'],   # STEMI + CAD + cardiac stent
    ['J18.9', 'J96.00', 'R09.02'],  # Pneumonia + ARDS + hypoxia
]

def make_mm_sample(pid, vid, codes, label):
    ns = random.choices(CLINICAL_NOTES, k=3)
    return {
        'patient_id': pid, 'visit_id': vid,
        'notes': (ns, [0., 24., 72.]),
        'ehr_codes': codes,
        'label': label,
    }

mm_samples = [
    make_mm_sample('p0', 'v0', EHR_CODES[0], 1),
    make_mm_sample('p1', 'v1', EHR_CODES[1], 0),
    make_mm_sample('p2', 'v2', EHR_CODES[2], 1),
    make_mm_sample('p3', 'v3', EHR_CODES[3], 0),
    make_mm_sample('p4', 'v4', EHR_CODES[4], 1),
]

mm_dataset = create_sample_dataset(
    samples=mm_samples,
    input_schema={'notes': proc_tok, 'ehr_codes': code_proc},
    output_schema={'label': 'binary'},
    dataset_name='multimodal_demo',
)
mm_dataset.set_shuffle(False)

mm_loader = DataLoader(mm_dataset, batch_size=2, collate_fn=collate_smart_processor)
mm_batch  = next(iter(mm_loader))

print(f'Batch keys               : {list(mm_batch.keys())}')
print(f'notes (input_ids) shape  : {mm_batch["notes"][0].shape}   # (B, N_notes, max_len)')
ehr_val = mm_batch['ehr_codes']
print(f'ehr_codes                : {type(ehr_val).__name__}')

In [None]:
mm_rnn = MultimodalRNN(
    dataset=mm_dataset, embedding_dim=EMBEDDING_DIM,
    hidden_dim=64, rnn_type='GRU',
)
mm_rnn.eval()

with torch.no_grad():
    out = mm_rnn(**mm_batch)

print('MultimodalRNN  (BERT notes + EHR codes → GRU → prediction)')
print(f'  loss   : {out["loss"].item():.4f}')
print(f'  y_prob : {out["y_prob"].squeeze().tolist()}')
print(f'  y_true : {out["y_true"].tolist()}')
print('\n✓  MultimodalRNN fused clinical text + EHR codes!')

---
## 8. End-to-End Backpropagation

Verify gradients flow from the classification loss all the way back through BERT.

In [None]:
print('Gradient flow — all 4 models')
print('─' * 55)

for name, model, b in [
    ('MLP',           mlp_model, batch),
    ('Transformer',   tfm_model, batch),
    ('RNN',           rnn_model, batch),
    ('MultimodalRNN', mm_rnn,    mm_batch),
]:
    model.train()
    model.zero_grad()
    model(**b)['loss'].backward()

    # nn.ModuleDict doesn't support .get() — cast to plain dict first
    el   = dict(model.embedding_model.embedding_layers)
    proj = el.get('notes_proj', None)
    if proj is not None:
        g     = proj.weight.grad.norm().item()
        label = 'projection layer'
    else:
        g = next(
            (p.grad.norm().item() for p in el['notes'].parameters() if p.grad is not None),
            0.0
        )
        label = 'BERT parameter'

    print(f'  {"✓" if g > 0 else "✗ NO GRAD"}  {name:<15} {label} grad_norm = {g:.4f}')

print('\nAll models propagate gradients through the text encoder. ✓')

---
## 9. Summary

| Model | Text encoder | EHR codes | Forward | Backward |
|-------|-------------|-----------|---------|----------|
| `MLP` | BERT-tiny | — | ✓ | ✓ |
| `Transformer` | BERT-tiny | — | ✓ | ✓ |
| `RNN` | BERT-tiny | — | ✓ | ✓ |
| `MultimodalRNN` | BERT-tiny | ✓ | ✓ | ✓ |

### Key Design Principles

| Principle | Detail |
|-----------|--------|
| **Smart Processor** | Tokenization in `TupleTimeTextProcessor`, not in models |
| **Pre-tokenization** | Text → `int` tensors at build time, no `List[str]` in DataLoader |
| **3D Handling** | `EmbeddingModel`: `(B, N, L)` → BERT → CLS → `(B, N, H)` |
| **Mask propagation** | Attention masks forwarded to BERT, padding ignored |
| **RNN mask reduction** | 3D token mask reduced to 2D note mask for `pack_padded_sequence` |
| **Backward-compatible** | `RETAIN`, `ConCare`, `StageAttentionNet` unchanged |

### Production usage (Bio_ClinicalBERT)

```python
from pyhealth.processors import TupleTimeTextProcessor
from pyhealth.models import MultimodalRNN

note_proc = TupleTimeTextProcessor(
    tokenizer_model='emilyalsentzer/Bio_ClinicalBERT',
    max_length=512,
)
# Build dataset and models exactly as shown above
```