# TupleTimeTextProcessor — Demo Notebook

This notebook demonstrates `TupleTimeTextProcessor`, a PyHealth processor for clinical text paired
with temporal information (time-since-admission). It supports two modes:

| Mode | `tokenizer_model` | Output |
|------|------------------|--------|
| **Raw** | `None` (default) | `(List[str], time_tensor, type_tag)` |
| **Tokenized** | e.g. `"prajjwal1/bert-tiny"` | `(input_ids, attention_mask, token_type_ids, time_tensor, type_tag)` |

**Sections:**
1. Environment setup
2. Raw mode — temporal text handling
3. Tokenized mode — HuggingFace integration
4. Schema & API inspection
5. Integration with `EmbeddingModel`
6. End-to-end training step with `RNN`
7. Summary

## 1. Environment Setup

In [None]:
import torch
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")
print(f"PyTorch: {torch.__version__}")

try:
    import transformers
    print(f"Transformers: {transformers.__version__}")
except ImportError:
    print("⚠️  transformers not installed — tokenized mode will be skipped.")
    print("   Install with: pip install transformers")

## 2. Raw Mode — Temporal Text Handling

Without a tokenizer, the processor acts as a lightweight wrapper that:
- Validates input structure
- Converts time diffs to a float tensor
- Attaches a `type_tag` for modality routing downstream

In [None]:
from pyhealth.processors import TupleTimeTextProcessor

# Default: no tokenizer
processor = TupleTimeTextProcessor(type_tag="clinical_note")

# Simulate three notes from a single patient visit
notes = [
    "Patient admitted with acute chest pain and dyspnea.",
    "Day 1: Troponin elevated. Initiated anticoagulation.",
    "Day 3: Symptoms resolved. Discharge planned.",
]
time_diffs = [0.0, 24.0, 72.0]  # hours since first note

texts_out, time_tensor, tag = processor.process((notes, time_diffs))

print(f"Output texts:       {texts_out}")
print(f"Time tensor:        {time_tensor}")
print(f"Time tensor shape:  {time_tensor.shape}")
print(f"Type tag:           '{tag}'")
print(f"Schema:             {processor.schema()}")
print(f"is_token():         {processor.is_token()}")

## 3. Tokenized Mode — HuggingFace Integration

When `tokenizer_model` is provided, the processor runs `AutoTokenizer` internally
and returns **pre-tokenized tensors** — no raw strings are serialized.
This avoids pickling overhead in PyHealth's dataset pipeline.

We use `prajjwal1/bert-tiny` here to keep downloads fast (4.4 MB).

In [None]:
tok_processor = TupleTimeTextProcessor(
    tokenizer_model="prajjwal1/bert-tiny",
    max_length=32,
    type_tag="clinical_note",
)

result = tok_processor.process((notes, time_diffs))
input_ids, attention_mask, token_type_ids, time_tensor, tag = result

N, L = input_ids.shape
print(f"Notes (N):          {N}")
print(f"Max length (L):     {L}")
print(f"input_ids shape:    {input_ids.shape}   (N × L)")
print(f"attention_mask:     {attention_mask.shape}")
print(f"token_type_ids:     {token_type_ids.shape}")
print(f"time_tensor:        {time_tensor}")
print(f"type_tag:           '{tag}'")
print()
print("Example input_ids (note 0):")
print(input_ids[0])
print("Attention mask (note 0):")
print(attention_mask[0])

## 4. Schema & API Inspection

The schema defines how downstream models interpret each element of the output tuple.
`"value"` and `"mask"` are canonical keys understood by `BaseModel` and `EmbeddingModel`.

In [None]:
print("=== Raw processor ===")
print(f"schema(): {processor.schema()}")
print(f"dim():    {processor.dim()}")
print(f"is_token(): {processor.is_token()}")
print(f"repr:     {repr(processor)}")

print()
print("=== Tokenized processor ===")
print(f"schema(): {tok_processor.schema()}")
print(f"dim():    {tok_processor.dim()}")
print(f"is_token(): {tok_processor.is_token()}")
print(f"repr:     {repr(tok_processor)}")

print()
print("Schema key mapping (tokenized mode):")
schema = tok_processor.schema()
for i, key in enumerate(schema):
    print(f"  result[{i}]  ->  '{key}'")

## 5. Integration with `EmbeddingModel`

`EmbeddingModel` automatically detects `is_token() == True` and loads the corresponding
HuggingFace `AutoModel` as the embedding backend.

For a 3D input `(Batch, N_notes, L)` it:
1. Flattens to `(B*N, L)`
2. Runs the BERT encoder
3. Pools via the `[CLS]` token → `(B*N, H)`
4. Unflattens back to `(B, N, H)` — a sequence of *note embeddings*

