# TFM-Tokenizer for EEG Signal Tokenization

This notebook demonstrates the TFM-Tokenizer model for tokenizing EEG signals into discrete tokens and continuous embeddings.

**Note**: This example uses dummy data. The EEG-specific processor for generating STFT features from raw signals is under development.

## 1. Environment Setup

In [1]:
import random
import numpy as np
import torch

from pyhealth.datasets import create_sample_dataset, get_dataloader
from pyhealth.models import TFMTokenizer, get_tfm_tokenizer_2x2x8

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

ModuleNotFoundError: No module named 'litdata'

## 2. Create Sample Dataset

TFM-Tokenizer expects two inputs:
- `stft`: STFT spectrogram of shape (n_freq, n_time), e.g., (100, 60)
- `signal`: Raw temporal signal of shape (n_samples,), e.g., (1280,)

For demonstration, we'll use dummy data:

In [None]:
# Create dummy samples (in practice, these would come from EEG preprocessing)
samples = [
    {
        "patient_id": f"patient-{i}",
        "visit_id": "visit-0",
        "stft": torch.randn(100, 60).numpy().tolist(),  # STFT spectrogram
        "signal": torch.randn(1280).numpy().tolist(),   # Raw signal
        "label": i % 6,  # 6 classes for TUEV events
    }
    for i in range(50)
]

input_schema = {
    "stft": "tensor",
    "signal": "tensor",
}
output_schema = {"label": "multiclass"}

dataset = create_sample_dataset(
    samples=samples,
    input_schema=input_schema,
    output_schema=output_schema,
    dataset_name="tfm_demo",
)

print(f"Created dataset with {len(dataset)} samples")
print(f"Input schema: {dataset.input_schema}")
print(f"Output schema: {dataset.output_schema}")

## 3. Split Dataset

In [None]:
from pyhealth.datasets.splitter import split_by_sample

train_data, val_data, test_data = split_by_sample(dataset, [0.7, 0.15, 0.15], seed=SEED)

print(f"Train: {len(train_data)} samples")
print(f"Val: {len(val_data)} samples")
print(f"Test: {len(test_data)} samples")

train_loader = get_dataloader(train_data, batch_size=8, shuffle=True)
val_loader = get_dataloader(val_data, batch_size=8, shuffle=False)
test_loader = get_dataloader(test_data, batch_size=8, shuffle=False)

## 4. Initialize TFM-Tokenizer Model

In [None]:
model = TFMTokenizer(
    dataset=dataset,
    emb_size=64,
    code_book_size=8192,
    trans_freq_encoder_depth=2,
    trans_temporal_encoder_depth=2,
    trans_decoder_depth=8,
    use_classifier=True,
    classifier_depth=4,
)

model = model.to(device)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

## 5. Test Forward Pass

In [None]:
batch = next(iter(train_loader))

with torch.no_grad():
    outputs = model(**batch)

print("Output keys:", outputs.keys())
print(f"Loss: {outputs['loss'].item():.4f}")
print(f"Logits shape: {outputs['logit'].shape}")
print(f"Tokens shape: {outputs['tokens'].shape}")
print(f"Embeddings shape: {outputs['embeddings'].shape}")

## 6. Train Model (Optional)

Train the model using PyHealth's Trainer:

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device=device)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=3,
    monitor="accuracy",
)

## 7. Extract Embeddings for Analysis

Extract patient embeddings for downstream tasks like clustering or conformal prediction:

In [None]:
# Extract embeddings from test set
test_embeddings = model.get_embeddings(test_loader)
print(f"Test embeddings shape: {test_embeddings.shape}")

# Get patient-level representation (mean pooling)
patient_embeddings = test_embeddings.mean(dim=1)
print(f"Patient-level embeddings shape: {patient_embeddings.shape}")

## 8. Extract Discrete Tokens

In [None]:
# Extract tokens from test set
test_tokens = model.get_tokens(test_loader)
print(f"Test tokens shape: {test_tokens.shape}")

# Analyze token vocabulary usage
unique_tokens = torch.unique(test_tokens)
print(f"Active tokens: {len(unique_tokens)} / {model.code_book_size}")
print(f"Token usage: {len(unique_tokens) / model.code_book_size * 100:.2f}%")

## 9. Patient Clustering (Example)

Use embeddings for k-means clustering:

In [None]:
from sklearn.cluster import KMeans

# Cluster patients based on embeddings
kmeans = KMeans(n_clusters=3, random_state=SEED)
clusters = kmeans.fit_predict(patient_embeddings.cpu().numpy())

print("Cluster distribution:")
unique, counts = np.unique(clusters, return_counts=True)
for cluster_id, count in zip(unique, counts):
    print(f"  Cluster {cluster_id}: {count} patients ({count/len(clusters)*100:.1f}%)")

## 10. Loading Pre-trained Weights

Load pre-trained weights from checkpoint:

In [None]:
# Uncomment and set the path to load pre-trained weights
# model.load_pretrained_weights("path/to/tfm_encoder_best_model.pth")
print("To load pre-trained weights:")
print("model.load_pretrained_weights('tfm_encoder_best_model.pth')")