# Whisper LoRA Keyboard Event Detection

This notebook implements fine-tuning of Whisper with LoRA for keyboard event detection from audio.

## Overview
- **Encoder**: Whisper with LoRA (frozen base weights)
- **Decoder**: Custom autoregressive transformer for keyboard events
- **Tokenization**: Based on KeyboardEvent.code (physical keys)

## 1. Setup and Dependencies

In [None]:
# Install required packages
!pip install -q torch transformers peft datasets librosa soundfile tensorboard accelerate

In [None]:
import json
import os
from glob import glob
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import librosa
import numpy as np

from transformers import (
    WhisperModel,
    WhisperFeatureExtractor,
    WhisperConfig,
)
from peft import LoraConfig, get_peft_model, TaskType

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Keyboard Event Tokenizer

Implements tokenization based on JavaScript KeyboardEvent.code values.
This maps physical keys (e.g., KeyA, Digit1) to token IDs.

In [None]:
class KeyboardEventTokenizer:
    """Tokenizer for keyboard events using KeyboardEvent.code values.
    
    Maps physical keyboard keys to tokens, with separate tokens for keydown/keyup.
    Example: 'a' and 'A' both map to KeyA, '1' and '!' both map to Digit1.
    """
    
    def __init__(self):
        self.event_types = ['down', 'up']
        self.special_tokens = ['<BOS>', '<EOS>']

        # Physical key codes (KeyboardEvent.code)
        self.codes = [
            # Letters (26)
            *[f'Key{chr(i)}' for i in range(ord('A'), ord('Z') + 1)],
            # Digits (10)
            *[f'Digit{i}' for i in range(10)],
            # Punctuation (11)
            'Minus', 'Equal', 'BracketLeft', 'BracketRight', 'Backslash',
            'Semicolon', 'Quote', 'Backquote', 'Comma', 'Period', 'Slash',
            # Modifiers (9)
            'ShiftLeft', 'ShiftRight', 'ControlLeft', 'ControlRight',
            'AltLeft', 'AltRight', 'MetaLeft', 'MetaRight', 'CapsLock',
            # Whitespace & Editing (5)
            'Space', 'Tab', 'Enter', 'Backspace', 'Delete',
            # Navigation (10)
            'ArrowLeft', 'ArrowRight', 'ArrowUp', 'ArrowDown',
            'Home', 'End', 'PageUp', 'PageDown', 'Insert', 'Escape',
            # Function Keys (12)
            *[f'F{i}' for i in range(1, 13)],
        ]

        # Mapping from event.key to event.code for legacy dataset compatibility
        self.key_to_code = {
            # Letters - lowercase map to KeyX
            **{chr(i): f'Key{chr(i).upper()}' for i in range(ord('a'), ord('z') + 1)},
            # Letters - uppercase map to KeyX (same as lowercase)
            **{chr(i): f'Key{chr(i)}' for i in range(ord('A'), ord('Z') + 1)},
            # Digits and their shifted symbols
            '0': 'Digit0', ')': 'Digit0',
            '1': 'Digit1', '!': 'Digit1',
            '2': 'Digit2', '@': 'Digit2',
            '3': 'Digit3', '#': 'Digit3',
            '4': 'Digit4', '$': 'Digit4',
            '5': 'Digit5', '%': 'Digit5',
            '6': 'Digit6', '^': 'Digit6',
            '7': 'Digit7', '&': 'Digit7',
            '8': 'Digit8', '*': 'Digit8',
            '9': 'Digit9', '(': 'Digit9',
            # Punctuation
            '-': 'Minus', '_': 'Minus',
            '=': 'Equal', '+': 'Equal',
            '[': 'BracketLeft', '{': 'BracketLeft',
            ']': 'BracketRight', '}': 'BracketRight',
            '\\': 'Backslash', '|': 'Backslash',
            ';': 'Semicolon', ':': 'Semicolon',
            "'": 'Quote', '"': 'Quote',
            '`': 'Backquote', '~': 'Backquote',
            ',': 'Comma', '<': 'Comma',
            '.': 'Period', '>': 'Period',
            '/': 'Slash', '?': 'Slash',
            # Whitespace & Editing
            ' ': 'Space',
            'Tab': 'Tab',
            'Enter': 'Enter',
            'Backspace': 'Backspace',
            'Delete': 'Delete', 'DELETE': 'Delete',
            # Modifiers
            'Shift': 'ShiftLeft',  # Generic shift maps to left
            'SHIFT_R': 'ShiftRight',
            'Control': 'ControlLeft',
            'CTRL_L': 'ControlLeft',
            'CTRL_R': 'ControlRight',
            'Alt': 'AltLeft',
            'ALT_GR': 'AltRight',
            'CMD': 'MetaLeft',
            'Meta': 'MetaLeft',
            'CAPS_LOCK': 'CapsLock',
            'CapsLock': 'CapsLock',
            # Navigation
            'ArrowLeft': 'ArrowLeft', 'LEFT': 'ArrowLeft',
            'ArrowRight': 'ArrowRight', 'RIGHT': 'ArrowRight',
            'ArrowUp': 'ArrowUp', 'UP': 'ArrowUp',
            'ArrowDown': 'ArrowDown', 'DOWN': 'ArrowDown',
            'Home': 'Home',
            'End': 'End',
            'PageUp': 'PageUp',
            'PageDown': 'PageDown',
            'Insert': 'Insert',
            'Escape': 'Escape',
            # Function Keys
            **{f'F{i}': f'F{i}' for i in range(1, 13)},
        }

        # Build vocabulary
        self.vocab = {}
        self.id_to_token = {}

        # Add special tokens
        for i, token in enumerate(self.special_tokens):
            self.vocab[token] = i
            self.id_to_token[i] = token

        # Add key code event tokens
        idx = len(self.special_tokens)
        for code in self.codes:
            for event_type in self.event_types:
                token = f"{code}_{event_type}"
                self.vocab[token] = idx
                self.id_to_token[idx] = token
                idx += 1
                
        self.vocab_size = len(self.vocab)
        self.pad_token_id = self.vocab['<BOS>']  # Use BOS as pad for simplicity
        self.bos_token_id = self.vocab['<BOS>']
        self.eos_token_id = self.vocab['<EOS>']

    def encode(self, events: List[Dict]) -> List[int]:
        """Convert list of key events to token IDs.
        
        Handles both legacy event.key format and new event.code format.
        
        Args:
            events: List of dicts with 'key' or 'code' and 'event_type'
            
        Returns:
            List of token IDs including BOS and EOS
        """
        token_ids = [self.vocab['<BOS>']]
        
        for event in events:
            # Try to get code directly, or map from key
            code = event.get('code')
            if not code:
                key = event.get('key', '')
                code = self.key_to_code.get(key)

            if code:
                event_type = event['event_type'].replace('key', '')  # keydown -> down
                token = f"{code}_{event_type}"
                if token in self.vocab:
                    token_ids.append(self.vocab[token])
                # Skip unknown codes silently

        token_ids.append(self.vocab['<EOS>'])
        return token_ids

    def decode(self, token_ids: List[int]) -> List[Dict]:
        """Convert token IDs back to key events with code values.
        
        Args:
            token_ids: List of token IDs
            
        Returns:
            List of event dicts with 'code' and 'event_type'
        """
        events = []
        for token_id in token_ids:
            if token_id in [0, 1]:  # Skip <BOS> and <EOS>
                continue
            token = self.id_to_token.get(token_id)
            if token and '_' in token:
                code, event_type = token.rsplit('_', 1)
                events.append({
                    'code': code,
                    'event_type': f'key{event_type}'
                })
        return events

