# Whisper keystroke fine-tuning blueprint

This notebook turns the earlier prose outline into an executable Colab-style plan. It keeps Whisper's encoder, trains a decoder whose tokens map **exactly** to observed `KeyboardEvent.code` values, and adds a reproducible data-ingest stage that syncs assets from the `/recordings` route before any modeling work begins.

## Feasibility, risks, and objectives
- Whisper's encoder already models short percussive events, so reusing it for keystrokes is realistic; the decoder is retrained from scratch to emit keyboard tokens instead of language text.
- Expect to need hours of paired (audio, keylog) data; diversity across typists, hardware, and mic placements will decide accuracy.
- Labels must stay perfectly ordered; we only rely on the sequence of `event.code` entries, so consistent capture timestamps from the web app remain critical.
- We'll start with a small Whisper checkpoint (e.g., `tiny`/`small`) plus LoRA adapters so the encoder can adapt slightly without overfitting when data is scarce.

## Data acquisition via `/recordings`
1. The notebook discovers data by hitting the deployed SvelteKit `/recordings` route (same domain as the typing UI). That endpoint returns `{ audio, keylog }` pairs when both `.webm` and `.json` files exist.
2. We mirror that list into a `recordings_cache/` folder that sits alongside this notebook. The sync logic:
   - Fetches the remote manifest once per run.
   - Downloads any missing files.
   - Deletes local files that are no longer advertised by the server so the cache always matches `/recordings` exactly.
3. Subsequent preprocessing (manifest building, token set discovery, etc.) always works off this synchronized cache, so there's a single source of truth and no stale artifacts.

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
assert device == "cuda", "Please enable GPU (e.g., Colab > Runtime > Change runtime type)."

In [None]:
%%capture
!uv pip install --upgrade --quiet \
    requests \
    "transformers>=4.44.0" \
    "datasets>=2.19.0" \
    "accelerate>=0.33.0" \
    "peft>=0.12.0" \
    "soundfile" \
    "librosa" \
    "tokenizers>=0.15.0" \
    "evaluate"

In [None]:
import json
import os
import random
from pathlib import Path
from typing import Dict, List
from urllib.parse import urljoin

import requests
import soundfile as sf

from datasets import Audio, load_dataset

import numpy as np
import torch
from torch import nn
from dataclasses import dataclass

