In [1]:
import librosa
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.utils.data import DataLoader, Dataset

from pydub import AudioSegment
import soundfile as sf

import matplotlib.pyplot as plt
import librosa.display
import random
import torchaudio

import matplotlib.pyplot as plt
import torchaudio
import torchaudio.transforms as T
import os

In [2]:
def mu_law_encode(x, mu=255):
    mu = float(mu)
    mu_tensor = torch.tensor(mu, dtype=x.dtype, device=x.device)
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu_tensor)
    return ((x_mu + 1) / 2 * mu + 0.5).clamp(0, mu).long()  # clamp to avoid bad values

def mu_law_decode(x, mu=255):
    mu = float(mu)
    x = 2 * (x.float() / mu) - 1
    return torch.sign(x) * (1 / mu) * ((1 + mu) ** torch.abs(x) - 1)


In [3]:
class RawWaveformDataset(Dataset):
    def __init__(self, wav_dir, sample_rate=16000, duration=4.0, preload=True, mu_law=False):
        self.wav_paths = [os.path.join(wav_dir, f) for f in os.listdir(wav_dir) if f.endswith(".wav")]
        self.sample_rate = sample_rate
        self.duration = duration
        self.num_samples = int(duration * sample_rate)
        self.preload = preload
        self.mu_law = mu_law

        if preload:
            print(f"🔁 Preloading {len(self.wav_paths)} audio files...")
            self.cache = [self.process_file(p) for p in self.wav_paths]
        else:
            self.cache = None

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, idx):
        return self.cache[idx] if self.preload else self.process_file(self.wav_paths[idx])

    def process_file(self, path):
        wav, sr = torchaudio.load(path)
        wav = wav.mean(dim=0)  # mono
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=self.sample_rate)

        # Pad or crop to fixed length
        if wav.size(0) < self.num_samples:
            pad_len = self.num_samples - wav.size(0)
            wav = F.pad(wav, (0, pad_len))
        elif wav.size(0) > self.num_samples:
            start = random.randint(0, wav.size(0) - self.num_samples)
            wav = wav[start:start + self.num_samples]

        # Normalize
        wav = wav / wav.abs().max()

        # Return either raw waveform or mu-law encoded
        if self.mu_law:
            return mu_law_encode(wav.unsqueeze(0)).squeeze(0).long()  # shape: (T,)
        else:
            return wav.unsqueeze(0)  # shape: (1, T)


In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, residual_channels, skip_channels, dilation):
        super().__init__()
        self.dilation = dilation
        self.filter_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=2, dilation=dilation, padding=0)
        self.gate_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=2, dilation=dilation, padding=0)
        self.res_conv = nn.Conv1d(residual_channels, residual_channels, kernel_size=1)
        self.skip_conv = nn.Conv1d(residual_channels, skip_channels, kernel_size=1)

    def forward(self, x):
        # Manual causal padding
        pad = self.dilation
        x_padded = F.pad(x, (pad, 0))  # pad left only

        filter_out = torch.tanh(self.filter_conv(x_padded))
        gate_out = torch.sigmoid(self.gate_conv(x_padded))
        out = filter_out * gate_out

        res = self.res_conv(out)
        skip = self.skip_conv(out)
        return x + res, skip
    

class WaveNet(nn.Module):
    def __init__(self, in_channels=256, residual_channels=64, skip_channels=128, 
                 dilation_cycles=2, layers_per_cycle=10):
        super().__init__()
        self.input_conv = nn.Conv1d(in_channels, residual_channels, kernel_size=1)
        self.res_blocks = nn.ModuleList()

        for cycle in range(dilation_cycles):
            for layer in range(layers_per_cycle):
                dilation = 2 ** layer
                self.res_blocks.append(ResidualBlock(residual_channels, skip_channels, dilation))

        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(skip_channels, skip_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(skip_channels, in_channels, kernel_size=1)
        )

    def forward(self, x):  # x: one-hot encoded waveform: (B, 256, T)
        x = self.input_conv(x)
        skip_total = 0
        for block in self.res_blocks:
            x, skip = block(x)
            skip_total = skip_total + skip if isinstance(skip_total, torch.Tensor) else skip
        return self.output(skip_total)



In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️ Using device: {device}")

# Initialize dataset
dataset = RawWaveformDataset(
    wav_dir="data", 
    sample_rate=16000,
    duration=4.0,
    preload=True,
    mu_law=True
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=False)

🖥️ Using device: cuda
🔁 Preloading 1000 audio files...


