Expected format:
{
  "audio": "/content/keyboard_dataset/audio/example_0001.wav",
  "events": [
    {"time": 0.12, "key": "A", "type": "down"},
    {"time": 0.25, "key": "A", "type": "up"},
    {"time": 0.40, "key": "SPACE", "type": "down"},
    {"time": 0.50, "key": "SPACE", "type": "up"}
  ]
}

In [1]:
# Cell 1: Basic setup

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
# assert device == "cuda", "Please enable GPU in Runtime > Change runtime type."


Device: cpu


In [2]:
# Cell 1: Point to your directory and inspect JSON/audio files

import os, glob

DATA_DIR = "../recordings"  # üîÅ change this to your folder

json_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.json")))
webm_files = sorted(glob.glob(os.path.join(DATA_DIR, "*.webm")))

print(f"Found {len(json_files)} JSON files")
print(f"Found {len(webm_files)} WEBM files")

print("First 3 JSON files:")
for f in json_files[:3]:
    print("  ", os.path.basename(f))

print("First 3 WEBM files:")
for f in webm_files[:3]:
    print("  ", os.path.basename(f))


Found 9 JSON files
Found 10 WEBM files
First 3 JSON files:
   recording_20251108_122658_DELETED.json
   recording_20251108_122843_DELETED.json
   recording_20251108_134844_DELETED.json
First 3 WEBM files:
   recording_20251108_122658_DELETED.webm
   recording_20251108_122843_DELETED.webm
   recording_20251108_134844_DELETED.webm


In [3]:
# Cell 2: Load JSON files into a Dataset and attach audio paths

from datasets import load_dataset, Audio, DatasetDict
import os

TARGET_SAMPLING_RATE = 16000  # safe to redefine, or remove if already defined

# Use a non-reserved split name like "data" instead of "all"
data_files = {"data": os.path.join(DATA_DIR, "*.json")}
raw_all = load_dataset("json", data_files=data_files)["data"]

print(raw_all)
print("Example raw JSON entry:\n", raw_all[0])

# Add an "audio" field pointing to the matching .webm file for each recording
def add_audio_path(example):
    # Use the audio_file field instead of constructing from recording_id
    audio_filename = example["audio_file"]
    audio_path = os.path.join(DATA_DIR, audio_filename)
    
    # If the file doesn't exist, check for a _DELETED version
    if not os.path.exists(audio_path):
        # Try inserting _DELETED before the extension
        base_name = audio_filename.replace(".webm", "")
        deleted_path = os.path.join(DATA_DIR, base_name + "_DELETED.webm")
        if os.path.exists(deleted_path):
            audio_path = deleted_path
        else:
            raise FileNotFoundError(f"Missing audio file: tried {audio_path} and {deleted_path}")
    
    example["audio"] = audio_path
    return example

raw_all = raw_all.map(add_audio_path)

# Cast "audio" column to an Audio feature so it gets decoded & resampled on access
raw_all = raw_all.cast_column("audio", Audio(sampling_rate=TARGET_SAMPLING_RATE))

print("\nAfter adding audio and casting column:")
print(raw_all)


  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['recording_id', 'start_timestamp', 'keyboard_session_id', 'control_session_id', 'keystrokes', 'audio_file', 'end_timestamp'],
    num_rows: 9
})
Example raw JSON entry:
 {'recording_id': '8d95d4ad-3192-4487-8499-1c6db012bf72', 'start_timestamp': 1762622818800, 'keyboard_session_id': 'a33ff147-dce3-43fb-bb42-31654d5a8958', 'control_session_id': 'ff21dc90-e7a8-427a-bb36-e83fa7071c62', 'keystrokes': [{'timestamp': 1762622819922, 'key': 'Shift', 'event_type': 'keydown'}, {'timestamp': 1762622819970, 'key': 'H', 'event_type': 'keydown'}, {'timestamp': 1762622820034, 'key': 'Shift', 'event_type': 'keyup'}, {'timestamp': 1762622820073, 'key': 'h', 'event_type': 'keyup'}, {'timestamp': 1762622820127, 'key': 'e', 'event_type': 'keydown'}, {'timestamp': 1762622820179, 'key': 'e', 'event_type': 'keyup'}, {'timestamp': 1762622820198, 'key': 'l', 'event_type': 'keydown'}, {'timestamp': 1762622820261, 'key': 'l', 'event_type': 'keyup'}, {'timestamp': 1762622820350, 'key': 'l

In [4]:
# Cell 3: Create train/validation splits and inspect a couple of examples

# 90/10 split into train / validation
splits = raw_all.train_test_split(test_size=0.1, seed=42)
raw_datasets = DatasetDict({
    "train": splits["train"],
    "validation": splits["test"],
})

print(raw_datasets)

# Peek at one training example
ex = raw_datasets["train"][0]
print("\nOne training example:")
print("recording_id:", ex["recording_id"])
audio = ex["audio"]
audio_info = audio if isinstance(audio, dict) else {
    "sampling_rate": audio["sampling_rate"],
    "array": audio["array"],
}
print("audio fields:", list(audio_info.keys()))
print("audio sampling rate:", audio_info["sampling_rate"])
print("audio array shape:", audio_info["array"].shape)
print("num keystrokes:", len(ex["keystrokes"]))
print("first 5 keystrokes:", ex["keystrokes"][:5])


DatasetDict({
    train: Dataset({
        features: ['recording_id', 'start_timestamp', 'keyboard_session_id', 'control_session_id', 'keystrokes', 'audio_file', 'end_timestamp', 'audio'],
        num_rows: 8
    })
    validation: Dataset({
        features: ['recording_id', 'start_timestamp', 'keyboard_session_id', 'control_session_id', 'keystrokes', 'audio_file', 'end_timestamp', 'audio'],
        num_rows: 1
    })
})

One training example:
recording_id: 8d95d4ad-3192-4487-8499-1c6db012bf72
audio fields: ['sampling_rate', 'array']
audio sampling rate: 16000
audio array shape: (127896,)
num keystrokes: 62
first 5 keystrokes: [{'timestamp': 1762622819922, 'key': 'Shift', 'event_type': 'keydown'}, {'timestamp': 1762622819970, 'key': 'H', 'event_type': 'keydown'}, {'timestamp': 1762622820034, 'key': 'Shift', 'event_type': 'keyup'}, {'timestamp': 1762622820073, 'key': 'h', 'event_type': 'keyup'}, {'timestamp': 1762622820127, 'key': 'e', 'event_type': 'keydown'}]
