
# Whisper Keyboard Event LoRA Fine-Tuning Notebook

This notebook translates the updated `spec.md` into executable code that relies on Hugging Face's `transformers` stack for Whisper fine-tuning with LoRA. Every section mirrors the spec: tokenizer creation, processor wiring, dataset preparation, and a minimal training loop that exercises `Seq2SeqTrainer` with LoRA-enabled encoder layers.



The workflow follows these stages:

1. Define the keyboard vocabulary exactly as specified and implement a `PreTrainedTokenizer` subclass for keyboard events.
2. Bind the tokenizer to a `WhisperProcessor` so data pipelines reuse `transformers` utilities end-to-end.
3. Provide a dataset implementation that loads (or, for testing, injects) audio/event pairs and emits tensors expected by Whisper.
4. Configure a lightweight Whisper model, wrap it with `peft` LoRA adapters on encoder attention projections, and fine-tune using `Seq2SeqTrainer`.
5. Verify the pipeline by running a tiny synthetic training/eval cycle and decoding generated event sequences.


In [None]:

import json
import math
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence

import librosa
import types
import transformers
import numpy as np
import torch
import torchaudio
import requests
from peft import LoraConfig, get_peft_model
from transformers import (
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    WhisperConfig,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
)

# Deterministic behavior keeps the demonstration reproducible.
SEED = 7
SAMPLING_RATE = 16000
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")


## Fetch the Dataset
The set of recordings along with their key press data is stored at kasca-esque.fly.dev/recordings

In [None]:
BASE_URL = "http://kasca-esque.fly.dev/recordings"
SAVE_DIR = Path("/home/matty/Projects/kasca-esque/finetune/recordings")
SAVE_DIR.mkdir(parents=True, exist_ok=True)
print(SAVE_DIR)

response = requests.get(BASE_URL, timeout=30)
resp.raise_for_status()
manifest = resp.json()
print(manifest)

def download_file(relative_path: str) -> Path:
    url = f"{BASE_URL}/{relative_path.lstrip('/')}"
    local_path = SAVE_DIR / relative_path
    r = requests.get(url, stream=True)
    r.raise_for_status()
    with open(local_path, "wb") as f:
        for chunk in r.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
    return local_path

downloaded = [{"audio": download_file(entry["audio"]), "keylog": download_file(entry["keylog"])} for entry in manifest]
print(f"Downloaded {len(downloaded)} recordings into {SAVE_DIR}")


## Keyboard Vocabulary

The tokenizer reuses the spec's `KeyboardEvent.code` values and the two event types (`down`, `up`). This cell simply materializes the lists so the tokenizer can reference them without any ad-hoc logic elsewhere in the notebook.


In [2]:

LETTER_CODES = [f"Key{chr(i)}" for i in range(ord('A'), ord('Z') + 1)]
DIGIT_CODES = [f"Digit{i}" for i in range(10)]
PUNCTUATION_CODES = [
    'Minus', 'Equal', 'BracketLeft', 'BracketRight', 'Backslash',
    'Semicolon', 'Quote', 'Backquote', 'Comma', 'Period', 'Slash'
]
MODIFIER_CODES = [
    'ShiftLeft', 'ShiftRight', 'ControlLeft', 'ControlRight',
    'AltLeft', 'AltRight', 'MetaLeft', 'MetaRight', 'CapsLock'
]
WHITESPACE_EDIT_CODES = ['Space', 'Tab', 'Enter', 'Backspace', 'Delete']
NAVIGATION_CODES = [
    'ArrowLeft', 'ArrowRight', 'ArrowUp', 'ArrowDown',
    'Home', 'End', 'PageUp', 'PageDown', 'Insert', 'Escape'
]
FUNCTION_CODES = [f"F{i}" for i in range(1, 13)]

PHYSICAL_KEY_CODES = (
    LETTER_CODES
    + DIGIT_CODES
    + PUNCTUATION_CODES
    + MODIFIER_CODES
    + WHITESPACE_EDIT_CODES
    + NAVIGATION_CODES
    + FUNCTION_CODES
)
EVENT_TYPES = ['down', 'up']
SPECIAL_TOKENS = ['<BOS>', '<EOS>']

