# TextEmbedding Tutorial

This notebook demonstrates how to use the `TextEmbedding` module for encoding clinical text in PyHealth.

**Overview:**
- Initialize TextEmbedding with Bio_ClinicalBERT
- Demonstrate 128-token chunking for long clinical notes
- Show different pooling modes (none, cls, mean)
- Verify mask output and backward compatibility
- Use TupleTimeTextProcessor for temporal text data

## 1. Environment Setup

In [None]:
import torch
import warnings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

## 2. Basic Usage

Initialize `TextEmbedding` and encode sample clinical text.

In [None]:
from pyhealth.models import TextEmbedding

# Initialize with default Bio_ClinicalBERT
encoder = TextEmbedding(
    embedding_dim=128,
    chunk_size=128,
    pooling="none",
)

print(f"Model: {encoder.model_name}")
print(f"Embedding dim: {encoder.embedding_dim}")
print(f"Chunk size: {encoder.chunk_size}")
print(f"Pooling: {encoder.pooling}")

In [None]:
# Encode sample clinical notes
texts = [
    "Patient presents with chest pain and shortness of breath.",
    "Follow-up visit for diabetes management. Blood glucose stable."
]

embeddings, mask = encoder(texts)

print(f"Input: {len(texts)} texts")
print(f"Embeddings shape: {embeddings.shape}  # [batch, tokens, embedding_dim]")
print(f"Mask shape: {mask.shape}  # [batch, tokens]")
print(f"Mask dtype: {mask.dtype}")
print(f"Valid tokens per sample: {mask.sum(dim=1).tolist()}")

## 3. Chunking Behavior

Long texts are split into non-overlapping chunks of `chunk_size - 2` tokens (2 reserved for [CLS]/[SEP]).

In [None]:
# Create a long clinical note by repeating text
long_note = "Patient admitted with acute myocardial infarction. " * 50
print(f"Long note word count: {len(long_note.split())}")

embeddings, mask = encoder([long_note])

print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Total tokens encoded: {mask.sum().item()}")
print(f"\nNote: Multiple chunks were created and concatenated along the sequence dimension.")

## 4. Pooling Modes