from transformers import (
    WhisperConfig,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    PreTrainedTokenizerFast,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace

from peft import LoraConfig, TaskType, get_peft_model

NOTEBOOK_DIR = Path.cwd()
CACHE_DIR = NOTEBOOK_DIR / "recordings_cache"
MANIFEST_DIR = NOTEBOOK_DIR / "manifests"
MANIFEST_DIR.mkdir(exist_ok=True)

BASE_URL = os.environ.get("S5_BASE_URL", "http://kasca-esque.fly.dev")
LIST_ENDPOINT = "/recordings"
DOWNLOAD_ENDPOINT = "/recordings/"
TARGET_SAMPLING_RATE = 16_000
BASE_MODEL_NAME = "openai/whisper-small"
TRAIN_SPLIT = 0.9
RNG = random.Random(1337)

### Sync `/recordings` into the local cache
Set `S5_BASE_URL` (or edit `BASE_URL` above) so that `BASE_URL + /recordings` resolves to the deployed app. Run the cell whenever you want to refresh the cache; it keeps the folder contents identical to what the server currently exposes.

In [None]:
def sync_recordings():
    CACHE_DIR.mkdir(exist_ok=True)
    list_url = urljoin(BASE_URL, LIST_ENDPOINT)
    response = requests.get(list_url, timeout=30)
    response.raise_for_status()
    payload = response.json()
    if isinstance(payload, dict) and "error" in payload:
        raise RuntimeError(f"Server error: {payload['error']}")

    # Build set of remote filenames and download missing ones
    remote_files = set()
    for item in payload:
        for field in ("audio", "keylog"):
            filename = item[field]
            remote_files.add(filename)
            dest = CACHE_DIR / filename
            if not dest.exists():
                file_url = urljoin(BASE_URL, f"/recordings/{filename}")
                print(f"Downloading {filename}...")
                with requests.get(file_url, stream=True, timeout=60) as r:
                    r.raise_for_status()
                    with open(dest, "wb") as fh:
                        for chunk in r.iter_content(chunk_size=1 << 16):
                            fh.write(chunk)

    # Delete local files that are no longer available upstream
    for local_path in CACHE_DIR.iterdir():
        if local_path.name not in remote_files:
            print(f"Removing {local_path.name} (not advertised by server)...")
            local_path.unlink()

    print(f"Synced {len(remote_files)} files into {CACHE_DIR}")

sync_recordings()

### Build train/eval manifests from the synchronized cache
Each `.json` file holds the ordered keystrokes for its sibling `.webm`. This utility normalizes timestamps (seconds offset from the first event), keeps only entries whose `event.code` exists in the downloaded data, and emits two JSONL manifests under `manifests/`.

In [None]:
def build_manifests():
    entries: List[Dict] = []
    for json_path in CACHE_DIR.glob("*.json"):
        audio_path = json_path.with_suffix(".webm")
        if not audio_path.exists():
            continue
        with open(json_path, "r") as fh:
            data = json.load(fh)
        keystrokes = data.get("keystrokes") or data.get("events") or []
        if not keystrokes:
            continue
        start_ts = keystrokes[0]["timestamp"]
        example_events = []
        for event in keystrokes:
            code = event.get("key")
            etype = event.get("event_type")
            if not code or not etype:
                continue
            example_events.append(
                {
                    "time": (event["timestamp"] - start_ts) / 1000.0,
                    "code": code,
                    "type": "down" if etype == "keydown" else "up",
                }
            )
        if not example_events:
            continue
        entries.append(
            {
                "audio": str(audio_path),
                "events": example_events,
            }
        )

    if not entries:
        raise RuntimeError("No paired recordings found in cache.")

    RNG.shuffle(entries)
    split_idx = int(len(entries) * TRAIN_SPLIT)
    train_entries = entries[:split_idx]
    eval_entries = entries[split_idx:] or entries[-1:]

    train_path = MANIFEST_DIR / "train.jsonl"
    eval_path = MANIFEST_DIR / "eval.jsonl"

    for path, subset in ((train_path, train_entries), (eval_path, eval_entries)):
        with open(path, "w") as fh:
            for row in subset:
                fh.write(json.dumps(row) + "\n")
        print(f"Wrote {len(subset)} rows to {path}")

build_manifests()

### Derive the exact `event.code` vocabulary from the cached logs
We now recompute the vocabulary straight from the synchronized JSON files so that every decoder token corresponds to a real `event.code`. Up/down states are expressed by duplicating each code into `_DOWN` / `_UP` variants, but no other synthetic key names are introduced.

In [None]:
def collect_event_codes(recordings_dir: Path) -> List[str]:
    codes = set()
    for json_path in recordings_dir.glob("*.json"):
        with open(json_path, "r") as fh:
            data = json.load(fh)
        keystrokes = data.get("keystrokes") or data.get("events") or []
        for event in keystrokes:
            code = event.get("key")
            if code:
                codes.add(code)
    if not codes:
        raise RuntimeError("No event.code entries found; double-check the cache.")
    return sorted(codes)

EVENT_CODES = collect_event_codes(CACHE_DIR)
print(f"Discovered {len(EVENT_CODES)} distinct event.code values.")
print(EVENT_CODES)

special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
event_tokens = [
    f"{code}_{suffix}"
    for code in EVENT_CODES
    for suffix in ("DOWN", "UP")
]

vocab = {tok: idx for idx, tok in enumerate(special_tokens + event_tokens)}
print(f"Vocabulary size: {len(vocab)}")

backend_tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token="<unk>"))
backend_tokenizer.pre_tokenizer = Whitespace()

tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=backend_tokenizer,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
)

# Set padding side to right (default for most seq2seq models)
tokenizer.padding_side = "right"

tokenizer.save_pretrained(NOTEBOOK_DIR / "key_tokenizer")
print(f"Tokenizer padding side: {tokenizer.padding_side}")

### Load Whisper encoder, swap in the new decoder + tokenizer
We clone the base Whisper config, override `vocab_size`, and copy only the encoder weights. Language/task forcing is disabled because we're no longer predicting text.

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(BASE_MODEL_NAME)
base_model = WhisperForConditionalGeneration.from_pretrained(BASE_MODEL_NAME)

