# LSTM Data Preparation: From Patient Sequences to Model Inputs

This notebook demonstrates how to prepare visit-grouped patient sequences for the LSTM baseline model.

**Prediction Objective:** Binary classification - predicting if a patient has diabetes based on their EHR sequence.

## Overview

We'll walk through the complete data transformation pipeline:

1. **Load processed sequences** from the main exploration notebook
2. **Encode sequences** to integer IDs using vocabulary
3. **Create labels** for prediction task (diabetes detection)
4. **Prepare batches** with proper padding and masking
5. **Visualize data shapes** at each transformation step
6. **Create LSTM-ready tensors** for model input

See `data_shape_transformations.md` for detailed documentation of all shape changes.

---

## Setup

In [1]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from collections import Counter

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Import ehrsequencing package
from ehrsequencing.data.adapters import SyntheaAdapter
from ehrsequencing.data.visit_grouper import VisitGrouper
from ehrsequencing.data.sequence_builder import PatientSequenceBuilder
from ehrsequencing.models import create_lstm_baseline

print("âœ… Imports successful")

âœ… Imports successful


---

## 1. Load and Prepare Data

We'll start by loading Synthea data and creating patient sequences (same as notebook 01).

In [2]:
# Path to Synthea data
data_path = Path.home() / 'work' / 'loinc-predictor' / 'data' / 'synthea' / 'all_cohorts'

# Initialize adapter
adapter = SyntheaAdapter(data_path=str(data_path))

print(f"âœ… Loaded Synthea data from: {data_path}")

âœ… Loaded Synthea data from: /Users/pleiadian53/work/loinc-predictor/data/synthea/all_cohorts


In [3]:
# Load patients (using more for better examples)
patients = adapter.load_patients(limit=50)
patient_ids = [p.patient_id for p in patients]

print(f"Loaded {len(patients)} patients")
print(f"\nðŸ“Š Data Shape: List[PatientInfo] with length {len(patients)}")

Loaded 50 patients

ðŸ“Š Data Shape: List[PatientInfo] with length 50


In [4]:
# Initialize visit grouper
visit_grouper = VisitGrouper(
    strategy='hybrid',
    time_window_hours=24,
    preserve_code_types=True
)

print("âœ… VisitGrouper initialized")

âœ… VisitGrouper initialized


In [5]:
# Load events and group into visits
print(f"Processing {len(patient_ids)} patients...")

patient_visits = {}
for patient_id in patient_ids:
    events = adapter.load_events(patient_ids=[patient_id])
    visits = visit_grouper.group_events(events, patient_id=patient_id)
    patient_visits[patient_id] = visits

print(f"\nâœ… Grouped visits for {len(patient_visits)} patients")
print(f"\nðŸ“Š Data Shape: Dict[str, List[Visit]]")
print(f"   - Keys: {len(patient_visits)} patient IDs")
print(f"   - Values: Lists of Visit objects")
print(f"   - Total visits: {sum(len(v) for v in patient_visits.values())}")

Processing 50 patients...

âœ… Grouped visits for 50 patients

ðŸ“Š Data Shape: Dict[str, List[Visit]]
   - Keys: 50 patient IDs
   - Values: Lists of Visit objects
   - Total visits: 3037


---

## 2. Build Patient Sequences

Transform visit groups into structured patient sequences with vocabulary.

In [6]:
# Initialize sequence builder
sequence_builder = PatientSequenceBuilder(
    vocab=None,
    max_visits=50,
    max_codes_per_visit=100,
    use_semantic_order=True
)

print("âœ… PatientSequenceBuilder initialized")
print(f"   Max visits per sequence: 50")
print(f"   Max codes per visit: 100")

âœ… PatientSequenceBuilder initialized
   Max visits per sequence: 50
   Max codes per visit: 100


In [7]:
# Build vocabulary from all patient visits
print("Building vocabulary...")
vocab = sequence_builder.build_vocabulary(list(patient_visits.values()), min_frequency=1)