Three pooling modes are available:
- `"none"`: All token embeddings [B, T, E']
- `"cls"`: [CLS] token per chunk [B, C, E']
- `"mean"`: Mean-pooled per chunk [B, C, E']

In [None]:
# Compare pooling modes on the same long text
long_text = "Patient presents with symptoms. " * 100

for pooling in ["none", "cls", "mean"]:
    enc = TextEmbedding(embedding_dim=128, pooling=pooling)
    emb, mask = enc([long_text])
    print(f"pooling='{pooling}': shape={tuple(emb.shape)}, valid_positions={mask.sum().item()}")

## 5. Performance Guardrails

The `max_chunks` parameter prevents memory issues with very long texts.

In [None]:
# Very long text that exceeds max_chunks
very_long_text = "Clinical observation. " * 2000  # Approximately 10,000+ tokens

encoder_limited = TextEmbedding(
    embedding_dim=128,
    pooling="cls",  # Use CLS for cleaner output
    max_chunks=10,  # Limit to 10 chunks
)

# This will trigger a UserWarning about truncation
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    emb, mask = encoder_limited([very_long_text])
    if w:
        print(f"Warning: {w[0].message}")

print(f"\nOutput shape: {emb.shape}")
print(f"Chunks retained: {mask.sum().item()} (capped at max_chunks=10)")

## 6. Backward Compatibility

The `return_mask` parameter allows single-tensor return for compatibility with older code.

In [None]:
# New API (default): returns tuple
encoder_new = TextEmbedding(return_mask=True)
result = encoder_new(["Test text"])
print(f"return_mask=True: type={type(result)}, len={len(result)}")

# Backward compatible: returns tensor only
encoder_compat = TextEmbedding(return_mask=False)
result = encoder_compat(["Test text"])
print(f"return_mask=False: type={type(result)}, shape={result.shape}")

## 7. Mask Convention

The mask uses `True` for valid positions and `False` for padding, compatible with PyHealth's TransformerLayer.

In [None]:
# Different length texts to show padding behavior
texts = [
    "Short note.",
    "This is a longer clinical note with more content to encode into embeddings."
]

encoder = TextEmbedding(embedding_dim=128)
embeddings, mask = encoder(texts)

print(f"Batch shape: {embeddings.shape}")
print(f"Mask shape: {mask.shape}")
print(f"")
print(f"Sample 1 valid tokens: {mask[0].sum().item()}")
print(f"Sample 2 valid tokens: {mask[1].sum().item()}")
print(f"")
print("Mask convention: True=valid, False=padding")
print(f"Sample 1 mask: {mask[0].tolist()[:15]}...")

## 8. Freezing Pretrained Weights

For multimodal fusion, freeze the encoder to prevent overfitting.

In [None]:
# Frozen encoder (recommended for multimodal fusion)
encoder_frozen = TextEmbedding(embedding_dim=128, freeze=True)

transformer_params = sum(p.numel() for p in encoder_frozen.transformer.parameters())
trainable_params = sum(p.numel() for p in encoder_frozen.transformer.parameters() if p.requires_grad)
projection_params = sum(p.numel() for p in encoder_frozen.fc.parameters())

print(f"Transformer parameters: {transformer_params:,}")
print(f"Trainable transformer params: {trainable_params:,}")
print(f"Projection layer params: {projection_params:,} (always trainable)")

## 9. Eval Mode Determinism

In eval mode, dropout is disabled and outputs are deterministic.

In [None]:
encoder = TextEmbedding(embedding_dim=128)
encoder.eval()

text = ["Patient stable."]

with torch.no_grad():
    emb1, _ = encoder(text)
    emb2, _ = encoder(text)

is_equal = torch.allclose(emb1, emb2)
print(f"Outputs identical in eval mode: {is_equal}")

## 10. TupleTimeTextProcessor for Temporal Data

The `TupleTimeTextProcessor` handles clinical text paired with temporal information (time differences). This is useful for multimodal models that need to automatically route different modality types.

In [None]:
from pyhealth.processors import TupleTimeTextProcessor

# Initialize processor with type tag for modality routing
processor = TupleTimeTextProcessor(type_tag="clinical_note")

# Clinical notes with time differences (e.g., hours since admission)
texts = [
    "Patient admitted with chest pain.",
    "Follow-up: symptoms improved.",
    "Discharge: stable condition."
]
time_diffs = [0.0, 24.0, 72.0]  # hours

# Process tuple
processed_texts, time_tensor, modality_tag = processor.process((texts, time_diffs))

print(f"Texts: {processed_texts}")
print(f"Time tensor: {time_tensor}")
print(f"Time tensor shape: {time_tensor.shape}")
print(f"Modality tag: '{modality_tag}'")
print(f"\nProcessor repr: {repr(processor)}")

### Multimodal Fusion Example

The `type_tag` enables automatic routing in multimodal pipelines without hardcoding feature names:

In [None]:
# Different modality types with different processors
note_processor = TupleTimeTextProcessor(type_tag="note")
ehr_processor = TupleTimeTextProcessor(type_tag="ehr")
observation_processor = TupleTimeTextProcessor(type_tag="observation")

# Process different data types
notes = ["Admission note", "Progress note"]
note_times = [0.0, 24.0]
_, note_tensor, note_tag = note_processor.process((notes, note_times))

ehr_events = ["Lab ordered", "Medication given"]
ehr_times = [2.0, 6.0]
_, ehr_tensor, ehr_tag = ehr_processor.process((ehr_events, ehr_times))

# Tags can be used for automatic routing in models
print(f"Note modality tag: '{note_tag}'")
print(f"EHR modality tag: '{ehr_tag}'")
print(f"\nNote times: {note_tensor}")
print(f"EHR times: {ehr_tensor}")

# Can combine tensors for temporal modeling
combined_times = torch.cat([note_tensor, ehr_tensor])
print(f"\nCombined times: {combined_times}")
print(f"Combined shape: {combined_times.shape}")

## Summary

The `TextEmbedding` module and `TupleTimeTextProcessor` provide:

| Feature | Description |
|---------|-------------|
| **Chunking** | Splits long texts into 128-token chunks |
| **Pooling** | none/cls/mean modes for different use cases |
| **Mask** | Boolean tensor compatible with TransformerLayer |
| **Guardrails** | max_chunks prevents OOM on long texts |
| **Compatibility** | return_mask=False for legacy code |
| **Temporal Processing** | TupleTimeTextProcessor for time-aware text data |
| **Modality Routing** | Type tags enable automatic multimodal fusion |