# Initialize tokenizer
tokenizer = KeyboardEventTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"BOS token ID: {tokenizer.bos_token_id}")
print(f"EOS token ID: {tokenizer.eos_token_id}")
print(f"\nNumber of physical keys: {len(tokenizer.codes)}")
print(f"Number of legacy key mappings: {len(tokenizer.key_to_code)}")

## 3. Test Tokenizer with Sample Data

Load a sample recording and verify tokenization works correctly.

In [None]:
# Load a sample JSON file to test tokenization
recordings_dir = '../recordings'
sample_files = [f for f in os.listdir(recordings_dir) if f.endswith('.json') and 'DELETED' not in f]

print(f"Found {len(sample_files)} JSON files")
print(f"Sample file: {sample_files[0]}")

# Load first file
with open(os.path.join(recordings_dir, sample_files[0]), 'r') as f:
    sample_data = json.load(f)

print(f"\nTotal events in sample: {len(sample_data['keystrokes'])}")
print(f"\nFirst 5 events:")
for i, event in enumerate(sample_data['keystrokes'][:5]):
    print(f"  {i}: {event}")

In [None]:
# Test encoding
sample_events = sample_data['keystrokes'][:20]  # First 20 events
encoded = tokenizer.encode(sample_events)

print(f"Original events count: {len(sample_events)}")
print(f"Encoded token count: {len(encoded)} (includes BOS/EOS)")
print(f"\nEncoded tokens: {encoded[:10]}...")

