# TUEV + TFMTokenizer

This notebook trains the `TFMTokenizer` on TUEV.

## 1. Environment Setup
Seed randomness, import dependencies, and select compute device.

In [1]:
from __future__ import annotations

import random
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange

from pyhealth.datasets import TUEVDataset
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader
from pyhealth.models import TFMTokenizer
from pyhealth.tasks import EEGEventsTUEV

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {DEVICE}")

Running on device: cuda


## 2. Configuration
Set dataset path and quick-run hyperparameters.

In [2]:
ROOT = Path(r"../../downloads/tuev/v2.0.1/edf")
CACHE_DIR = "cache"

RESAMPLING_RATE = 200
BATCH_SIZE = 8
EPOCHS = 20
LR = 1e-4

DEBUG_MAX_SAMPLES = 512  # set to None for full dataset

if not ROOT.exists():
    raise FileNotFoundError(f"TUEV root not found: {ROOT}. Update ROOT before running.")

print({
    "ROOT": str(ROOT),
    "RESAMPLING_RATE": RESAMPLING_RATE,
    "BATCH_SIZE": BATCH_SIZE,
    "EPOCHS": EPOCHS,
    "LR": LR,
    "DEBUG_MAX_SAMPLES": DEBUG_MAX_SAMPLES,
})

{'ROOT': '..\\..\\downloads\\tuev\\v2.0.1\\edf', 'RESAMPLING_RATE': 200, 'BATCH_SIZE': 8, 'EPOCHS': 20, 'LR': 0.0001, 'DEBUG_MAX_SAMPLES': 512}


## 3. Load and Split TUEV Dataset
Create task samples using `EEGEventsTUEV`, then split into train/val/test loaders.

In [3]:
dataset = TUEVDataset(root=str(ROOT), subset="both", dev=True)
sample_dataset = dataset.set_task(EEGEventsTUEV(), cache_dir=CACHE_DIR)

if DEBUG_MAX_SAMPLES is not None:
    n = min(DEBUG_MAX_SAMPLES, len(sample_dataset))
    sample_dataset = sample_dataset.subset(list(range(n)))

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE, shuffle=False) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE, shuffle=False) if len(test_ds) else None

print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

