In [None]:
# --- Imports ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import os, glob, random
from IPython.display import Audio, display

# --- Constants ---
TARGET_SR = 44100
N_FFT = 1024
HOP_LENGTH = 256
LATENT_DIM = 64
FIXED_FRAMES = 512  # <-- key change here
TOTAL_EPOCHS = 30000
CHUNK_SIZE = 500
NUM_CHUNKS = TOTAL_EPOCHS // CHUNK_SIZE
KL_TARGET = 0.1
KL_WARMUP = 2000

# --- Dataset ---
folder_path = '/content/drive/MyDrive/Neural Drum Machine/Samples/01. Bass Drum'
files_list = glob.glob(os.path.join(folder_path, '*.wav'))
print(f"Found {len(files_list)} bass drum samples.")

def wav_to_spec(filename, sr=TARGET_SR, n_fft=N_FFT, hop_length=HOP_LENGTH, target_frames=FIXED_FRAMES):
    y, _ = librosa.load(filename, sr=sr)
    y, _ = librosa.effects.trim(y, top_db=30)
    y = librosa.util.fix_length(y, size=sr)
    S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    mag = np.abs(S)
    mag = np.log1p(mag)
    mag = mag / mag.max()

    # Pad or truncate to fixed length
    if mag.shape[1] < target_frames:
        mag = np.pad(mag, ((0, 0), (0, target_frames - mag.shape[1])))
    else:
        mag = mag[:, :target_frames]

    return mag  # (513, 512)

SAMPLES = np.stack([wav_to_spec(f) for f in files_list])
print(f"SAMPLES shape: {SAMPLES.shape}")  # (N, 513, 512)

class DrumDataset(torch.utils.data.Dataset):
    def __init__(self, specs):
        self.specs = torch.tensor(specs, dtype=torch.float32)
    def __len__(self): return len(self.specs)
    def __getitem__(self, idx): return self.specs[idx]

loader = torch.utils.data.DataLoader(DrumDataset(SAMPLES), batch_size=16, shuffle=True)

# --- VAE ---
class Encoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(513, 256, 4, stride=2, padding=1),  # → (B, 256, 256)
            nn.ReLU(),
            nn.Conv1d(256, 128, 4, stride=2, padding=1),  # → (B, 128, 128)
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),                      # → (B, 128, 1)
        )
        self.mu = nn.Linear(128, latent_dim)
        self.logvar = nn.Linear(128, latent_dim)

    def forward(self, x):  # x: (B, 513, 512)
        h = self.conv(x).squeeze(-1)  # (B, 128)
        return self.mu(h), self.logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512 * 128),  # → reshape to (B, 512, 128)
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1),  # → (B, 256, 256)
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1),  # → (B, 128, 512)
            nn.ReLU(),
            nn.Conv1d(128, 513, 1),                                # → (B, 513, 512)
            nn.ReLU()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 512, 128)
        return self.deconv(x)

class VAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.latent_dim = latent_dim

    def forward(self, x):
        mu, logvar = self.encoder(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        recon = self.decoder(z)
        return recon, mu, logvar

vae = VAE().cuda()
opt = optim.Adam(vae.parameters(), lr=1e-3)

def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())

def spec_to_audio(log_mag, sr=TARGET_SR, n_fft=N_FFT, hop_length=HOP_LENGTH):
    mag = np.expm1(log_mag * log_mag.max())  # undo log1p and scaling
    return librosa.griffinlim(mag, hop_length=hop_length, n_fft=n_fft)

# --- Training ---
start_epoch = 0
for chunk in range(NUM_CHUNKS):
    print(f"\n--- Training chunk {chunk+1}/{NUM_CHUNKS} (Epochs {start_epoch+1} to {start_epoch+CHUNK_SIZE}) ---")

    for epoch in range(start_epoch, start_epoch + CHUNK_SIZE):
        vae.train(); total = 0
        kl_weight = min(KL_TARGET, KL_TARGET * epoch / KL_WARMUP)

        for batch in loader:  # (B, 513, 512)
            batch = batch.cuda()
            opt.zero_grad()
            recon, mu, logvar = vae(batch)
            recon_loss = F.l1_loss(recon, batch, reduction='sum')
            kl = kl_loss(mu, logvar)
            loss = recon_loss + kl_weight * kl
            loss.backward(); opt.step()
            total += loss.item()

        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch+1} | Loss: {total / len(SAMPLES):.4f} | KL Weight: {kl_weight:.4f}")

    # --- Preview ---
    vae.eval()
    with torch.no_grad():
        batch = next(iter(loader)).cuda()
        idx = random.randint(0, batch.size(0) - 1)
        real = batch[idx:idx+1]
        recon, _, _ = vae(real)
        real_np = real.squeeze().cpu().numpy()
        recon_np = recon.squeeze().cpu().numpy()

        z = torch.randn(1, LATENT_DIM).cuda()
        fake = vae.decoder(z).cpu().squeeze().numpy()

    plt.figure(figsize=(15, 4))
    plt.subplot(1, 3, 1)
    librosa.display.specshow(real_np, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='linear')
    plt.title("Original")
    plt.subplot(1, 3, 2)
    librosa.display.specshow(recon_np, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='linear')
    plt.title("Reconstructed")
    plt.subplot(1, 3, 3)
    librosa.display.specshow(fake, sr=TARGET_SR, hop_length=HOP_LENGTH, x_axis='time', y_axis='linear')
    plt.title("Random Sample")
    plt.suptitle(f"Chunk {chunk+1}/{NUM_CHUNKS} — Epoch {start_epoch+CHUNK_SIZE}")
    plt.tight_layout()
    plt.show()

    print("Original Kick")
    display(Audio(spec_to_audio(real_np), rate=TARGET_SR))
    print("Reconstructed Kick")
    display(Audio(spec_to_audio(recon_np), rate=TARGET_SR))
    print("Randomly Generated Kick")
    display(Audio(spec_to_audio(fake), rate=TARGET_SR))

    start_epoch += CHUNK_SIZE