# Test decoding
decoded = tokenizer.decode(encoded)
print(f"\nDecoded events count: {len(decoded)}")
print(f"\nFirst 3 decoded events:")
for i, event in enumerate(decoded[:3]):
    print(f"  {i}: {event}")

In [None]:
# Verify round-trip consistency (event types should match)
print("Verifying round-trip encoding/decoding:")
mismatches = 0
for orig, dec in zip(sample_events[:len(decoded)], decoded):
    orig_type = orig['event_type']
    dec_type = dec['event_type']
    if orig_type != dec_type:
        print(f"  Mismatch: {orig} -> {dec}")
        mismatches += 1

if mismatches == 0:
    print("✓ All event types match after round-trip!")
else:
    print(f"✗ Found {mismatches} mismatches")

In [None]:
# Test with different key variations (uppercase/lowercase, symbols)
test_events = [
    {'key': 'a', 'event_type': 'keydown'},
    {'key': 'A', 'event_type': 'keydown'},
    {'key': '1', 'event_type': 'keydown'},
    {'key': '!', 'event_type': 'keydown'},
    {'key': 'Shift', 'event_type': 'keydown'},
    {'key': 'SHIFT_R', 'event_type': 'keydown'},
]

print("Testing key variations:")
for event in test_events:
    encoded = tokenizer.encode([event])
    decoded = tokenizer.decode(encoded)
    print(f"  {event['key']:10s} -> token {encoded[1]:3d} -> {decoded[0]['code'] if decoded else 'NONE'}")

# Verify 'a' and 'A' map to same code
a_lower = tokenizer.encode([{'key': 'a', 'event_type': 'keydown'}])[1]
a_upper = tokenizer.encode([{'key': 'A', 'event_type': 'keydown'}])[1]
print(f"\n'a' and 'A' map to same token: {a_lower == a_upper}")

# Verify '1' and '!' map to same code
one_digit = tokenizer.encode([{'key': '1', 'event_type': 'keydown'}])[1]
one_symbol = tokenizer.encode([{'key': '!', 'event_type': 'keydown'}])[1]
print(f"'1' and '!' map to same token: {one_digit == one_symbol}")

## 4. Dataset Implementation

Implements dataset class for loading audio and keyboard event pairs.