In [None]:
def visualize_reconstruction_live(model, waveform, step, sample_rate=16000):
    model.eval()
    with torch.no_grad():
        waveform = waveform.unsqueeze(0).to(model.input_conv.weight.device)
        x = mu_law_encode(waveform).long()
        x_input = F.one_hot(x[:, :-1], 256).permute(0, 2, 1).float()

        y_pred = model(x_input)
        y_class = torch.argmax(y_pred, dim=1)
        decoded = mu_law_decode(y_class.squeeze(0).cpu())

    # Decode original for fair comparison
    decoded_original = mu_law_decode(mu_law_encode(waveform.cpu().squeeze(0)))

    # Plot waveforms
    plt.figure(figsize=(12, 4))
    plt.plot(decoded_original[:1000], label="Original", alpha=0.7)
    plt.plot(decoded[:1000], label="Decoded", alpha=0.7)
    plt.title(f"Waveform Comparison @ Step {step}")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Plot spectrograms
    spec_transform = T.MelSpectrogram(sample_rate=sample_rate, n_fft=1024, hop_length=256, n_mels=80)
    original_spec = spec_transform(decoded_original.unsqueeze(0))[0]
    decoded_spec = spec_transform(decoded.unsqueeze(0))[0]

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(original_spec.log2().numpy(), origin="lower", aspect="auto")
    plt.title("Original Mel")

    plt.subplot(1, 2, 2)
    plt.imshow(decoded_spec.log2().numpy(), origin="lower", aspect="auto")
    plt.title("Decoded Mel")

    plt.suptitle(f"Mel Spectrogram @ Step {step}")
    plt.tight_layout()
    plt.show()


In [None]:
model = WaveNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10
step = 0

for epoch in range(num_epochs):
    for waveform in dataloader:  # waveform: (B, T) in [-1, 1]
        waveform = waveform.to(device)
        x = mu_law_encode(waveform).long()  # (B, T)

        x_input = F.one_hot(x[:, :-1], 256).permute(0, 2, 1).float().to(device)  # (B, 256, T-1)
        x_target = x[:, 1:].to(device)  # (B, T-1)

        y_pred = model(x_input)  # (B, 256, T-1)
        loss = F.cross_entropy(y_pred, x_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.empty_cache()

        # ✅ Track training
        if step % 10 == 0:
            print(f"[Epoch {epoch} | Step {step}] Loss: {loss.item():.4f} | CUDA: {torch.cuda.memory_allocated() / 1e6:.1f} MB")

        if step % 200 == 0:
            sample_wave = dataset[0]  # Or random
            visualize_reconstruction_live(model, sample_wave, step)

        step += 1


[Epoch 0 | Step 0] Loss: 5.6779 | CUDA: 676.5 MB
[Epoch 0 | Step 10] Loss: 5.4018 | CUDA: 676.5 MB
[Epoch 0 | Step 20] Loss: 5.0894 | CUDA: 676.5 MB
[Epoch 0 | Step 30] Loss: 4.4664 | CUDA: 676.5 MB
[Epoch 0 | Step 40] Loss: 3.0480 | CUDA: 676.5 MB
[Epoch 0 | Step 50] Loss: 0.8183 | CUDA: 676.5 MB
[Epoch 0 | Step 60] Loss: 0.0471 | CUDA: 676.5 MB
[Epoch 0 | Step 70] Loss: 0.0064 | CUDA: 676.5 MB
[Epoch 0 | Step 80] Loss: 0.0024 | CUDA: 676.5 MB
[Epoch 0 | Step 90] Loss: 0.0017 | CUDA: 676.5 MB
[Epoch 0 | Step 100] Loss: 0.0013 | CUDA: 676.5 MB
[Epoch 0 | Step 110] Loss: 0.0016 | CUDA: 676.5 MB
[Epoch 0 | Step 120] Loss: 0.0017 | CUDA: 676.5 MB
[Epoch 0 | Step 130] Loss: 0.0009 | CUDA: 676.5 MB
[Epoch 0 | Step 140] Loss: 0.0009 | CUDA: 676.5 MB
[Epoch 0 | Step 150] Loss: 0.0046 | CUDA: 676.5 MB
[Epoch 0 | Step 160] Loss: 0.0012 | CUDA: 676.5 MB
[Epoch 0 | Step 170] Loss: 0.0006 | CUDA: 676.5 MB
[Epoch 0 | Step 180] Loss: 0.0008 | CUDA: 676.5 MB
[Epoch 0 | Step 190] Loss: 0.0011 | CUDA: 

KeyboardInterrupt: 

In [None]:
def generate(model, seed, steps=16000):
    model.eval()
    output = []

    x = mu_law_encode(seed).long()  # (T,)
    x = F.one_hot(x, 256).float().unsqueeze(0).permute(0, 2, 1)  # (1, 256, T)

    with torch.no_grad():
        for _ in range(steps):
            y = model(x[:, :, -model.receptive_field:])  # Only last context
            probs = F.softmax(y[:, :, -1], dim=-1)       # last timestep
            sample = torch.multinomial(probs, num_samples=1).squeeze(-1)
            sample_one_hot = F.one_hot(sample, 256).float().unsqueeze(2)
            x = torch.cat([x, sample_one_hot], dim=2)
            output.append(sample.item())

    return mu_law_decode(torch.tensor(output))