print(f"Total physical keys: {len(PHYSICAL_KEY_CODES)}")
print(f"Total tokens including events and specials: {len(PHYSICAL_KEY_CODES) * len(EVENT_TYPES) + len(SPECIAL_TOKENS)}")


Total physical keys: 83
Total tokens including events and specials: 168



## KeyboardEventTokenizer

The tokenizer subclasses `PreTrainedTokenizer` so that it can be bundled into a `WhisperProcessor`. Custom helper methods (`encode_events` / `decode_events`) work directly with keyboard event dictionaries, while the inherited text APIs still function (by treating each event token as whitespace-delimited text) to maintain compatibility with `transformers` utilities.


In [3]:

class KeyboardEventTokenizer(PreTrainedTokenizer):
    """Tokenizer that maps keyboard event dictionaries to token ids.

    The implementation leans entirely on `PreTrainedTokenizer` hooks so that
    Hugging Face processors/trainers can treat it like any other tokenizer.
    Custom encode/decode helpers keep higher-level code ergonomic.
    """

    vocab_files_names = {"vocab_file": "keyboard_vocab.json"}
    model_input_names = ["input_ids"]

    def __init__(
        self,
        codes: Sequence[str] = PHYSICAL_KEY_CODES,
        event_types: Sequence[str] = EVENT_TYPES,
        special_tokens: Sequence[str] = SPECIAL_TOKENS,
        key_to_code: Optional[Dict[str, str]] = None,
        **kwargs,
    ) -> None:
        self.codes = list(codes)
        self.event_types = list(event_types)
        self.special_tokens = list(special_tokens)
        self.key_to_code = key_to_code or {}

        vocab = {}
        for token in self.special_tokens:
            vocab[token] = len(vocab)
        for code in self.codes:
            for event_type in self.event_types:
                vocab[f"{code}_{event_type}"] = len(vocab)

        self._vocab = vocab
        self._id_to_token = {idx: token for token, idx in vocab.items()}

        super().__init__(
            bos_token=self.special_tokens[0],
            eos_token=self.special_tokens[1],
            pad_token=self.special_tokens[1],  # Spec only defines BOS/EOS, so reuse EOS for padding.
            **kwargs,
        )

    @property
    def vocab_size(self) -> int:  # type: ignore[override]
        return len(self._vocab)

    def get_vocab(self) -> Dict[str, int]:  # type: ignore[override]
        return dict(self._vocab, **self.added_tokens_encoder)

    def _tokenize(self, text: str) -> List[str]:  # type: ignore[override]
        return text.strip().split()

    def _convert_token_to_id(self, token: str) -> int:  # type: ignore[override]
        return self._vocab.get(token, self.eos_token_id)

    def _convert_id_to_token(self, index: int) -> str:  # type: ignore[override]
        return self._id_to_token.get(index, self.eos_token)

    def build_inputs_with_special_tokens(self, token_ids: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:  # type: ignore[override]
        if token_ids_1 is not None:
            raise ValueError("This tokenizer does not support pair encodings.")
        return [self.bos_token_id] + token_ids + [self.eos_token_id]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):  # type: ignore[override]
        vocab_file = os.path.join(save_directory, (filename_prefix + '-' if filename_prefix else '') + self.vocab_files_names['vocab_file'])
        with open(vocab_file, 'w', encoding='utf-8') as f:
            json.dump(self._vocab, f, indent=2)
        return (vocab_file,)

    # --- Custom helpers for keyboard events ---------------------------------------------------

    def encode_events(self, events: Sequence[Dict[str, str]]) -> List[int]:
        token_ids = [self.bos_token_id]
        for event in events:
            code = event.get('code') or self.key_to_code.get(event.get('key', ''))
            if not code:
                continue
            event_type = event['event_type']
            normalized_type = event_type.replace('key', '')
            token = f"{code}_{normalized_type}"
            if token in self._vocab:
                token_ids.append(self._vocab[token])
        token_ids.append(self.eos_token_id)
        return token_ids

    def decode_events(self, token_ids: Sequence[int]) -> List[Dict[str, str]]:
        events = []
        for token_id in token_ids:
            if token_id in (self.bos_token_id, self.eos_token_id):
                continue
            token = self._id_to_token.get(token_id)
            if not token or '_' not in token:
                continue
            code, event_type = token.rsplit('_', 1)
            events.append({'code': code, 'event_type': f'key{event_type}'})
        return events