In [None]:
class KeyboardEventDataset(Dataset):
    """Dataset for keyboard event detection from audio.
    
    Loads paired audio (.webm) and event (.json) files.
    """
    
    def __init__(
        self, 
        recordings_dir: str,
        tokenizer: KeyboardEventTokenizer,
        feature_extractor: WhisperFeatureExtractor,
        max_length: int = 1024,
        sample_rate: int = 16000,
    ):
        """
        Args:
            recordings_dir: Directory containing .webm and .json files
            tokenizer: KeyboardEventTokenizer instance
            feature_extractor: Whisper feature extractor
            max_length: Maximum sequence length for tokens
            sample_rate: Audio sample rate (Whisper expects 16kHz)
        """
        self.recordings_dir = recordings_dir
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.max_length = max_length
        self.sample_rate = sample_rate
        
        # Find all audio-json pairs
        self.pairs = self._find_pairs()
        print(f"Found {len(self.pairs)} audio-json pairs")

    def _find_pairs(self) -> List[Tuple[str, str]]:
        """Find matching .webm and .json file pairs."""
        pairs = []
        json_files = glob(f"{self.recordings_dir}/*.json")
        
        for json_path in json_files:
            if 'DELETED' in json_path:
                continue
            base = json_path.replace('.json', '')
            webm_path = f"{base}.webm"
            
            if os.path.exists(webm_path):
                pairs.append((webm_path, json_path))
                
        return pairs

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Load and process a single audio-event pair.
        
        Returns:
            Dictionary with:
                - input_features: Mel spectrogram features for Whisper encoder
                - labels: Token IDs for keyboard events
                - attention_mask: Mask for padded tokens
        """
        audio_path, json_path = self.pairs[idx]

        # Load audio
        try:
            audio, sr = librosa.load(audio_path, sr=self.sample_rate)
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            # Return empty sample on error
            audio = np.zeros(self.sample_rate)

        # Extract features (mel spectrogram)
        features = self.feature_extractor(
            audio,
            sampling_rate=self.sample_rate,
            return_tensors="pt"
        )

        # Load events
        with open(json_path, 'r') as f:
            data = json.load(f)

        # Tokenize events
        token_ids = self.tokenizer.encode(data['keystrokes'])
        
        # Truncate or pad to max_length
        if len(token_ids) > self.max_length:
            token_ids = token_ids[:self.max_length]
        else:
            # Pad with BOS token (our pad token)
            token_ids = token_ids + [self.tokenizer.pad_token_id] * (self.max_length - len(token_ids))
        
        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = [1 if tid != self.tokenizer.pad_token_id else 0 for tid in token_ids]

        return {
            'input_features': features.input_features[0],
            'labels': torch.tensor(token_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        }

# Initialize feature extractor
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

# Create dataset
dataset = KeyboardEventDataset(
    recordings_dir=recordings_dir,
    tokenizer=tokenizer,
    feature_extractor=feature_extractor,
    max_length=1024,
)

print(f"\nDataset size: {len(dataset)}")

In [None]:
# Test dataset loading
sample = dataset[0]

print("Sample data shapes:")
print(f"  input_features: {sample['input_features'].shape}")
print(f"  labels: {sample['labels'].shape}")
print(f"  attention_mask: {sample['attention_mask'].shape}")

print(f"\nFirst 20 label tokens: {sample['labels'][:20].tolist()}")
print(f"First 20 attention mask: {sample['attention_mask'][:20].tolist()}")

# Count non-padded tokens
non_pad_count = sample['attention_mask'].sum().item()
print(f"\nNon-padded tokens: {non_pad_count} / {len(sample['labels'])}")

## 5. Custom Decoder Architecture

Implements a custom transformer decoder for generating keyboard event sequences.

In [None]:
class KeyboardEventDecoder(nn.Module):
    """Custom transformer decoder for keyboard event prediction.
    
    Generates autoregressive sequences of keyboard events from Whisper encoder features.
    """
    
    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int = 384,  # Match whisper-tiny
        num_layers: int = 4,
        num_heads: int = 6,
        ff_dim: int = 1536,  # 4x hidden_dim
        dropout: float = 0.1,
        max_seq_len: int = 1024,
    ):
        """
        Args:
            vocab_size: Size of token vocabulary
            hidden_dim: Hidden dimension (should match encoder output)
            num_layers: Number of decoder layers
            num_heads: Number of attention heads
            ff_dim: Feed-forward dimension
            dropout: Dropout probability
            max_seq_len: Maximum sequence length
        """
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
        
        # Positional encoding
        self.pos_encoding = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
        
        # Transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True,  # Pre-LN
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, vocab_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        input_ids: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            input_ids: Token IDs [batch, seq_len]
            encoder_hidden_states: Encoder outputs [batch, enc_seq_len, hidden_dim]
            attention_mask: Mask for padding [batch, seq_len]
            
        Returns:
            Logits [batch, seq_len, vocab_size]
        """
        batch_size, seq_len = input_ids.shape
        
        # Embed tokens
        x = self.token_embedding(input_ids)  # [batch, seq_len, hidden_dim]
        
        # Add positional encoding
        x = x + self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x)
        
        # Create causal mask (prevent attending to future tokens)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        
        # Create key padding mask (mask out padding tokens)
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0)  # True for padding positions
        else:
            key_padding_mask = None
        
        # Apply decoder
        x = self.decoder(
            tgt=x,
            memory=encoder_hidden_states,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=key_padding_mask,
        )
        
        # Project to vocabulary
        logits = self.output_proj(x)  # [batch, seq_len, vocab_size]
        
        return logits

