# Using REVE to classify EEG

In this tutorial, we train a classification head on the REVE model to demonstrate how it can be used as a powerfull off-the-shelf feature extractor.

## Loading REVE

To load the REVE model, go to the [Hugging Face collection](https://huggingface.co/collections/brain-bzh/reve) and select your model size (e.g. base).

The model is gated so you need to accept the terms of the form on the website. In general, you will need to authenticate before using the model. You can do this in the CLI here :

In [1]:
!hf auth login

/usr/bin/bash: line 1: hf: command not found


In [4]:
from huggingface_hub import login
import os
login(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)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:

from transformers import AutoModel
import torch

model = AutoModel.from_pretrained("brain-bzh/reve-base", trust_remote_code=True, torch_dtype="auto")
pos_bank = AutoModel.from_pretrained("brain-bzh/reve-positions", trust_remote_code=True, torch_dtype="auto")

  from .autonotebook import tqdm as notebook_tqdm


OSError: We couldn't connect to 'https://huggingface.co/' to load this model and it looks like brain-bzh/reve-base is not the path to a directory conaining a config.json file.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

In [5]:
# Show all pre-registered positions
print(pos_bank.get_all_positions())


['A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1', 'FP2', 'FZ', 'O1', 'O2', 'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6', 'OZ', 'Cz', 'Fpz', 'Fz', 'P7', 'P8', 'Pz', 'T7', 'T8', 'C1', 'C2', 'C5', 'C6', 'CP1', 'CP2', 'CP3', 'CP4', 'CPz', 'FC1', 'FC2', 'FC3', 'FC4', 'FCz', 'P1', 'P2', 'POz', 'AFz', 'P5', 'P6', 'PO3', 'PO4', 'AF3', 'AF4', 'AF7', 'AF8', 'CP5', 'CP6', 'F1', 'F2', 'F5', 'F6', 'FC5', 'FC6', 'FT7', 'FT8', 'Fp1', 'Fp2', 'Iz', 'Oz', 'P10', 'P9', 'PO7', 'PO8', 'TP7', 'TP8', 'F10', 'F9', 'FT10', 'FT9', 'FTT10h', 'FTT9h', 'PO10', 'PO9', 'TP10', 'TP9', 'TPP10h', 'TPP8h', 'TPP9h', 'TTP7h', 'CCP1h', 'CCP2h', 'CCP3h', 'CCP4h', 'CCP5h', 'CCP6h', 'CPP1h', 'CPP2h', 'CPP3h', 'CPP4h', 'CPP5h', 'CPP6h', 'FCC1h', 'FCC2h', 'FCC3h', 'FCC4h', 'FCC5h', 'FCC6h', 'FFC1h', 'FFC2h', 'FFC3h', 'FFC4h', 'FFC5h', 'FFC6h', 'FTT7h', 'FTT8h', 'PPO1h', 'PPO2h', 'TTP8h', 'T10', 'T9', 'AFF1h', 'AFF2h', 'AFF5h', 'AFF6h', 'AFp1', 'AFp2', 'POO1', 'POO2', 'PO5', 'PO6', 'AFF1', 'AFF2', 'AFp3h', 'AFp4h', 'FFT7

In [6]:
# Inspect the model
print(model)

Reve(
  (transformer): TransformerBackbone(
    (layers): ModuleList(
      (0-21): 22 x ModuleList(
        (0): Attention(
          (norm): RMSNorm()
          (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
          (to_out): Linear(in_features=512, out_features=512, bias=False)
          (attend): ClassicalAttention()
        )
        (1): FeedForward(
          (net): Sequential(
            (0): RMSNorm()
            (1): Linear(in_features=512, out_features=2722, bias=False)
            (2): GEGLU()
            (3): Linear(in_features=1361, out_features=512, bias=False)
          )
        )
      )
    )
  )
  (to_patch_embedding): Sequential(
    (0): Linear(in_features=200, out_features=512, bias=True)
  )
  (fourier4d): FourierEmb4D()
  (mlp4d): Sequential(
    (0): Linear(in_features=4, out_features=512, bias=False)
    (1): GELU(approximate='none')
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (ln): LayerNorm((512,), eps=1e-05, el

At this stage, the last layer of the model is `Identity()`, we will replace it with a classifier layer suited for our case.
The dataset has 20 channels, and samples of 5s. The model output will be of size `[B, 20, 5, D]`, with `B` the batch size and `D` the hidden dimension (512 for the base model). We thus set the final layer to be of size `20*5*D`.

In [7]:
dim = 20 * 5 * 512

model.final_layer = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.RMSNorm(dim),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(dim, 2),
)

NB: in practice, when using REVE on your tasks, you will only need to load it from Huggingface and modify the last layer.

## Training scripts

You can adjust some of the training parameters here.

In [8]:
# Training parameters
batch_size = 64
n_epochs = 20
lr = 1e-3
positions = pos_bank(["Fp1", "Fp2", "F3", "F4", "F7", "F8", "T3", "T4", "C3", "C4", "T5", "T6", "P3", "P4", "O1", "O2", "Fz", "Cz", "Pz", "A2"])

In [9]:
from transformers import set_seed

set_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

`brain-bzh/eegmat-prepro` on Hugging Face is a preprocessed [version](https://huggingface.co/datasets/brain-bzh/eegmat-prepro) of the [EEGMAT](https://physionet.org/content/eegmat/1.0.0/) dataset, originally uploaded to PhysioNet.

In [10]:
from functools import partial
from datasets import load_dataset

dataset = load_dataset("brain-bzh/eegmat-prepro")
dataset.set_format("torch", columns=["data", "labels"])

print(dataset)

def collate(batch, positions):
    x_data = torch.stack([x["data"] for x in batch])
    y_label = torch.tensor([x["labels"] for x in batch])
    positions = positions.repeat(len(batch), 1, 1)
    return {"sample": x_data,"label": y_label.long(),"pos": positions}
collate_fn = partial(collate, positions=positions)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(dataset["val"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset["test"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

README.md:   0%|          | 0.00/823 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

data/val-00000-of-00001.parquet:   0%|          | 0.00/27.8M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/31.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1343 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/172 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/192 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['data', 'labels'],
        num_rows: 1343
    })
    val: Dataset({
        features: ['data', 'labels'],
        num_rows: 172
    })
    test: Dataset({
        features: ['data', 'labels'],
        num_rows: 192
    })
})


### Training functions

For simplicity, we only implement a basic training loop that does not include:
- model souping
- LoRA wrappers
- Channel Mixup
- Position augmentation
- Two stage fine-tuning
- Stable AdamW optimizer

The results might thus differ slightly from the paper.

In [11]:
from tqdm.auto import tqdm
from sklearn.metrics import balanced_accuracy_score, cohen_kappa_score, f1_score, roc_auc_score, average_precision_score

from functools import partial

def train_one_epoch(model, optimizer, loader):
    model.train()
    pbar = tqdm(loader, desc="Training", total=len(loader))

    for batch_data in pbar:
        data, target, pos = (
            batch_data["sample"].to(device, non_blocking=True),
            batch_data["label"].to(device, non_blocking=True),
            batch_data["pos"].to(device, non_blocking=True),
        )
        optimizer.zero_grad()
        with torch.amp.autocast(dtype=torch.float16, device_type="cuda" if torch.cuda.is_available() else "cpu"):
            output = model(data, pos)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item()})


def eval_model(model, loader):
    model.eval()

    y_decisions = []
    y_targets = []
    y_probs = []
    score, count = 0, 0
    pbar = tqdm(loader, desc="Evaluating", total=len(loader))
    with torch.inference_mode():
        for batch_data in pbar:
            data, target, pos = (
                batch_data["sample"].to(device, non_blocking=True),
                batch_data["label"].to(device, non_blocking=True),
                batch_data["pos"].to(device, non_blocking=True),
            )
            with torch.amp.autocast(
                dtype=torch.float16, device_type="cuda" if torch.cuda.is_available() else "cpu"
            ):
                output = model(data, pos)

            decisions = torch.argmax(output, dim=1)
            score += (decisions == target).int().sum().item()
            count += target.shape[0]
            y_decisions.append(decisions)
            y_targets.append(target)
            y_probs.append(output)

    gt = torch.cat(y_targets).cpu().numpy()
    pr = torch.cat(y_decisions).cpu().numpy()
    pr_probs = torch.cat(y_probs).cpu().numpy()
    acc = score / count
    balanced_acc = balanced_accuracy_score(gt, pr)
    cohen_kappa = cohen_kappa_score(gt, pr)
    f1 = f1_score(gt, pr, average="weighted")

    auroc = roc_auc_score(gt, pr_probs[:, 1])
    auc_pr = average_precision_score(gt, pr_probs[:, 1])

    return acc, balanced_acc, cohen_kappa, f1, auroc, auc_pr

### Train the model

We freeze the backbone and only fine-tune the classification head.

The best classifier is used on the test set at the end of the training.

In [14]:
optimizer = torch.optim.AdamW(model.final_layer.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3)


criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

best_val_acc = 0
best_final_layer = None

for epoch in range(n_epochs):
    print(f"Epoch {epoch + 1}/{n_epochs}")
    train_one_epoch(model, optimizer, train_loader)
    _, b_acc, _, _, _, _ = eval_model(model, val_loader)
    if b_acc > best_val_acc:
        best_val_acc = b_acc
        best_final_layer = model.final_layer.state_dict()
    print(f"Validation balanced accuracy: {b_acc:.4f}, best: {best_val_acc:.4f}")
    scheduler.step(b_acc)


model.final_layer.load_state_dict(best_final_layer)
acc, balanced_acc, cohen_kappa, f1, auroc, auc_pr = eval_model(model, test_loader)

# Results
print("acc", acc)
print("balanced_acc", balanced_acc)
print("cohen_kappa", cohen_kappa)
print("f1", f1)
print("auroc", auroc)
print("auc_pr", auc_pr)


Epoch 1/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.5000, best: 0.5000
Epoch 2/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6149, best: 0.6149
Epoch 3/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6166, best: 0.6166
Epoch 4/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6142, best: 0.6166
Epoch 5/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6526, best: 0.6526
Epoch 6/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.5974, best: 0.6526
Epoch 7/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6270, best: 0.6526
Epoch 8/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6630, best: 0.6630
Epoch 9/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6190, best: 0.6630
Epoch 10/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6166, best: 0.6630
Epoch 11/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6310, best: 0.6630
Epoch 12/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6630, best: 0.6630
Epoch 13/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6149, best: 0.6630
Epoch 14/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation balanced accuracy: 0.6190, best: 0.6630
Epoch 15/20


Training:   0%|          | 0/21 [00:00<?, ?it/s]

KeyboardInterrupt: 