print(f"\nâœ… Vocabulary built")
print(f"   Vocabulary size: {len(vocab)}")
print(f"   Special tokens: [PAD]=0, [UNK]=1, [MASK]=2, [CLS]=3, [SEP]=4")
print(f"\nðŸ“Š Data Shape: Dict[str, int]")
print(f"   - Medical code â†’ Integer ID mapping")
print(f"   - Example: {list(vocab.items())[:5]}")

Building vocabulary...

âœ… Vocabulary built
   Vocabulary size: 659
   Special tokens: [PAD]=0, [UNK]=1, [MASK]=2, [CLS]=3, [SEP]=4

ðŸ“Š Data Shape: Dict[str, int]
   - Medical code â†’ Integer ID mapping
   - Example: [('[PAD]', 0), ('[UNK]', 1), ('[MASK]', 2), ('[CLS]', 3), ('[SEP]', 4)]


In [None]:
# Build patient sequences
print("Building patient sequences...")
sequences = sequence_builder.build_sequences(list(patient_visits.values()), min_visits=2)

print(f"\nâœ… Built {len(sequences)} sequences")
print(f"   Filtered out: {len(patient_visits) - len(sequences)} patients (< 2 visits)")
print(f"\nðŸ“Š Data Shape: List[PatientSequence]")
print(f"   - Length: {len(sequences)}")
print(f"   - Each PatientSequence contains:")
print(f"     â€¢ patient_id: str")
print(f"     â€¢ visits: List[Visit]")
print(f"     â€¢ sequence_length: int")
print(f"     â€¢ metadata: Optional[Dict]")

Building patient sequences...


AttributeError: 'str' object has no attribute 'patient_id'

---

## 3. Examine Raw Patient Sequence

Let's look at a patient sequence before encoding.

In [None]:
# Pick a sample sequence
sample_seq = sequences[0]

print("Sample Patient Sequence (Before Encoding):")
print("=" * 70)
print(f"Patient ID: {sample_seq.patient_id}")
print(f"Number of visits: {sample_seq.sequence_length}")
print(f"\nðŸ“Š Data Shape:")
print(f"   - Type: PatientSequence dataclass")
print(f"   - visits: List[Visit] with length {len(sample_seq.visits)}")

# Show first 3 visits
print(f"\nFirst 3 visits:")
for i, visit in enumerate(sample_seq.visits[:3]):
    print(f"\n  Visit {i+1}:")
    print(f"    Timestamp: {visit.timestamp}")
    print(f"    Number of codes: {visit.num_codes()}")
    print(f"    Code types: {list(visit.codes_by_type.keys())}")
    
    # Show some actual codes
    all_codes = visit.get_all_codes()
    print(f"    Sample codes (first 5): {all_codes[:5]}")
    print(f"    ðŸ“Š Shape: List[str] with length {len(all_codes)}")

---

## 4. Encode Sequences for LSTM

Transform string codes to integer IDs with proper padding and masking.

In [None]:
# Encode the sample sequence
encoded = sequence_builder.encode_sequence(sample_seq, return_tensors=False)

print("Encoded Sequence (LSTM-Ready Format):")
print("=" * 70)
print(f"Patient ID: {encoded['patient_id']}")
print(f"Sequence length: {encoded['sequence_length']} visits")

print(f"\nðŸ“Š Data Shapes After Encoding:")
print(f"   visit_codes: {np.array(encoded['visit_codes']).shape}")
print(f"      â†’ [num_visits={len(encoded['visit_codes'])}, max_codes_per_visit={len(encoded['visit_codes'][0])}]")
print(f"      â†’ Type: List[List[int]]")
print(f"\n   visit_mask: {np.array(encoded['visit_mask']).shape}")
print(f"      â†’ [num_visits={len(encoded['visit_mask'])}, max_codes_per_visit={len(encoded['visit_mask'][0])}]")
print(f"      â†’ Type: List[List[int]] (1=real code, 0=padding)")
print(f"\n   sequence_mask: {np.array(encoded['sequence_mask']).shape}")
print(f"      â†’ [num_visits={len(encoded['sequence_mask'])}]")
print(f"      â†’ Type: List[int] (1=real visit, 0=padding)")
print(f"\n   time_deltas: {np.array(encoded['time_deltas']).shape}")
print(f"      â†’ [num_visits-1={len(encoded['time_deltas'])}]")
print(f"      â†’ Type: List[float] (days between consecutive visits)")