# Test decoder initialization
decoder = KeyboardEventDecoder(
    vocab_size=tokenizer.vocab_size,
    hidden_dim=384,  # whisper-tiny dimension
    num_layers=4,
    num_heads=6,
)

print(f"Decoder parameters: {sum(p.numel() for p in decoder.parameters()):,}")
print(f"Decoder vocab size: {decoder.vocab_size}")

## 6. Complete Model with Whisper Encoder + LoRA

Combines Whisper encoder (with LoRA) and custom decoder into a single model.

In [None]:
class WhisperKeyboardModel(nn.Module):
    """Complete model combining Whisper encoder with LoRA and custom decoder."""
    
    def __init__(
        self,
        whisper_model_name: str = "openai/whisper-tiny",
        vocab_size: int = 256,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
    ):
        super().__init__()
        
        # Load pretrained Whisper encoder
        print(f"Loading {whisper_model_name}...")
        whisper = WhisperModel.from_pretrained(whisper_model_name)
        self.encoder = whisper.encoder
        
        # Get encoder config
        config = whisper.config
        self.hidden_dim = config.d_model
        
        print(f"Encoder hidden dim: {self.hidden_dim}")
        
        # Apply LoRA to encoder
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj", "k_proj"],  # Attention projections
            lora_dropout=lora_dropout,
            bias="none",
        )
        
        print("Applying LoRA to encoder...")
        self.encoder = get_peft_model(self.encoder, lora_config)
        
        # Custom decoder
        self.decoder = KeyboardEventDecoder(
            vocab_size=vocab_size,
            hidden_dim=self.hidden_dim,
            num_layers=4,
            num_heads=6,
        )
        
        print(f"Total model parameters: {sum(p.numel() for p in self.parameters()):,}")
        print(f"Trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
        
    def forward(
        self,
        input_features: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            input_features: Mel spectrogram [batch, n_mels, time]
            labels: Target token IDs [batch, seq_len]
            attention_mask: Mask for labels [batch, seq_len]
            
        Returns:
            loss, logits
        """
        # Encode audio
        encoder_outputs = self.encoder(input_features)
        encoder_hidden_states = encoder_outputs.last_hidden_state
        
        # Prepare decoder inputs (shift labels right, start with BOS)
        # Input: <BOS> token_1 token_2 ...
        # Target: token_1 token_2 ... <EOS>
        decoder_input_ids = labels[:, :-1]  # Remove last token
        decoder_attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
        
        # Decode
        logits = self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=decoder_attention_mask,
        )
        
        # Compute loss
        targets = labels[:, 1:]  # Remove BOS token
        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1),
            ignore_index=0,  # Ignore BOS (pad) token in loss
        )
        
        return loss, logits

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = WhisperKeyboardModel(
    whisper_model_name="openai/whisper-tiny",
    vocab_size=tokenizer.vocab_size,
    lora_rank=8,
    lora_alpha=16,
)
model = model.to(device)

In [None]:
# Test forward pass with sample data
sample_batch = {
    'input_features': sample['input_features'].unsqueeze(0).to(device),
    'labels': sample['labels'].unsqueeze(0).to(device),
    'attention_mask': sample['attention_mask'].unsqueeze(0).to(device),
}

print("Testing forward pass...")
with torch.no_grad():
    loss, logits = model(**sample_batch)

print(f"\nLoss: {loss.item():.4f}")
print(f"Logits shape: {logits.shape}")
print(f"Expected: [batch_size, seq_len-1, vocab_size] = [1, {sample['labels'].shape[0]-1}, {tokenizer.vocab_size}]")

## 7. Training Configuration

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Training hyperparameters
BATCH_SIZE = 4
LEARNING_RATE = 5e-4
NUM_EPOCHS = 10
WARMUP_STEPS = 100
GRADIENT_CLIP = 1.0

# Create dataloader
train_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Use 0 for notebook compatibility
)

print(f"Training batches: {len(train_loader)}")

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Learning rate scheduler
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)

print(f"Total training steps: {total_steps}")

## 8. Training Loop

In [None]:
from tqdm.auto import tqdm

def train_epoch(model, train_loader, optimizer, scheduler, device, epoch):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        input_features = batch['input_features'].to(device)
        labels = batch['labels'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        loss, logits = model(
            input_features=input_features,
            labels=labels,
            attention_mask=attention_mask,
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
        optimizer.step()
        scheduler.step()
        
        # Track loss
        total_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'avg_loss': f"{total_loss/(batch_idx+1):.4f}",
            'lr': f"{scheduler.get_last_lr()[0]:.6f}"
        })
    
    return total_loss / len(train_loader)

# Training loop
print("Starting training...\n")
for epoch in range(NUM_EPOCHS):
    avg_loss = train_epoch(model, train_loader, optimizer, scheduler, device, epoch)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}\n")

print("Training complete!")

## 9. Evaluation and Testing

In [None]:
# Evaluate on a sample
model.eval()

with torch.no_grad():
    sample_batch = {
        'input_features': sample['input_features'].unsqueeze(0).to(device),
        'labels': sample['labels'].unsqueeze(0).to(device),
        'attention_mask': sample['attention_mask'].unsqueeze(0).to(device),
    }
    
    loss, logits = model(**sample_batch)
    predictions = logits.argmax(dim=-1)
    
print(f"Evaluation loss: {loss.item():.4f}")
print(f"\nPredicted tokens (first 20): {predictions[0, :20].cpu().tolist()}")
print(f"Ground truth tokens (first 20): {sample['labels'][1:21].tolist()}")

# Decode predictions
pred_events = tokenizer.decode(predictions[0].cpu().tolist())
true_events = tokenizer.decode(sample['labels'].tolist())

print(f"\nPredicted events (first 10):")
for i, event in enumerate(pred_events[:10]):
    print(f"  {i}: {event}")

print(f"\nGround truth events (first 10):")
for i, event in enumerate(true_events[:10]):
    print(f"  {i}: {event}")

## 10. Save Model

In [None]:
# Save model checkpoint
checkpoint_path = "whisper_keyboard_model.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'vocab_size': tokenizer.vocab_size,
}, checkpoint_path)

print(f"Model saved to {checkpoint_path}")

# Save tokenizer config
tokenizer_config = {
    'codes': tokenizer.codes,
    'vocab_size': tokenizer.vocab_size,
}

with open('tokenizer_config.json', 'w') as f:
    json.dump(tokenizer_config, f, indent=2)

print("Tokenizer config saved to tokenizer_config.json")