## 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [1]:
import random

import numpy as np
import torch

from pyhealth.datasets import TUEVDataset
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader
from pyhealth.models import TFMTokenizer

SEED = 5
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


## 2. Load TUEV Dataset
Point to the TUEV dataset root and load the dataset.

In [2]:
dataset = TUEVDataset(
    root='/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf',  # Update this path
    subset='eval',
    # dev=True
)
dataset.stats()

No config path provided, using default config
Using subset: eval
Using cached metadata from /home/jp65/.cache/pyhealth/tuev
Initializing tuev dataset from /home/jp65/.cache/pyhealth/tuev (dev mode: False)
No cache_dir provided. Using default cache dir: /home/jp65/.cache/pyhealth/2e360290-5ee2-591d-aaca-45052d892fb1
Found cached event dataframe: /home/jp65/.cache/pyhealth/2e360290-5ee2-591d-aaca-45052d892fb1/global_event_df.parquet
Dataset: tuev
Dev mode: False
Number of patients: 80
Number of events: 159


## 3. Prepare PyHealth Dataset
Set the task for the dataset and convert raw samples into PyHealth format for abnormal EEG classification.

In [3]:
sample_dataset = dataset.set_task(EEGEventsTUEV(
    resample_rate=200,    # Resample rate
    bandpass_filter=(0.1, 75.0),    # Bandpass filter
    notch_filter=50.0,    # Notch filter
    normalization='95th_percentile'
))

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

# Inspect a sample
sample = sample_dataset[0]
print(f"\nSample keys: {sample.keys()}")
print(f"Signal shape: {sample['signal'].shape}")
print(f"Label: {sample['label']}")

Setting task EEG_events for tuev base dataset...
Task cache paths: task_df=/home/jp65/.cache/pyhealth/2e360290-5ee2-591d-aaca-45052d892fb1/tasks/EEG_events_107ed06b-ca3c-5d49-87ae-a690ae834ab8/task_df.ld, samples=/home/jp65/.cache/pyhealth/2e360290-5ee2-591d-aaca-45052d892fb1/tasks/EEG_events_107ed06b-ca3c-5d49-87ae-a690ae834ab8/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld
Found cached processed samples at /home/jp65/.cache/pyhealth/2e360290-5ee2-591d-aaca-45052d892fb1/tasks/EEG_events_107ed06b-ca3c-5d49-87ae-a690ae834ab8/samples_cdbbc602-34e2-5a41-8643-4c76b08829f6.ld, skipping processing.
Total task samples: 29421
Input schema: {'signal': 'tensor'}
Output schema: {'label': 'multiclass'}

Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])
Signal shape: torch.Size([16, 1000])
Label: 5


In [24]:
test_loader = get_dataloader(
    dataset=sample_dataset,
    batch_size=1024,
    shuffle=False,
)
print(f"Test loader size: {len(test_loader)}")

Test loader size: 29


## 4. Initialize TFM-Tokenizer Model

In [None]:
model = TFMTokenizer(
    dataset=sample_dataset,
    emb_size=64,
    code_book_size=8192,
    trans_freq_encoder_depth=2,
    trans_temporal_encoder_depth=2,
    trans_decoder_depth=8,
    use_classifier=True,
    classifier_depth=4,
)

model = model.to(device)

# tokenizer_checkpoint_path = '[From the TFM-Tokenizer GitHub] - pretrained_weigths/multiple_dataset_settings/Pretrained_tfm_tokenizer_2x2x8/tfm_tokenizer_last.pth'
model.load_pretrained_weights(
    tokenizer_checkpoint_path='/home/jp65/Biosignals_Research/tfm_token_code_for_release/TFM-Tokenizer/pretrained_weigths/multiple_dataset_settings/Pretrained_tfm_tokenizer_2x2x8/tfm_tokenizer_last.pth',
    classifier_checkpoint_path='/home/jp65/PyHealth/TFM_Tokenizer_multiple_finetuned_on_TUEV_1/best_model.pth',
    is_masked_training=False,
    strict=False,
    map_location=device
)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

✓ Successfully loaded weights from /home/jp65/PyHealth/TFM_Tokenizer_multiple_finetuned_on_TUEV_1/best_model.pth
Model created with 1891114 parameters


## 5. Test Forward Pass

In [26]:
batch = next(iter(test_loader))

with torch.no_grad():
    outputs = model(**batch)

print("Output keys:", outputs.keys())
print(f"Loss: {outputs['loss'].item():.4f}")
print(f"Logits shape: {outputs['logit'].shape}")
print(f"Tokens shape: {outputs['tokens'].shape}")
print(f"Embeddings shape: {outputs['embeddings'].shape}")

Output keys: dict_keys(['recon_loss', 'vq_loss', 'tokens', 'embeddings', 'loss', 'cls_loss', 'y_prob', 'y_true', 'logit'])
Loss: 3.5373
Logits shape: torch.Size([1024, 6])
Tokens shape: torch.Size([1024, 16, 9])
Embeddings shape: torch.Size([1024, 16, 9, 64])


# 6. Inference on Test Dataloader

In [27]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device=device,metrics=["accuracy", "balanced_accuracy", "cohen_kappa"],)


TFMTokenizer(
  (tokenizer): TFM_VQVAE2_deep(
    (freq_patch_embedding): Sequential(
      (0): Conv1d(1, 64, kernel_size=(5,), stride=(5,))
      (1): GELU(approximate='none')
      (2): GroupNorm(16, 64, eps=1e-05, affine=True)
      (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      (4): GELU(approximate='none')
      (5): GroupNorm(16, 64, eps=1e-05, affine=True)
      (6): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      (7): GELU(approximate='none')
      (8): GroupNorm(16, 64, eps=1e-05, affine=True)
    )
    (trans_freq_encoder): TransformerEncoder(
      (transformer): LinearAttentionTransformer(
        (layers): SequentialSequence(
          (layers): ModuleList(
            (0-1): 2 x ModuleList(
              (0): PreNorm(
                (fn): SelfAttention(
                  (local_attn): LocalAttention(
                    (dropout): Dropout(p=0.2, inplace=False)
                  )
                  (to_q): Linear(in_features=64, out_features=64, bias=False)

In [28]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 29/29 [00:31<00:00,  1.09s/it]


{'accuracy': 0.7779817137418851,
 'balanced_accuracy': 0.5872575097440921,
 'cohen_kappa': 0.5892794513693542,
 'loss': 2.781993134268399}