In [None]:
# Show actual encoded values for first visit
print("First Visit - Detailed View:")
print("=" * 70)

first_visit_codes = encoded['visit_codes'][0]
first_visit_mask = encoded['visit_mask'][0]

# Count real vs padded codes
num_real_codes = sum(first_visit_mask)
num_padding = len(first_visit_mask) - num_real_codes

print(f"Real codes: {num_real_codes}")
print(f"Padding: {num_padding}")
print(f"\nFirst 10 code IDs: {first_visit_codes[:10]}")
print(f"First 10 mask values: {first_visit_mask[:10]}")
print(f"\nLast 10 code IDs (should be padding): {first_visit_codes[-10:]}")
print(f"Last 10 mask values (should be 0): {first_visit_mask[-10:]}")

---

## 5. Create Labels for Prediction Task

**Task:** Binary classification - predict if patient has diabetes.

We'll use SNOMED-CT codes for diabetes diagnosis.

In [None]:
# Define diabetes-related codes (SNOMED-CT)
diabetes_codes = {
    '44054006',   # Type 2 diabetes mellitus
    '46635009',   # Type 1 diabetes mellitus
    '73211009',   # Diabetes mellitus
    '11687002',   # Gestational diabetes
    '190330002',  # Diabetes mellitus without complication
    '190331003',  # Diabetes mellitus with complication
}

print(f"Diabetes codes: {diabetes_codes}")
print(f"Number of codes: {len(diabetes_codes)}")

In [None]:
# Create labels for all sequences
def has_diabetes(sequence):
    """Check if patient has diabetes based on their visit codes."""
    for visit in sequence.visits:
        all_codes = visit.get_all_codes()
        if any(code in diabetes_codes for code in all_codes):
            return True
    return False

# Create dataset with labels
dataset_items = []
for seq in sequences:
    encoded = sequence_builder.encode_sequence(seq, return_tensors=False)
    label = 1 if has_diabetes(seq) else 0
    
    dataset_items.append({
        'patient_id': seq.patient_id,
        'visit_codes': encoded['visit_codes'],
        'visit_mask': encoded['visit_mask'],
        'sequence_mask': encoded['sequence_mask'],
        'time_deltas': encoded['time_deltas'],
        'label': label
    })

print(f"\nâœ… Created {len(dataset_items)} labeled sequences")
print(f"\nðŸ“Š Dataset Item Shape:")
print(f"   - Type: List[Dict]")
print(f"   - Each dict contains:")
print(f"     â€¢ patient_id: str")
print(f"     â€¢ visit_codes: List[List[int]] shape [num_visits, max_codes_per_visit]")
print(f"     â€¢ visit_mask: List[List[int]] shape [num_visits, max_codes_per_visit]")
print(f"     â€¢ sequence_mask: List[int] shape [num_visits]")
print(f"     â€¢ time_deltas: List[float] shape [num_visits-1]")
print(f"     â€¢ label: int (0 or 1)")

# Label distribution
num_positive = sum(item['label'] for item in dataset_items)
num_negative = len(dataset_items) - num_positive

print(f"\nLabel Distribution:")
print(f"   Positive (has diabetes): {num_positive} ({num_positive/len(dataset_items)*100:.1f}%)")
print(f"   Negative (no diabetes): {num_negative} ({num_negative/len(dataset_items)*100:.1f}%)")

---