new_config = WhisperConfig(
    vocab_size=len(vocab),
    d_model=base_model.config.d_model,
    encoder_layers=base_model.config.encoder_layers,
    encoder_attention_heads=base_model.config.encoder_attention_heads,
    decoder_layers=base_model.config.decoder_layers,
    decoder_attention_heads=base_model.config.decoder_attention_heads,
    decoder_ffn_dim=base_model.config.decoder_ffn_dim,
    encoder_ffn_dim=base_model.config.encoder_ffn_dim,
    max_source_positions=base_model.config.max_source_positions,
    max_target_positions=base_model.config.max_target_positions,
    dropout=base_model.config.dropout,
    attention_dropout=base_model.config.attention_dropout,
    activation_dropout=base_model.config.activation_dropout,
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,  # Important for generation
)

new_config.forced_decoder_ids = None
new_config.suppress_tokens = []

model = WhisperForConditionalGeneration(new_config)
model.model.encoder.load_state_dict(base_model.model.encoder.state_dict())

# Free up memory by deleting the base model
del base_model
import gc
gc.collect()
if device == "cuda":
    torch.cuda.empty_cache()

# Configure generation settings
model.generation_config.max_length = 256
model.generation_config.max_new_tokens = None  # Use max_length instead
model.generation_config.num_beams = 1  # Greedy decoding for speed
model.generation_config.do_sample = False
model.generation_config.decoder_start_token_id = tokenizer.bos_token_id

# Note: We can't use WhisperProcessor with a custom tokenizer, so we'll use feature_extractor and tokenizer separately
# Device placement happens after LoRA is applied (see cell below)
print("Model initialized with pretrained encoder + fresh decoder.")
print(f"Generation config: max_length={model.generation_config.max_length}, num_beams={model.generation_config.num_beams}")
print(f"Decoder start token ID: {model.config.decoder_start_token_id}")

### (Optional) LoRA adapters on the encoder
Adapters let us nudge the encoder toward keyboard acoustics without updating every weight. Disable this if you prefer to keep the encoder frozen.

In [None]:
use_lora = True
if use_lora:
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj"],
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
else:
    print("LoRA disabled; the encoder weights remain as-is.")

# Move model to device after LoRA is applied
model.to(device)
print(f"Model moved to device: {device}")

### Freeze the base encoder weights (except LoRA) so we mainly train the decoder

In [None]:
for name, param in model.named_parameters():
    if name.startswith("model.encoder") and "lora_" not in name:
        param.requires_grad = False

# Disable cache during training (required for gradient checkpointing and can cause issues)
model.config.use_cache = False

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,} ({trainable/total:.2%})")
print(f"use_cache disabled for training")

### Load the manifests and attach audio/label processing
We leverage Hugging Face `datasets` to read the JSONL manifests, resample audio to 16 kHz, and convert each event list into a whitespace-separated token string.

In [None]:
data_files = {
    "train": str(MANIFEST_DIR / "train.jsonl"),
    "validation": str(MANIFEST_DIR / "eval.jsonl"),
}

datasets = load_dataset("json", data_files=data_files)
datasets = datasets.cast_column("audio", Audio(sampling_rate=TARGET_SAMPLING_RATE))
print(datasets)

In [None]:
def events_to_text(example):
    ordered = sorted(example["events"], key=lambda e: e["time"])
    tokens = []
    for evt in ordered:
        code = evt["code"]
        state = evt["type"].upper()
        token = f"{code}_{state}"
        if token not in vocab:
            token = "<unk>"
        tokens.append(token)
    example["labels_text"] = " ".join(tokens)
    return example

textualized = datasets.map(events_to_text)
print(textualized["train"][0]["labels_text"][:120])

In [None]:
def prepare_example(batch):
    audio = batch["audio"]
    inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    batch["input_features"] = inputs.input_features[0]
    labels = tokenizer(batch["labels_text"], add_special_tokens=True).input_ids
    batch["labels"] = labels
    return batch

vectorized = textualized.map(
    prepare_example,
    remove_columns=textualized["train"].column_names,
    num_proc=4,
)
print(vectorized)