No config path provided, using default config
Using both train and eval subsets
Initializing tuev dataset from ..\..\downloads\tuev\v2.0.1\edf (dev mode: True)
Setting task EEG_events for tuev base dataset...
Found cached processed samples at cache\samples_47c27255-9fc0-5271-bd99-638ffecdb1cc.ld, skipping processing.
Total task samples: 512
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}
Train/Val/Test sizes: 358, 51, 103


  return SampleDataset(


## 4. STFT Function (Torch)
Use the exact STFT helper for multi-channel EEG batches.

In [4]:
def get_stft_torch(X, resampling_rate = 200):
    B,C,T = X.shape
    x_temp = rearrange(X, 'B C T -> (B C) T')
    window = torch.hann_window(resampling_rate).to(x_temp.device)
    x_stft_temp = torch.abs(torch.stft(x_temp, n_fft=resampling_rate, hop_length=resampling_rate//2, 
                          onesided = True,
                          return_complex=True, center = False,#normalized = True,
                          window = window)[:,:resampling_rate//2,:])
    
    x_stft_temp = rearrange(x_stft_temp, '(B C) F T -> B C F T', B=B)
    
    return x_stft_temp

## 5. Inspect One Batch
Verify raw EEG batch shape and derived STFT shape before training.

In [5]:
batch0 = next(iter(train_loader))
x0 = batch0["signal"].float().to(DEVICE)
x0_stft = get_stft_torch(x0, resampling_rate=RESAMPLING_RATE)

print("signal shape:", tuple(x0.shape), "(B, C, T)")
print("stft shape:", tuple(x0_stft.shape), "(B, C, F, T_stft)")
print("label shape:", tuple(batch0["label"].shape))

signal shape: (8, 16, 1000) (B, C, T)
stft shape: (8, 16, 100, 9) (B, C, F, T_stft)
label shape: (8,)


## 6. Initialize TFMTokenizer
Instantiate `TFMTokenizer` on the TUEV task dataset and optimizer.

In [6]:
model = TFMTokenizer(
    dataset=sample_dataset,
    emb_size=64,
    code_book_size=8192,
    trans_freq_encoder_depth=2,
    trans_temporal_encoder_depth=2,
    trans_decoder_depth=8,
    use_classifier=True,
)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = model.get_loss_function()

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

Model parameters: 1,891,114


## 7. Training Utilities
Use STFT + multi-channel flattening in each training/eval step.

In [7]:
def run_step(model, batch, training=True):
    x = batch["signal"].float().to(DEVICE)          # (B, C, T)
    y = batch["label"].long().to(DEVICE)

    x_stft = get_stft_torch(x, resampling_rate=RESAMPLING_RATE)

    x_stft = rearrange(x_stft, "B C F T -> (B C) F T")
    x = rearrange(x, "B C T -> (B C) T")

    reconstructed, tokens_flat, quant_out, quant_in = model.tokenizer(x_stft, x)
    recon_loss = F.mse_loss(reconstructed, x_stft)
    vq_loss, _, _ = model.tokenizer.vec_quantizer_loss(quant_in, quant_out)

    B, C = batch["signal"].shape[0], batch["signal"].shape[1]
    tokens_mc = tokens_flat.reshape(B, C, -1)
    logits = model.classifier(tokens_mc, num_ch=C)

    cls_loss = criterion(logits, y)
    total_loss = recon_loss + vq_loss + cls_loss

    if training:
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    preds = torch.argmax(logits, dim=1)
    acc = (preds == y).float().mean().item()

    return {
        "loss": total_loss.item(),
        "cls_loss": cls_loss.item(),
        "recon_loss": recon_loss.item(),
        "vq_loss": vq_loss.item(),
        "acc": acc,
    }

@torch.no_grad()
def evaluate(model, loader):
    if loader is None:
        return None

    model.eval()
    stats = []
    for batch in loader:
        stats.append(run_step(model, batch, training=False))

    keys = stats[0].keys()
    return {k: float(np.mean([s[k] for s in stats])) for k in keys}

## 8. Train
Run a short training loop and report train/val metrics per epoch.

In [8]:
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_stats = []

    for batch in train_loader:
        train_stats.append(run_step(model, batch, training=True))

    train_mean = {k: float(np.mean([s[k] for s in train_stats])) for k in train_stats[0].keys()}
    val_mean = evaluate(model, val_loader)

    print(f"Epoch {epoch}/{EPOCHS}")
    print("  Train:", train_mean)
    if val_mean is not None:
        print("  Val:  ", val_mean)

Epoch 1/20
  Train: {'loss': 11417.664046223957, 'cls_loss': 0.9261853244569567, 'recon_loss': 11416.726513671874, 'vq_loss': 0.011308995944758257, 'acc': 0.7638888888888888}
  Val:   {'loss': 7567.515973772322, 'cls_loss': 0.6936458434377398, 'recon_loss': 7566.812430245535, 'vq_loss': 0.009907172726733344, 'acc': 0.7738095266478402}
Epoch 2/20
  Train: {'loss': 11375.016617838543, 'cls_loss': 0.6516648951503966, 'recon_loss': 11374.355457899306, 'vq_loss': 0.009574467399054104, 'acc': 0.796296297179328}
  Val:   {'loss': 7536.031389508928, 'cls_loss': 0.6579780748912266, 'recon_loss': 7535.36408342634, 'vq_loss': 0.009364819686327661, 'acc': 0.7738095266478402}
Epoch 3/20
  Train: {'loss': 11349.942599826389, 'cls_loss': 0.6224249478843477, 'recon_loss': 11349.3109375, 'vq_loss': 0.009151072489718597, 'acc': 0.7972222222222223}
  Val:   {'loss': 7508.786865234375, 'cls_loss': 0.6490445605346135, 'recon_loss': 7508.128557477678, 'vq_loss': 0.009073269353913409, 'acc': 0.77380952664784

## 9. Test Evaluation
Evaluate the trained model on the test split.

In [9]:
test_mean = evaluate(model, test_loader)
print("Test metrics:", test_mean)

Test metrics: {'loss': 10736.539438100961, 'cls_loss': 0.09225075351647459, 'recon_loss': 10736.439077524039, 'vq_loss': 0.007972102218235914, 'acc': 0.9711538461538461}


## 10. Extract Tokens (Optional)
Extract flattened multi-channel tokens from one test batch for downstream conformal workflows.

In [10]:
@torch.no_grad()
def extract_tokens_one_batch(model, loader):
    model.eval()
    batch = next(iter(loader))

    x = batch["signal"].float().to(DEVICE)
    x_stft = get_stft_torch(x, resampling_rate=RESAMPLING_RATE)

    B, C = x.shape[0], x.shape[1]
    x_stft = rearrange(x_stft, "B C F T -> (B C) F T")
    x = rearrange(x, "B C T -> (B C) T")

    _, tokens_flat, _, _ = model.tokenizer(x_stft, x)
    tokens_mc = tokens_flat.reshape(B, C, -1)
    return tokens_mc.cpu()

if test_loader is not None and len(test_ds) > 0:
    token_batch = extract_tokens_one_batch(model, test_loader)
    print("Token batch shape:", tuple(token_batch.shape), "(B, C, T_tokens)")
else:
    print("No test data available for token extraction.")

Token batch shape: (8, 16, 9) (B, C, T_tokens)
