<a href="https://colab.research.google.com/github/ssrinivas-berkeley/genai-playground/blob/main/DeskDrummer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# 🥁 DeskDrummer AI — JAX Edition (Colab)

Turn taps on your desk (via upload or mic recording) into a quantized drum+bass loop.  
This version uses **JAX** for synthesis so audio generation runs fast on GPU/TPU and can be extended to differentiable music models.


## 🛠️ Setup

In [9]:

!pip -q install --upgrade librosa soundfile jax jaxlib
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


import jax
import jax.numpy as jnp
import numpy as np
import librosa, soundfile as sf, os, io, base64
from IPython.display import Audio, display
from google.colab import files, output
os.makedirs("results", exist_ok=True)

SR = 22050
print("✅ Installed. SR =", SR)


Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
✅ Installed. SR = 22050


## 📁 Upload an audio (or record)

In [2]:

audio_path = None
uploaded = files.upload()
for k in uploaded:
    audio_path = k
    break

if audio_path:
    print("📁 Using uploaded file:", audio_path)
else:
    print("ℹ️ No file uploaded. You can also record with mic in another cell.")


Saving file_example_MP3_700KB.mp3 to file_example_MP3_700KB.mp3
📁 Using uploaded file: file_example_MP3_700KB.mp3


## 🧠 Beat tracking

In [10]:

if not audio_path:
    raise SystemExit("No audio provided")

y, sr = librosa.load(audio_path, sr=SR, mono=True)
y = librosa.util.normalize(y)

onset_env = librosa.onset.onset_strength(y=y, sr=sr)
# tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')

# print(f"Estimated tempo: {tempo:.1f} BPM")
# beat_sec = 60.0 / max(tempo, 1e-6)


tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr, units='time')

# Ensure tempo is scalar
tempo_val = float(np.atleast_1d(tempo)[0])

print(f"Estimated tempo: {tempo_val:.1f} BPM")
beat_sec = 60.0 / max(tempo_val, 1e-6)

Estimated tempo: 95.7 BPM


## 🎛️ JAX drum synthesizers

In [11]:

def envelope(n, a=0.005, d=0.2):
    t = jnp.linspace(0, n/SR, n, endpoint=False)
    e = jnp.exp(-t/d)
    a_n = jnp.maximum(1, int(a*SR))
    attack = jnp.linspace(0, 1, a_n)
    e = e.at[:a_n].set(e[:a_n] * attack)
    return e

def kick(seed=0, length=0.25):
    n = int(length*SR)
    t = jnp.linspace(0, length, n, endpoint=False)
    sine = jnp.sin(2*jnp.pi*(100*jnp.exp(-t*10))*t)
    return 0.6 * sine * envelope(n, a=0.002, d=0.15)

def snare(seed=1, length=0.15):
    n = int(length*SR)
    noise = jax.random.normal(jax.random.PRNGKey(seed), (n,))
    return 0.3 * noise * envelope(n, a=0.001, d=0.12)

def hat(seed=2, length=0.05):
    n = int(length*SR)
    noise = jax.random.normal(jax.random.PRNGKey(seed), (n,))
    return 0.15 * noise * envelope(n, a=0.001, d=0.05)

def bass(seed=3, length=0.5, freq=55.0):
    n = int(length*SR)
    t = jnp.linspace(0, length, n, endpoint=False)
    wave = jnp.sin(2*jnp.pi*freq*t)
    return 0.25 * wave * envelope(n, a=0.005, d=0.25)


## 🎼 Sequence pattern & mix

In [12]:

length_beats = int(beats[-1] / beat_sec) + 4
dur = int((length_beats+2)*beat_sec*SR)
mix = jnp.zeros(dur)

for i in range(length_beats):
    t0 = int(i*beat_sec*SR)
    # Kick on downbeats
    if i % 4 == 0:
        seg = kick(i)
        mix = mix.at[t0:t0+seg.shape[0]].add(seg)
    # Snare on 2 & 4
    if i % 2 == 0:
        seg = snare(i)
        mix = mix.at[t0:t0+seg.shape[0]].add(seg)
    # Hi-hats
    seg = hat(i)
    mix = mix.at[t0:t0+seg.shape[0]].add(seg)
    seg = hat(i+999)
    off = t0 + int(0.5*beat_sec*SR)
    mix = mix.at[off:off+seg.shape[0]].add(seg)
    # Bass every other beat
    if i % 2 == 0:
        seg = bass(i)
        mix = mix.at[t0:t0+seg.shape[0]].add(seg)


## 💾 Normalize & export

In [13]:

mix = mix / (jnp.max(jnp.abs(mix)) + 1e-6)
sf.write("results/deskdrummer_jax.wav", np.array(mix), SR)

display(Audio("results/deskdrummer_jax.wav", rate=SR))
print("✅ Saved results/deskdrummer_jax.wav")


✅ Saved results/deskdrummer_jax.wav