In [None]:
@dataclass
class DataCollatorSpeechSeq2Seq:
    feature_extractor: WhisperFeatureExtractor
    tokenizer: PreTrainedTokenizerFast

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        input_feats = [{"input_features": f["input_features"]} for f in features]
        batch = self.feature_extractor.pad(input_feats, return_tensors="pt")
        label_feats = [{"input_ids": f["labels"]} for f in features]
        labels_batch = self.tokenizer.pad(label_feats, padding=True, return_tensors="pt")
        labels = labels_batch["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

collator = DataCollatorSpeechSeq2Seq(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    # When predict_with_generate=True, preds are already token IDs, not logits
    # Replace -100 in labels with pad token for decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Calculate exact match accuracy
    exact_matches = sum(p == l for p, l in zip(pred_str, label_str))
    
    # Also calculate token-level accuracy for more granular metrics
    pred_tokens = [p.split() for p in pred_str]
    label_tokens = [l.split() for l in label_str]
    
    correct_tokens = 0
    total_tokens = 0
    for p_toks, l_toks in zip(pred_tokens, label_tokens):
        # Count matching tokens at each position
        for i in range(min(len(p_toks), len(l_toks))):
            if p_toks[i] == l_toks[i]:
                correct_tokens += 1
        total_tokens += max(len(p_toks), len(l_toks))
    
    return {
        "sequence_accuracy": exact_matches / len(pred_str) if len(pred_str) > 0 else 0,
        "token_accuracy": correct_tokens / total_tokens if total_tokens > 0 else 0,
    }

In [None]:
output_dir = NOTEBOOK_DIR / "whisper_eventcode_lora"

# Calculate reasonable step counts based on dataset size
# With 55 train samples, batch_size=4, grad_accum=4 -> ~3.4 steps/epoch -> ~34 total steps for 10 epochs
steps_per_epoch = len(textualized["train"]) // (4 * 4) + 1
total_steps = steps_per_epoch * 10

training_args = Seq2SeqTrainingArguments(
    output_dir=str(output_dir),
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    warmup_steps=min(10, total_steps // 4),  # Warmup for ~25% of training or 10 steps max
    num_train_epochs=10,
    evaluation_strategy="epoch",  # Evaluate every epoch instead of by steps
    save_strategy="epoch",  # Save every epoch
    logging_steps=max(1, steps_per_epoch // 2),  # Log twice per epoch
    save_total_limit=3,
    fp16=device == "cuda",  # Only use fp16 on GPU
    predict_with_generate=True,  # Generate sequences during evaluation for proper metrics
    generation_max_length=256,  # Max sequence length for generation
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="sequence_accuracy",
    greater_is_better=True,
    gradient_checkpointing=False,  # Set to True if running out of memory
    optim="adamw_torch",  # Use PyTorch's AdamW
)
print(f"Steps per epoch: {steps_per_epoch}, Total steps: {total_steps}")
print(f"Warmup steps: {training_args.warmup_steps}")
print(training_args)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=vectorized["train"],
    eval_dataset=vectorized["validation"],
    data_collator=collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()

In [None]:
save_dir = NOTEBOOK_DIR / "whisper_eventcode_artifacts"
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir / "tokenizer")
feature_extractor.save_pretrained(save_dir / "feature_extractor")
print(f"Artifacts stored under {save_dir}")

In [None]:
from transformers import GenerationConfig

def decode_recording(wav_path: str, max_length: int = 256):
    audio_array, sr = sf.read(wav_path)
    if sr != TARGET_SAMPLING_RATE:
        import librosa
        audio_array = librosa.resample(y=audio_array, orig_sr=sr, target_sr=TARGET_SAMPLING_RATE)
        sr = TARGET_SAMPLING_RATE
    inputs = feature_extractor(audio_array, sampling_rate=sr, return_tensors="pt")
    input_features = inputs.input_features.to(device)
    with torch.no_grad():
        generated = model.generate(
            input_features,
            max_length=max_length,
        )
    decoded = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
    return decoded.split()

# Example usage (update path after training):
# predicted_tokens = decode_recording(str(CACHE_DIR / "example.webm"))
# print(predicted_tokens)

## Next steps
- Expand the dataset (more typists, keyboards, and mics) so the decoder sees diverse acoustics.
- Experiment with different Whisper checkpoints (`tiny`, `base`, `small`) and LoRA ranks to balance speed vs. accuracy.
- Add richer evaluation metrics (per-token F1, timing alignment) or streaming decoding logic once the offline model is reliable.