In [None]:
from pyhealth.datasets import SampleDataset
from pyhealth.models.embedding import EmbeddingModel

# Build a tiny SampleDataset
samples = [
    {
        "patient_id": f"p{i}",
        "visit_id": f"v{i}",
        "notes": (
            ["Admission note.", "Progress note.", "Discharge note."][:i+1],
            [0.0, 24.0, 72.0][:i+1],
        ),
        "label": i % 2,
    }
    for i in range(3)
]

dataset = SampleDataset(
    samples=samples,
    input_schema={"notes": tok_processor},
    output_schema={"label": "binary"},
)

embedding_dim = 64
emb_model = EmbeddingModel(dataset, embedding_dim=embedding_dim)

# Manually build a batch from two samples
from pyhealth.datasets import get_dataloader
loader = get_dataloader(dataset, batch_size=2, shuffle=False)
batch = next(iter(loader))

notes_feature = batch["notes"]          # Tuple from the processor
print(f"Feature tuple length:  {len(notes_feature)}")
print(f"input_ids batch shape: {notes_feature[0].shape}  (B, N, L)")

# Run through EmbeddingModel (masks propagated automatically)
input_ids_batch = notes_feature[0]      # (B, N, L)
attention_mask_batch = notes_feature[1] # (B, N, L)

with torch.no_grad():
    embedded = emb_model(
        {"notes": input_ids_batch},
        masks={"notes": attention_mask_batch},
    )

note_embeddings = embedded["notes"]
print(f"\nEmbedded shape: {note_embeddings.shape}  (B, N, embedding_dim)")
print(f"Expected:       (2, N, {embedding_dim})")

## 6. End-to-End Training Step with `RNN`

The `RNN` model treats each note embedding as a timestep in the sequence,
and automatically extracts & propagates masks from the feature tuple.
This is a full forward + backward pass.

In [None]:
from pyhealth.models import RNN

model = RNN(
    dataset=dataset,
    feature_keys=["notes"],
    label_key="label",
    embedding_dim=64,
    hidden_dim=32,
    rnn_type="GRU",
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print("Running one training step...")
model.train()
optimizer.zero_grad()

out = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()})

loss = out["loss"]
loss.backward()
optimizer.step()

print(f"Loss:    {loss.item():.4f}")
print(f"y_prob:  {out['y_prob'].detach().cpu().numpy().round(3)}")
print(f"y_true:  {out['y_true'].cpu().numpy()}")
print("\n✅ Forward + backward pass complete!")

## 7. Choosing Your Tokenizer

Any HuggingFace model can be used as the tokenizer/encoder backend.
Here is a quick comparison of popular clinical NLP options:

| Model | Size | Best for |
|-------|------|----------|
| `prajjwal1/bert-tiny` | 4.4 MB | Fast prototyping / unit tests |
| `emilyalsentzer/Bio_ClinicalBERT` | 418 MB | General clinical notes |
| `microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract` | 418 MB | Biomedical literature |
| `allenai/longformer-base-4096` | 580 MB | Very long discharge summaries |

Just swap the `tokenizer_model` string — everything else stays the same.

In [None]:
# Example: switch to Bio_ClinicalBERT (requires ~418 MB download on first run)
# clinical_processor = TupleTimeTextProcessor(
#     tokenizer_model="emilyalsentzer/Bio_ClinicalBERT",
#     max_length=512,
#     type_tag="clinical_note",
# )

# Preview the repr
print(repr(tok_processor))
print(f"Vocab size: {tok_processor.tokenizer.vocab_size:,} tokens")
print("\nAll APIs:")
print(f"  .schema()   -> {tok_processor.schema()}")
print(f"  .is_token() -> {tok_processor.is_token()}")
print(f"  .dim()      -> {tok_processor.dim()}")
print(f"  .size()     -> {tok_processor.size():,}")

## Summary

| Capability | API |
|------------|-----|
| Raw temporal text | `TupleTimeTextProcessor(type_tag="note")` |
| Tokenized text | `TupleTimeTextProcessor(tokenizer_model="...", max_length=128)` |
| Schema inspection | `.schema()` → `('value', 'mask', 'token_type_ids', 'time', 'type_tag')` |
| Modality detection | `.is_token()` → `True/False` |
| Auto-encoder selection | `EmbeddingModel` reads `is_token()` and loads `AutoModel` |
| 3D note handling | `EmbeddingModel` flattens `(B,N,L)` → encode → pool → `(B,N,H)` |
| Downstream use | `RNN`, `MLP`, `Transformer`, `MultimodalRNN` all handle it natively |

**Design philosophy**: preprocessing logic lives in the processor, not the model.
This keeps models generic and enables swapping encoders without touching training code.