# Ableton LLM Control

## Setup

### Imports

In [1]:
from typing import Any

import ipywidgets as widgets
import live
import numpy as np
import sounddevice as sd
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    pipeline,
)

### Constants

In [2]:
TEXT_CLASSIFIER_MODEL_NAME = "ableton_osc_matcher"
TEXT_CLASSIFIER_MODEL_PATH = f"../artifacts/{TEXT_CLASSIFIER_MODEL_NAME}_trained"
CONFIDENCE_THRESHOLD = 0.5
# DEVICE = (
#     "mps"
#     if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
#     else "cpu"
# )
INPUT_CHANNELS = 1
SAMPLE_RATE = 16_000  # Whisper sample rate
MAX_FRAMES = 5 * SAMPLE_RATE  # Max recording time

sd.default.samplerate = SAMPLE_RATE
sd.default.channels = INPUT_CHANNELS

## Voice recording

In [3]:
def init_recording() -> np.ndarray:
    return np.zeros((MAX_FRAMES, INPUT_CHANNELS))

def start_recording() -> np.ndarray:
    recording = init_recording()
    sd.rec(out=recording)
    return recording

def trim_recording(recording: np.ndarray) -> np.ndarray:
    last_non_zero = np.max(np.where(recording.any(axis=1))[0]) + 1
    return recording[:last_non_zero]

def stop_recording(out: np.ndarray) -> np.ndarray:
    sd.stop()
    return trim_recording(out)

## Voice transcription

In [4]:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
speech_recognition_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
transcriber = pipeline("automatic-speech-recognition", model=speech_recognition_model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor)

def transcribe(recording: np.ndarray) -> str:
    result = transcriber(recording.squeeze())["text"].strip()
    print(f"[Transcriber]: {result}")
    return result

## Text classification

In [5]:
text_classifier_model = AutoModelForSequenceClassification.from_pretrained(TEXT_CLASSIFIER_MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TEXT_CLASSIFIER_MODEL_PATH)
classifier = pipeline("text-classification", model=text_classifier_model, tokenizer=tokenizer)

def classify_text(command: str) -> str:
    results: list[dict] = classifier(command)
    result = results[0]
    print(f"[Classifier]: {result}")
    return result

## Ableton Live connection

In [11]:
livequery = live.Query()
test_response = livequery.query("/live/test")
print(f"[Live]: {test_response[0]}")

[Live]: ok


In [7]:
DEFAULT_PARAMETERS = {
    "/live/song/create_audio_track": [-1],
    "/live/song/create_midi_track": [-1],
    "/live/song/create_return_track": [-1],
    "/live/song/create_scene": [-1],
}

def get_parameters(osc_endpoint: str) -> list[Any]:
    return DEFAULT_PARAMETERS.get(osc_endpoint, [])

## Full pipeline

In [8]:
def run(recording: np.ndarray) -> None:
    command = transcribe(recording)
    classification = classify_text(command)
    if classification["score"] < CONFIDENCE_THRESHOLD:
        print("[Classifier] Low confidence")
        return
    osc_endpoint = classification["label"]
    if osc_endpoint == "none":
        print("[Live] No matching endpoint")
        return
    parameters = get_parameters(osc_endpoint)
    live_result = livequery.cmd(osc_endpoint, *parameters)
    if live_result is not None:
        print(f"[Live]: {live_result}")

## User interface

In [9]:
recording = init_recording()

button = widgets.Button(
    description="Record",
    disabled=False,
    button_style="danger",
    tooltip="Record",
    icon="microphone",
)

def on_click(b: widgets.Button) -> None:
    global recording
    if b.description == "Record":
        b.description = "Done"
        b.tooltip = "Done"
        b.button_style = "warning"
        b.icon = "microphone-slash"
        recording = start_recording()
    else:
        b.description = "Record"
        b.tooltip = "Record"
        b.button_style = "danger"
        b.icon = "microphone"
        b.disabled = True
        recording = stop_recording(recording)
        b.disabled = False
        run(recording)

button.on_click(on_click)
button

Button(button_style='danger', description='Record', icon='microphone', style=ButtonStyle(), tooltip='Record')

[Transcriber]: New MIDI track.
[Classifier]: {'label': '/live/song/create_midi_track', 'score': 0.7960658669471741}
[Transcriber]: Play scene.
[Classifier]: {'label': '/live/view/start_listen/selected_scene', 'score': 0.8248885869979858}
[Transcriber]: Play scene.
[Classifier]: {'label': '/live/view/start_listen/selected_scene', 'score': 0.8248885869979858}
[Transcriber]: Play scene.
[Classifier]: {'label': '/live/view/start_listen/selected_scene', 'score': 0.8248885869979858}
[Transcriber]: Stop playing.
[Classifier]: {'label': '/live/song/stop_playing', 'score': 0.4330361783504486}
[Classifier] Low confidence
[Transcriber]: Start playing.
[Classifier]: {'label': '/live/song/continue_playing', 'score': 0.46374908089637756}
[Classifier] Low confidence
[Transcriber]: From the top.
[Classifier]: {'label': '/live/song/start_playing', 'score': 0.9426731467247009}
[Transcriber]: Hose?
[Classifier]: {'label': 'none', 'score': 0.9755275249481201}
[Live] No matching endpoint
[Transcriber]: Sti

In [10]:
# result = livequery.cmd("/live/song/create_return_track")
# print(result)