## 6. Create Batched Tensors for LSTM

Convert to PyTorch tensors and demonstrate batching with proper padding.

In [None]:
# Collate function (same as in train_lstm_baseline.py)
def collate_fn(batch):
    """
    Collate function for DataLoader.
    Handles variable-length sequences and creates proper masks.
    """
    # Extract data
    visit_codes = [item['visit_codes'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.float32)
    
    # Get dimensions
    batch_size = len(visit_codes)
    max_visits = max(len(seq) for seq in visit_codes)
    max_codes = max(max(len(visit) for visit in seq) for seq in visit_codes)
    
    # Create padded tensors
    padded_codes = torch.zeros(batch_size, max_visits, max_codes, dtype=torch.long)
    visit_mask = torch.zeros(batch_size, max_visits, max_codes, dtype=torch.bool)
    sequence_mask = torch.zeros(batch_size, max_visits, dtype=torch.bool)
    
    # Fill tensors
    for i, seq in enumerate(visit_codes):
        sequence_mask[i, :len(seq)] = 1
        for j, visit in enumerate(seq):
            padded_codes[i, j, :len(visit)] = torch.tensor(visit)
            visit_mask[i, j, :len(visit)] = 1
    
    return {
        'visit_codes': padded_codes,
        'visit_mask': visit_mask,
        'sequence_mask': sequence_mask,
        'labels': labels.unsqueeze(1)
    }

print("âœ… Collate function defined")

In [None]:
# Create a sample batch
batch_size = 4
sample_batch = dataset_items[:batch_size]

# Collate the batch
batched_data = collate_fn(sample_batch)

print("Sample Batch (LSTM Model Input):")
print("=" * 70)
print(f"Batch size: {batch_size}")
print(f"\nðŸ“Š Tensor Shapes:")
print(f"\n   visit_codes: {batched_data['visit_codes'].shape}")
print(f"      â†’ [batch_size={batched_data['visit_codes'].shape[0]}, ")
print(f"         max_visits={batched_data['visit_codes'].shape[1]}, ")
print(f"         max_codes_per_visit={batched_data['visit_codes'].shape[2]}]")
print(f"      â†’ dtype: {batched_data['visit_codes'].dtype}")
print(f"\n   visit_mask: {batched_data['visit_mask'].shape}")
print(f"      â†’ [batch_size, max_visits, max_codes_per_visit]")
print(f"      â†’ dtype: {batched_data['visit_mask'].dtype}")
print(f"\n   sequence_mask: {batched_data['sequence_mask'].shape}")
print(f"      â†’ [batch_size, max_visits]")
print(f"      â†’ dtype: {batched_data['sequence_mask'].dtype}")
print(f"\n   labels: {batched_data['labels'].shape}")
print(f"      â†’ [batch_size, 1]")
print(f"      â†’ dtype: {batched_data['labels'].dtype}")

print(f"\n\nMemory footprint:")
print(f"   visit_codes: {batched_data['visit_codes'].numel() * 8 / 1024:.2f} KB")
print(f"   visit_mask: {batched_data['visit_mask'].numel() / 1024:.2f} KB")
print(f"   sequence_mask: {batched_data['sequence_mask'].numel() / 1024:.2f} KB")

In [None]:
# Visualize batch structure for first patient
print("First Patient in Batch - Detailed View:")
print("=" * 70)

patient_0_codes = batched_data['visit_codes'][0]
patient_0_visit_mask = batched_data['visit_mask'][0]
patient_0_seq_mask = batched_data['sequence_mask'][0]
patient_0_label = batched_data['labels'][0]

# Count real visits
num_real_visits = patient_0_seq_mask.sum().item()
print(f"Number of real visits: {num_real_visits}")
print(f"Label: {patient_0_label.item()} ({'Diabetes' if patient_0_label.item() == 1 else 'No Diabetes'})")

# Show first visit details
print(f"\nFirst visit:")
first_visit_codes = patient_0_codes[0]
first_visit_mask = patient_0_visit_mask[0]
num_real_codes = first_visit_mask.sum().item()
print(f"   Real codes: {num_real_codes}")
print(f"   Code IDs (first 10): {first_visit_codes[:10].tolist()}")
print(f"   Mask (first 10): {first_visit_mask[:10].tolist()}")

---

## 7. Create LSTM Model and Test Forward Pass

Instantiate the LSTM baseline model and run a forward pass to verify shapes.

In [None]:
# Create LSTM model
model = create_lstm_baseline(
    vocab_size=len(vocab),
    task='binary_classification',
    model_size='small'
)

print("LSTM Baseline Model:")
print("=" * 70)
print(f"Vocabulary size: {len(vocab)}")
print(f"Embedding dim: 128")
print(f"Hidden dim: 256")
print(f"Number of layers: 1")
print(f"Task: Binary classification")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Run forward pass
model.eval()
with torch.no_grad():
    output = model(
        visit_codes=batched_data['visit_codes'],
        visit_mask=batched_data['visit_mask'],
        sequence_mask=batched_data['sequence_mask'],
        return_hidden=True
    )

print("Model Output:")
print("=" * 70)
print(f"\nðŸ“Š Output Shapes:")
print(f"\n   logits: {output['logits'].shape}")
print(f"      â†’ [batch_size={output['logits'].shape[0]}, output_dim={output['logits'].shape[1]}]")
print(f"      â†’ Raw predictions before sigmoid")
print(f"\n   predictions: {output['predictions'].shape}")
print(f"      â†’ [batch_size={output['predictions'].shape[0]}, output_dim={output['predictions'].shape[1]}]")
print(f"      â†’ After sigmoid activation (probabilities)")
print(f"\n   hidden_states: {output['hidden_states'].shape}")
print(f"      â†’ [batch_size={output['hidden_states'].shape[0]}, ")
print(f"         num_visits={output['hidden_states'].shape[1]}, ")
print(f"         hidden_dim={output['hidden_states'].shape[2]}]")
print(f"      â†’ LSTM hidden states for each visit")

print(f"\n\nPredictions:")
for i in range(batch_size):
    prob = output['predictions'][i, 0].item()
    true_label = batched_data['labels'][i, 0].item()
    print(f"   Patient {i+1}: P(diabetes) = {prob:.4f}, True label = {int(true_label)}")

---

## 8. Summary: Complete Data Flow

Let's visualize the complete transformation pipeline.

In [None]:
print("Complete Data Transformation Pipeline:")
print("=" * 70)
print("\n1. Raw Synthea CSV Files")
print("   â””â”€> patients.csv, encounters.csv, conditions.csv, etc.")
print("\n2. SyntheaAdapter.load_events()")
print("   â””â”€> List[MedicalEvent]")
print("       â€¢ Each event has: patient_id, timestamp, code, code_type")
print("\n3. VisitGrouper.group_events()")
print("   â””â”€> Dict[str, List[Visit]]")
print("       â€¢ Key: patient_id")
print("       â€¢ Value: List of Visit objects")
print("       â€¢ Each Visit has: visit_id, timestamp, codes_by_type")
print("\n4. PatientSequenceBuilder.build_sequences()")
print("   â””â”€> List[PatientSequence]")
print("       â€¢ Each sequence has: patient_id, visits, sequence_length")
print("\n5. PatientSequenceBuilder.encode_sequence()")
print("   â””â”€> Dict with:")
print("       â€¢ visit_codes: List[List[int]] - [num_visits, max_codes_per_visit]")
print("       â€¢ visit_mask: List[List[int]] - [num_visits, max_codes_per_visit]")
print("       â€¢ sequence_mask: List[int] - [num_visits]")
print("\n6. Add Labels")
print("   â””â”€> List[Dict] with encoded data + label")
print("       â€¢ label: int (0 or 1 for binary classification)")
print("\n7. collate_fn() - Batch Creation")
print("   â””â”€> Dict with PyTorch tensors:")
print("       â€¢ visit_codes: [batch_size, max_visits, max_codes_per_visit]")
print("       â€¢ visit_mask: [batch_size, max_visits, max_codes_per_visit]")
print("       â€¢ sequence_mask: [batch_size, max_visits]")
print("       â€¢ labels: [batch_size, 1]")
print("\n8. LSTM Model Forward Pass")
print("   â””â”€> Dict with:")
print("       â€¢ logits: [batch_size, 1] - Raw predictions")
print("       â€¢ predictions: [batch_size, 1] - Probabilities (after sigmoid)")
print("       â€¢ hidden_states: [batch_size, max_visits, hidden_dim]")
print("\n" + "=" * 70)

In [None]:
# Create a visual diagram
fig, ax = plt.subplots(figsize=(14, 10))
ax.axis('off')

# Define stages
stages = [
    ("Raw Data\n(CSV Files)", "patients.csv\nencounters.csv\nconditions.csv\nobservations.csv"),
    ("Medical Events\n(List[MedicalEvent])", f"{sum(len(adapter.load_events([pid])) for pid in patient_ids[:5])} events\n(sample)"),
    ("Visit Groups\n(Dict[str, List[Visit]])", f"{len(patient_visits)} patients\n{sum(len(v) for v in patient_visits.values())} visits"),
    ("Patient Sequences\n(List[PatientSequence])", f"{len(sequences)} sequences\nmin_visits â‰¥ 2"),
    ("Encoded Sequences\n(List[Dict])", f"{len(dataset_items)} items\nwith labels"),
    ("Batched Tensors\n(PyTorch)", f"[{batch_size}, {batched_data['visit_codes'].shape[1]}, {batched_data['visit_codes'].shape[2]}]"),
    ("Model Output\n(Predictions)", f"[{batch_size}, 1]\nprobabilities")
]

y_positions = np.linspace(0.9, 0.1, len(stages))

for i, ((title, desc), y) in enumerate(zip(stages, y_positions)):
    # Draw box
    box = plt.Rectangle((0.2, y-0.05), 0.6, 0.08, 
                        facecolor='lightblue', edgecolor='black', linewidth=2)
    ax.add_patch(box)
    
    # Add text
    ax.text(0.5, y+0.02, title, ha='center', va='center', 
           fontsize=12, fontweight='bold')
    ax.text(0.5, y-0.02, desc, ha='center', va='center', 
           fontsize=9, style='italic')
    
    # Draw arrow to next stage
    if i < len(stages) - 1:
        ax.arrow(0.5, y-0.05, 0, -0.04, head_width=0.03, head_length=0.01,
                fc='black', ec='black', linewidth=2)

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.title('EHR Data Transformation Pipeline for LSTM', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("\nâœ… Pipeline visualization complete!")

---

## Conclusion

This notebook demonstrated the complete pipeline for preparing EHR sequences for the LSTM baseline model:

1. âœ… **Loaded and grouped** raw Synthea data into visits
2. âœ… **Built patient sequences** with vocabulary
3. âœ… **Encoded sequences** to integer IDs with padding/masking
4. âœ… **Created labels** for diabetes prediction task
5. âœ… **Batched data** into PyTorch tensors
6. âœ… **Ran model forward pass** to verify shapes

### Key Takeaways:

- **Visit-level representation**: Codes within each visit are aggregated (mean/sum/attention)
- **Sequence-level modeling**: LSTM captures temporal dependencies across visits
- **Proper masking**: Essential for handling variable-length sequences
- **Shape transformations**: From raw CSV â†’ tensors â†’ predictions

### Next Steps:

- See `data_shape_transformations.md` for detailed shape documentation
- See `examples/train_lstm_baseline.py` for full training script
- Experiment with different prediction tasks (readmission, mortality, etc.)
- Try different model configurations (attention, bidirectional LSTM, etc.)