<a href="https://colab.research.google.com/github/ssrinivas-berkeley/genai-playground/blob/main/DeskDrummer_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🥁 DeskDrummer AI — GenAI + JAX Edition (Colab)

Upload or record taps, extract tempo, and then:
- Generate a **style-aware backing track** with [MusicGen](https://huggingface.co/facebook/musicgen-small) (Generative AI).
- Overlay **procedural JAX drums** synced to your taps.
- Export a polished loop.

This combines **Generative AI** (music model) with **JAX DSP synthesis**.

In [1]:
!pip -q install --upgrade librosa soundfile jax jaxlib torch accelerate transformers

import jax, jax.numpy as jnp
import numpy as np
import librosa, soundfile as sf, os
from IPython.display import Audio, display
from google.colab import files

os.makedirs("results", exist_ok=True)
SR = 22050
print("✅ Installed. SR =", SR)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.9/374.9 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.3/11.3 MB[0m [31m135.3 MB/s[0m eta [36m0:00:00[0m
[?25h



✅ Installed. SR = 22050


In [2]:
# 📁 Upload a tap recording
audio_path = None
uploaded = files.upload()
for k in uploaded:
    audio_path = k
    break

if audio_path:
    print("📁 Using uploaded file:", audio_path)
else:
    print("ℹ️ No file uploaded.")

Saving file_example_MP3_700KB.mp3 to file_example_MP3_700KB.mp3
📁 Using uploaded file: file_example_MP3_700KB.mp3


In [3]:
# 🧠 Beat tracking
if not audio_path:
    raise SystemExit("No audio provided")

y, sr = librosa.load(audio_path, sr=SR, mono=True)
y = librosa.util.normalize(y)

onset_env = librosa.onset.onset_strength(y=y, sr=sr)
tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')
tempo_val = float(np.atleast_1d(tempo)[0])

print(f"Estimated tempo: {tempo_val:.1f} BPM")
beat_sec = 60.0 / max(tempo_val, 1e-6)

Estimated tempo: 95.7 BPM


In [4]:
# 🎶 Generative AI Backing Track with MusicGen
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")

prompt = "lofi hip hop beat with jazzy hi-hats and deep kick"
inputs = processor(text=[prompt], padding=True, return_tensors="pt")

gen_audio = model.generate(**inputs, max_new_tokens=256)
gen_np = gen_audio[0,0].cpu().numpy()
sf.write("results/musicgen.wav", gen_np, model.config.audio_encoder.sampling_rate)
print("✅ Generated backing track with MusicGen")
display(Audio("results/musicgen.wav", rate=model.config.audio_encoder.sampling_rate))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/275 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.36G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/224 [00:00<?, ?B/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


✅ Generated backing track with MusicGen


In [8]:
# 🥁 JAX Drum Synthesizers
def envelope(n, a=0.005, d=0.2):
    t = jnp.linspace(0, n/SR, n, endpoint=False)
    e = jnp.exp(-t/d)
    a_n = jnp.maximum(1, int(a*SR))
    attack = jnp.linspace(0, 1, a_n)
    e = e.at[:a_n].set(e[:a_n] * attack)
    return e

def kick(seed=0, length=0.25):
    n = int(length*SR)
    t = jnp.linspace(0, length, n, endpoint=False)
    sine = jnp.sin(2*jnp.pi*(100*jnp.exp(-t*10))*t)
    return 0.6 * sine * envelope(n, a=0.002, d=0.15)

def snare(seed=1, length=0.15):
    n = int(length*SR)
    noise = jax.random.normal(jax.random.PRNGKey(seed), (n,))
    return 0.3 * noise * envelope(n, a=0.001, d=0.12)

def hat(seed=2, length=0.05):
    n = int(length*SR)
    noise = jax.random.normal(jax.random.PRNGKey(seed), (n,))
    return 0.15 * noise * envelope(n, a=0.001, d=0.05)

In [6]:
# 🎼 Sequence pattern & mix
length_beats = int(beats[-1] / beat_sec) + 4
dur = int((length_beats+2)*beat_sec*SR)
overlay = jnp.zeros(dur)

for i in range(length_beats):
    t0 = int(i*beat_sec*SR)
    if i % 4 == 0:
        seg = kick(i)
        overlay = overlay.at[t0:t0+seg.shape[0]].add(seg)
    if i % 2 == 0:
        seg = snare(i)
        overlay = overlay.at[t0:t0+seg.shape[0]].add(seg)
    seg = hat(i)
    overlay = overlay.at[t0:t0+seg.shape[0]].add(seg)

In [7]:
# 🎚️ Combine MusicGen + JAX overlay
sr_gen = model.config.audio_encoder.sampling_rate
target_len = len(gen_np)
overlay_np = np.array(overlay)
if len(overlay_np) < target_len:
    overlay_np = np.pad(overlay_np, (0, target_len-len(overlay_np)))
else:
    overlay_np = overlay_np[:target_len]

mix = gen_np + 0.4*overlay_np
mix = mix / (np.max(np.abs(mix)) + 1e-6)

sf.write("results/deskdrummer_genai.wav", mix, sr_gen)
display(Audio("results/deskdrummer_genai.wav", rate=sr_gen))
print("✅ Saved results/deskdrummer_genai.wav")

✅ Saved results/deskdrummer_genai.wav