import transformers as _transformers_module
setattr(_transformers_module, 'KeyboardEventTokenizer', KeyboardEventTokenizer)
WhisperProcessor.tokenizer_class = ('KeyboardEventTokenizer', None)

In [4]:

tokenizer = KeyboardEventTokenizer()
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")

sample_events = [
    {'code': 'KeyA', 'event_type': 'keydown'},
    {'code': 'KeyA', 'event_type': 'keyup'},
    {'code': 'Enter', 'event_type': 'keydown'},
    {'code': 'Enter', 'event_type': 'keyup'},
]
encoded = tokenizer.encode_events(sample_events)
decoded = tokenizer.decode_events(encoded)
print('Encoded IDs:', encoded)
print('Decoded events:', decoded)
assert len(encoded) == len(decoded) + 2  # accounting for BOS/EOS
assert decoded == sample_events


Tokenizer vocab size: 168
Encoded IDs: [0, 2, 3, 118, 119, 1]
Decoded events: [{'code': 'KeyA', 'event_type': 'keydown'}, {'code': 'KeyA', 'event_type': 'keyup'}, {'code': 'Enter', 'event_type': 'keydown'}, {'code': 'Enter', 'event_type': 'keyup'}]



## Whisper Processor

This cell wires the tokenizer into a `WhisperProcessor` so that feature extraction and tokenization stay centralized inside `transformers`. The processor is saved and reloaded to prove the tokenizer integrates cleanly with standard serialization.


In [5]:
artifacts_dir = Path('artifacts')
artifacts_dir.mkdir(exist_ok=True)

feature_extractor = WhisperFeatureExtractor(
    feature_size=80,
    sampling_rate=SAMPLING_RATE,
    chunk_length=30,
    hop_length=160,
    return_attention_mask=True,
)
processor = WhisperProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# Sanity check: process a short silence clip to verify the processor pipeline.
zero_audio = np.zeros(SAMPLING_RATE, dtype=np.float32)
processed = processor(audio=zero_audio, sampling_rate=SAMPLING_RATE, return_tensors='pt')
print('Processor produced input shape:', processed.input_features.shape)


Processor produced input shape: torch.Size([1, 80, 3000])



## Dataset utilities

The dataset matches the spec: it pairs `.webm` audio files with `.json` keystroke annotations when a recordings directory is provided. For reproducible testing in this notebook, we also allow injecting synthetic audio/event examples so the machinery can run end-to-end without external data.


In [6]:

@dataclass
class RecordingExample:
    audio: np.ndarray
    sampling_rate: int
    events: List[Dict[str, str]]


class KeyboardEventDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        processor: WhisperProcessor,
        recordings_dir: Optional[Path] = None,
        max_length: int = 1024,
        examples: Optional[List[RecordingExample]] = None,
    ) -> None:
        self.processor = processor
        self.recordings_dir = Path(recordings_dir) if recordings_dir else None
        self.max_length = max_length
        self.examples = examples
        self.pairs = self._find_pairs() if self.examples is None else []

    def _find_pairs(self) -> List[Dict[str, Path]]:
        if not self.recordings_dir:
            return []
        pairs = []
        for json_path in self.recordings_dir.glob('*.json'):
            if 'DELETED' in json_path.name:
                continue
            webm_path = json_path.with_suffix('.webm')
            if webm_path.exists():
                pairs.append({'audio_path': webm_path, 'json_path': json_path})
        return pairs

    def __len__(self) -> int:
        if self.examples is not None:
            return len(self.examples)
        return len(self.pairs)

    def _load_example_from_disk(self, idx: int) -> RecordingExample:
        pair = self.pairs[idx]

        waveform, sampling_rate = torchaudio.load(pair["audio_path"])

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        if sampling_rate != SAMPLING_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
            waveform = resampler(waveform)
            sampling_rate = SAMPLING_RATE
        audio = waveform.squeeze(0).numpy()
        
        with open(pair['json_path'], 'r', encoding='utf-8') as f:
            events = json.load(f)['keystrokes']
        return RecordingExample(audio=audio, sampling_rate=SAMPLING_RATE, events=events)

    def _prepare_features(self, audio: np.ndarray, sampling_rate: int) -> torch.Tensor:
        if sampling_rate != SAMPLING_RATE:
            audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE)
        features = self.processor.feature_extractor(
            audio,
            sampling_rate=SAMPLING_RATE,
            return_tensors='pt'
        ).input_features[0]
        return features

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        example = self.examples[idx] if self.examples is not None else self._load_example_from_disk(idx)
        features = self._prepare_features(example.audio, example.sampling_rate)
        token_ids = self.processor.tokenizer.encode_events(example.events)
        token_ids = token_ids[:self.max_length]
        labels = torch.tensor(token_ids, dtype=torch.long)
        return {
            'input_features': features,
            'labels': labels,
        }



### Synthetic data for testing

To keep the notebook self-contained (and runnable without external recordings), the next cell generates simple sine-wave audio clips with random keyboard events, then exercises the dataset implementation.


In [7]:

def generate_sine_wave(duration_s: float = 0.4, sr: int = SAMPLING_RATE, freq: float = 220.0) -> np.ndarray:
    t = np.linspace(0, duration_s, int(sr * duration_s), endpoint=False)
    return 0.05 * np.sin(2 * math.pi * freq * t)


def random_events(num_events: int = 6) -> List[Dict[str, str]]:
    events: List[Dict[str, str]] = []
    for _ in range(num_events):
        code = random.choice(PHYSICAL_KEY_CODES)
        event_type = random.choice(['keydown', 'keyup'])
        events.append({'code': code, 'event_type': event_type})
    return events


synthetic_examples = [
    RecordingExample(
        audio=generate_sine_wave(freq=220 + 20 * i),
        sampling_rate=SAMPLING_RATE,
        events=random_events(),
    )
    for i in range(4)
]

dataset = KeyboardEventDataset(processor=processor, max_length=64, recordings_dir="/home/matty/Projects/kasca-esque/recordings")
sample = dataset[0]
print('Input feature shape:', sample['input_features'].shape)
print('Label ids:', sample['labels'])


Input feature shape: torch.Size([80, 3000])
Label ids: tensor([0, 1])



## Data Collator

`Seq2SeqTrainer` expects batches of tensors with uniform shapes. The collator below pads log-mel features along the time axis and uses the tokenizer's pad token for label padding, keeping everything aligned with `transformers` conventions.


In [9]:

def collate_keyboard_batch(batch: Sequence[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    input_features = [item['input_features'] for item in batch]
    labels = [item['labels'] for item in batch]

    max_feat_len = max(feature.shape[-1] for feature in input_features)
    padded_inputs = []
    for feature in input_features:
        pad_amount = max_feat_len - feature.shape[-1]
        padded = torch.nn.functional.pad(feature, (0, pad_amount))
        padded_inputs.append(padded)
    stacked_inputs = torch.stack(padded_inputs)

    padded_labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id,
    )

    return {
        'input_features': stacked_inputs,
        'labels': padded_labels,
    }


# batch = collate_keyboard_batch([dataset[i] for i in range(2)])
batch = collate_keyboard_batch(dataset)
print('Batch input shape:', batch['input_features'].shape)
print('Batch labels shape:', batch['labels'].shape)


Batch input shape: torch.Size([59, 80, 3000])
Batch labels shape: torch.Size([59, 2])



## Whisper model + LoRA adapters

A compact Whisper configuration keeps the demo lightweight. We then attach LoRA adapters (via `peft`) to the encoder's attention projections exactly as prescribed in the spec (query/key/value matrices only). This ensures the decoder remains fully trainable while the encoder benefits from parameter-efficient fine-tuning.


In [10]:
config = WhisperConfig(
    vocab_size=processor.tokenizer.vocab_size,
    d_model=64,
    encoder_layers=2,
    decoder_layers=2,
    encoder_attention_heads=4,
    decoder_attention_heads=4,
    encoder_ffn_dim=256,
    decoder_ffn_dim=256,
    num_mel_bins=feature_extractor.feature_size,
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    pad_token_id=processor.tokenizer.pad_token_id,
)
config.decoder_start_token_id = processor.tokenizer.bos_token_id
model = WhisperForConditionalGeneration(config)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

# Allow Trainer to pass either `input_features` (preferred) or `input_ids` (what text models expect).
_original_forward = model.forward

def forward_with_feature_alias(self, input_features=None, input_ids=None, inputs_embeds=None, **kwargs):
    if input_features is None:
        if inputs_embeds is not None:
            input_features = inputs_embeds
        elif input_ids is not None:
            input_features = input_ids
    return _original_forward(input_features=input_features, **kwargs)

model.forward = types.MethodType(forward_with_feature_alias, model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj"],
    bias='none',
    task_type='SEQ_2_SEQ_LM',
)
model = get_peft_model(model, lora_config)
model.to(DEVICE)
model.print_trainable_parameters()


trainable params: 18,432 || all params: 414,976 || trainable%: 4.4417



## Training with `Seq2SeqTrainer`

With the model, processor, dataset, and collator in place, we can leverage Hugging Face's `Seq2SeqTrainer` to handle the training loop, evaluation, scheduling, and generation utilities. The tiny synthetic dataset keeps runtime small while still exercising every component.


In [None]:
train_dataset = dataset
val_dataset = dataset

generation_config = GenerationConfig(
    max_length=64,
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    pad_token_id=processor.tokenizer.pad_token_id,
)

class KeyboardSeq2SeqTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        sanitized_inputs = {
            key: value for key, value in inputs.items() if key not in ('input_ids', 'inputs_embeds')
        }
        sanitized_inputs = self._prepare_inputs(sanitized_inputs)
        outputs = model(**sanitized_inputs)
        loss = outputs.loss if hasattr(outputs, 'loss') and outputs.loss is not None else outputs[0]
        return (loss, outputs) if return_outputs else loss

training_args = Seq2SeqTrainingArguments(
    output_dir=str(artifacts_dir / 'checkpoints'),
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    learning_rate=1e-5,
    num_train_epochs=1,
    warmup_steps=500,
    max_steps=5000,
    logging_strategy='steps',
    logging_steps=1,
    eval_strategy='steps',
    save_steps=1000,
    eval_steps=1000,
    save_strategy='no',
    gradient_accumulation_steps=1,
    report_to=[],
    predict_with_generate=True,
    generation_max_length=64,
    remove_unused_columns=False,
    metric_for_best_model="wer",
)

trainer = KeyboardSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_keyboard_batch,
)
print(len(train_dataset))
trainer.model.generation_config = generation_config
train_result = trainer.train()
print(train_result)


59




Step,Training Loss,Validation Loss



## Quick generation sanity check

Finally, run `model.generate` on one synthetic sample to ensure decoding works through the tokenizer and the output maintains the expected structure.


In [15]:

model.eval()
with torch.no_grad():
    batch = collate_keyboard_batch([val_dataset[0]])
    input_features = batch['input_features'].to(DEVICE)
    generated_ids = model.generate(
        input_features=input_features,
        max_length=32,
        bos_token_id=processor.tokenizer.bos_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        pad_token_id=processor.tokenizer.pad_token_id,
    )[0].cpu().tolist()

print('Generated token ids:', generated_ids)
print('Decoded events:', processor.tokenizer.decode_events(generated_ids))


  audio, _ = librosa.load(pair['audio_path'], sr=16000)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Generated token ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Decoded events: []



The notebook now mirrors the updated spec: it mandates the `transformers` stack for Whisper, uses `WhisperProcessor` + `WhisperForConditionalGeneration`, applies LoRA via `peft`, drives training with `Seq2SeqTrainer`, and validates the custom tokenizer before entering the training loop.
