<a href="https://colab.research.google.com/github/russabejr/StringAlongHW/blob/main/DDSP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from scipy import signal
import librosa
import time
from scipy.io import wavfile

!wget https://digitalmusicprocessing.github.io/HW6_StringAlong/data.wav
!wget https://digitalmusicprocessing.github.io/HW6_StringAlong/marvin.wav
!wget https://digitalmusicprocessing.github.io/HW6_StringAlong/adele.wav

!pip install torchcrepe
import torchcrepe # https://github.com/maxrmorrison/torchcrepe

--2025-05-08 20:15:46--  https://digitalmusicprocessing.github.io/HW6_StringAlong/data.wav
Resolving digitalmusicprocessing.github.io (digitalmusicprocessing.github.io)... 185.199.109.153, 185.199.108.153, 185.199.111.153, ...
Connecting to digitalmusicprocessing.github.io (digitalmusicprocessing.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 25135448 (24M) [audio/wav]
Saving to: ‘data.wav.1’


2025-05-08 20:15:53 (191 MB/s) - ‘data.wav.1’ saved [25135448/25135448]

--2025-05-08 20:15:53--  https://digitalmusicprocessing.github.io/HW6_StringAlong/marvin.wav
Resolving digitalmusicprocessing.github.io (digitalmusicprocessing.github.io)... 185.199.109.153, 185.199.108.153, 185.199.111.153, ...
Connecting to digitalmusicprocessing.github.io (digitalmusicprocessing.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 328174 (320K) [audio/wav]
Saving to: ‘marvin.wav.1’


2025-05-08 20:15:53 

# Utility Functions (Given)

General purpose functions that will help us with specific parts of the pipeline

In [None]:
def upsample_time(X, hop_length, mode='nearest'):
    """
    Upsample a tensor by a factor of hop_length along the time axis

    Parameters
    ----------
    X: torch.tensor(M, T, N)
        A tensor in which the time axis is axis 1
    hop_length: int
        Upsample factor
    mode: string
        Mode of interpolation.  'nearest' by default to avoid artifacts
        where notes in the violin jump by large intervals

    Returns
    -------
    torch.tensor(M, T*hop_length, N)
        Upsampled tensor
    """
    X = X.permute(0, 2, 1)
    X = nn.functional.interpolate(X, size=hop_length*X.shape[-1], mode=mode)
    return X.permute(0, 2, 1)

def fftconvolve(x, h):
    """
    Perform a fast convolution of two tensors across their last axis
    by using the FFT. Since the DFT assumes circularity, zeropad them
    appropriately before doing the FFT and slice them down afterwards

    The length of the result will be equivalent to np.convolve's 'same'

    Refer to this module for more background:
    https://ursinus-cs372-s2023.github.io/Modules/Module14/Video4

    Parameters
    ----------
    x: torch.tensor(..., N1)
        First tensor
    h: torch.tensor(..., N2)
        Second tensor

    Returns
    -------
    torch.tensor(..., max(N1, N2))
    Tensor resulting from the convolution of x and y across their last axis,
    """
    N = max(x.shape[-1], h.shape[-1])
    if x.shape[-1] != h.shape[-1]:
        # Zeropad so they're equal
        if x.shape[-1] < N:
            x = nn.functional.pad(x, (0, N-x.shape[-1]))
        if h.shape[-1] < N:
            h = nn.functional.pad(h, (0, N-h.shape[-1]))
    x = nn.functional.pad(x, (0, N))
    h = nn.functional.pad(h, (0, N))
    X = torch.fft.rfft(x)
    H = torch.fft.rfft(h)
    y = torch.fft.irfft(X*H)
    return y[..., 0:N]


def plot_stft_comparison(F, L, X, Y, reverb, losses=torch.tensor([]), win=1024, sr=16000):
    """
    Some code to help compare the STFTs of ground truth and output audio, while
    also plotting the frequency, loudness, and reverb to get an idea of what the
    inputs to the network were that gave rise to these ouputs.  It's very helpful
    to call this method while monitoring the training of the network

    Parameters
    ----------
    F: torch.tensor(n_batches, n_samples/hop_length, 1)
         Tensor holding the pitch estimates for the clips
    L: torch.tensor(n_batches, n_samples/hop_length, 1)
         Tensor holding the loudness estimates for the clips
    X: torch.tensor(n_batches, n_samples, 1)
        Ground truth audio
    Y: torch.tensor(n_batches, n_samples, 1)
        Output audio from the network->decoder
    reverb: torch.tensor(reverb_len)
        The learned reverb
    losses: list
        A list of losses over epochs over time
    win: int
        Window length to use in the STFT
    sr: int
        Sample rate of audio (used to help make proper units for time and frequency)
    """
    hop = 256
    hann = torch.hann_window(win).to(X)
    SX = torch.abs(torch.stft(X.squeeze(), win, hop, win, hann, return_complex=True))
    SY = torch.abs(torch.stft(Y.squeeze(), win, hop, win, hann, return_complex=True))
    print(SX.shape)
    extent = (0, SX.shape[2]*hop/sr, SX.shape[1]*sr/win, 0)
    plt.subplot(321)
    plt.imshow(torch.log10(SX.detach().cpu()[0, :, :]), aspect='auto', cmap='magma', extent=extent)
    plt.title("Ground Truth")
    plt.ylim([0, 8000])
    plt.xlabel("Time (Sec)")
    plt.ylabel("Frequency (hz)")

    plt.subplot(322)
    plt.imshow(torch.log10(SY.detach().cpu()[0, :, :]), aspect='auto', cmap='magma', extent=extent)
    plt.title("Synthesized")
    plt.ylim([0, 8000])
    plt.xlabel("Time (Sec)")
    plt.ylabel("Frequency (hz)")

    plt.subplot(323)
    plt.plot(F.detach().cpu()[0, :, 0])
    plt.title("Fundamental Frequency")
    plt.xlabel("Window index")
    plt.ylabel("Hz")
    plt.subplot(324)
    plt.plot(L.detach().cpu()[0, :, 0])
    plt.title("Loudness")
    plt.xlabel("Window Index")
    plt.ylabel("Z-normalized dB")
    if torch.numel(losses) > 0:
        plt.subplot(325)
        plt.plot(losses.detach().cpu().numpy().flatten())
        plt.yscale("log")
        plt.title("Losses (Current {:.3f})".format(losses[-1]))
        plt.xlabel("Epoch")
    plt.subplot(326)
    plt.plot(reverb.detach().cpu().flatten())
    plt.title("Impulse Response")
    plt.xlabel("Sample index")

################################################
# Loudness code modified from original Google Magenta DDSP implementation in tensorflow
# https://github.com/magenta/ddsp/blob/86c7a35f4f2ecf2e9bb45ee7094732b1afcebecd/ddsp/spectral_ops.py#L253
# which, like this repository, is licensed under Apache2 by Google Magenta Group, 2020
# Modifications by Chris Tralie, 2023

def power_to_db(power, ref_db=0.0, range_db=80.0, use_tf=True):
    """Converts power from linear scale to decibels."""
    # Convert to decibels.
    db = 10.0*np.log10(np.maximum(power, 10**(-range_db/10)))
    # Set dynamic range.
    db -= ref_db
    db = np.maximum(db, -range_db)
    return db

def extract_loudness(x, sr, hop_length, n_fft=512):
    """
    Extract the loudness in dB by using an A-weighting of the power spectrum
    (section B.1 of the paper)

    Parameters
    ----------
    x: ndarray(N)
        Audio samples
    sr: int
        Sample rate (used to figure out frequencies for A-weighting)
    hop_length: int
        Hop length between loudness estimates
    n_fft: int
        Number of samples to use in each window
    """
    # Computed centered STFT
    S = librosa.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, center=True)

    # Compute power spectrogram
    amplitude = np.abs(S)
    power = amplitude**2

    # Perceptual weighting.
    freqs = np.arange(S.shape[0])*sr/n_fft
    a_weighting = librosa.A_weighting(freqs)[:, None]

    # Perform weighting in linear scale, a_weighting given in decibels.
    weighting = 10**(a_weighting/10)
    power = power * weighting

    # Average over frequencies (weighted power per a bin).
    avg_power = np.mean(power, axis=0)
    loudness = power_to_db(avg_power)
    return np.array(loudness, dtype=np.float32)

################################################

# Part 1: Dataset

## FM Synthesis Dataset (Given)

For debugging, if you need it.  Your network should at least be able to learn these very simple sounds, so if it can't, you should figure out what the problem is before you move onto training your network on the real violin audio

In [None]:
class FMDataset(Dataset):
    def __init__(self, sr, hop_length, samples_per_batch=1000):
        """
        Instantiate an fm dataset

        Parameters
        ----------
        sr: int
            Sample rate
        hop_length: int
            Samples between loudness and pitch frames
        samples_per_batch: int
            The length of this object
        """
        self.sr = sr
        self.hop_length = hop_length
        self.samples_per_batch = samples_per_batch

    def __len__(self):
        return self.samples_per_batch

    def __getitem__(self, idx):
        """
        Generate a random FM plucked string note between A3 and A5
        over a duration of 4 seconds

        Parameters
        ----------
        idx: int
            Index of example (ignored because data is random)

        Returns
        -------
        x: ndarray(sr*4)
            Audio samples
        pitch: ndarray(sr*4//hop_length)
            The pitch (a constant line since this is one solid note)
        loudness: ndarray(sr*4//hop_length)
            Loudness
        """
        note = np.random.randint(-12, 12)
        sr = self.sr
        ratio = 1
        I = 8
        lam = 3
        duration = 4
        envelope = lambda N, sr, lam: np.exp(-lam*np.arange(N)/sr)
        N = int(duration*sr)
        ts = np.arange(N)/sr
        f = 440*2**(note/12)
        fm = f*ratio
        x = np.cos(2*np.pi*f*ts + envelope(N, sr, lam)*I*(np.cos(2*np.pi*fm*ts)))
        loudness = envelope(N, sr, lam)
        x = x*loudness
        K = x.size//self.hop_length
        loudness = np.array(loudness[0::self.hop_length], dtype=np.float32)
        loudness = 10*np.log10(loudness**2+1e-8)
        loudness = torch.from_numpy(loudness).view(K, 1)
        x = torch.from_numpy(x).view(x.size, 1)
        # Extract pitch and loudness
        pitch = 440*(2**(note/12))*torch.ones(K)
        pitch = pitch.view(K, 1)
        return x, pitch, loudness

## An example of a dataset
sr = 16000
dataset = FMDataset(sr, 160)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
(X, F, L) = next(iter(loader))
ipd.Audio(X[0, :, 0], rate=sr)

## Instrument Dataset

In [None]:
## TODO: Fill this in!
class InstrumentData(Dataset):
    def __init__(self, x, sr, hop_length, samples_per_batch=5000):
        #1
        self.sr = sr
        self.x = x
        self.hop_length = hop_length
        self.samples_per_batch = samples_per_batch

        #2
        loudness = extract_loudness(x)
        self.loudness_mean = loudness.mean()
        self.loudness_stand_dev = loudness.std()
        loudness = (loudness - self.loudness_mean) / self.loudness_stand_dev
        self.loudness = loudness

        #3
        device = 'cuda'
        pitch = torchcrepe.predict(torch.from_numpy(x).view((1, x.size)),sr,hop_length,50,2000,'full',batch_size=2048,device=device).flatten()
        self.pitch = pitch

        #4
        min_len = min(len(self.loudness), len(self.pitch))
        self.loudness = self.loudness[:min_len]
        self.pitch = self.pitch[:min_len]
        self.x = self.x[:min_len * hop_length]


    def __len__(self):
        return self.samples_per_batch

    def __getitem__(self, idx):
        clip_length = 4*self.sr

        start_sample = np.random.randint(0, len(self.x) - clip_length)
        end_sample = start_sample + clip_length

        # Extract audio clip
        x = torch.from_numpy(self.x[start_sample:end_sample]).view(x.size, 1)

        start_frame = start_sample // self.hop_length
        end_frame = end_sample // self.hop_length
        pitch = torch.from_numpy(self.pitch[start_frame:end_frame]).view(x.size, 1)
        loudness = torch.from_numpy(self.loudness[start_frame:end_frame]).view(x.size, 1)
        return x, pitch, loudness

# Part 2a: Decoder Architecture

Section B.2 of the paper


In [7]:
def modified_sigmoid(x):
    return 2*torch.sigmoid(x)**np.log(10) + 1e-7

## TODO: Fill this in!
class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()
        self.sequential = nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.LayerNorm(output_size),
            nn.LeakyReLU(),

            nn.Linear(input_size, output_size),
            nn.LayerNorm(output_size),
            nn.LeakyReLU(),

            nn.Linear(input_size, output_size),
            nn.LayerNorm(output_size),
            nn.LeakyReLU()
        )

class DDSPDecoder(nn.Module):
    def __init__(self, n_units, sr, n_harmonics, n_bands, reverb_len):
        super(DDSPDecoder, self).__init__()
        self.pitch_layer = MLP(1, n_units)
        self.loudness_layer = MLP(1, n_units)
        self.gru = nn.GRU(input_size=n_units*2, hidden_size=n_units, batch_first=True)
        self.joint_layer = MLP(3*n_units, n_units)

        self.harmonics_decoder = nn.Linear(n_units, n_harmonics)
        self.amplitude_decoder = nn.Linear(n_units, 1)
        self.sub_filt_decoder = nn.Linear(n_units, n_bands)

        self.reverb = nn.Parameter(torch.rand(reverb_len)*1e-4-0.5e-4)

    def forward(self, F, L):
        F_encoded = self.pitch_layer(F)
        L_encoded = self.loudness_layer(L)

        FL_cat_2 = torch.cat([F_encoded, L_encoded], dim=2)
        GRU_output = self.gru(FL_cat_2)[0]

        joint_input = torch.cat([F_encoded, L_encoded, GRU_output], dim=2)
        joint_output = self.joint_layer(joint_input)

        C = self.harmonics_decoder(joint_output)
        A = self.amplitude_decoder(joint_output)
        S = self.sub_filt_decoder(joint_output)

        C = modified_sigmoid(C)
        A = modified_sigmoid(A)
        S = modified_sigmoid(S)

        C = C/(1e-8+torch.sum(C, axis=2, keepdim=True))

        reverb = torch.tanh(self.reverb)

        return A, C, S, reverb

# Part 2b: Synthesizer
Section 3.2, 3.3 B.5

Use the outputs of the decoder network to create audio samples, *using only torch methods*


## Subtractive Synthesizer (Given)

In [None]:
def subtractive_synthesis(S, hop_length):
    """
    Perform subtractive synthesis by converting frequency domain transfer
    functions into causal, zero-phase, windowed impulse responses

    Parameters
    ----------
    S: n_batches x time x n_bands
        Subtractive synthesis parameters
    hop_length: int
        Hop length between subtractive synthesis windows

    Returns
    -------
    torch.tensor(n_batches, time*hop_length, 1)
        Subtractive synthesis audio components for each clip
    """

    # Put an imaginary component of all 0s across a new last axis
    # https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
    S = torch.stack([S, torch.zeros_like(S)], -1)
    S = torch.view_as_complex(S)
    # Do the inverse real DFT (assuming symmetry)
    h = torch.fft.irfft(S)

    # Shift the impulse response to zero-phase
    nh = h.shape[-1]
    h = torch.roll(h, nh//2, -1)
    # Apply hann window
    h = h*torch.hann_window(nh, dtype=h.dtype, device=h.device)
    # Shift back to causal
    h = nn.functional.pad(h, (0, hop_length-nh))
    h = torch.roll(h, -nh//2, -1)

    # Apply the impulse response to random noise in [-1, 1]
    noise = torch.rand(h.shape[0],h.shape[1],hop_length).to(h.device)
    noise = noise*2 - 1
    noise = fftconvolve(noise, h).contiguous()

    # Flatten nonoverlapping samples to one contiguous stream
    return noise.reshape(noise.shape[0], noise.shape[1]*noise.shape[2], 1)

## Additive Synthesizer / Putting It Together

In [None]:
## TODO: Fill this in!
def synthesize(self, F, C, A, S, reverb):
    C_upsampled = self.upsample_time(C)
    A_upsampled = self.upsample_time(A)
    F_upsampled = self.upsample_time(F)

    # Harmonics
    batch_size, audio_time, n_harmonics = C_upsampled.shape
    harmonic_numbers = torch.arange(1, n_harmonics + 1).view(1, 1, -1)
    harmonic_freqs = harmonic_numbers * F_upsampled

    phase = 2*torch.pi*torch.cumsum(harmonic_freqs / self.sr, dim=1)

    harmonics = torch.sin(phase) * C_upsampled

    additive = torch.sum(harmonics, dim=2) * A_upsampled

    subtractive = self.subtractive_synthesis(S, self.hop_length)

    Y = additive + subtractive

    Y = self.fftconvolve(Y.squeeze(-1), reverb.view(1, -1))

# Part 3: Loss Function

Implement Multi-Scale Spectral Loss (DDSP Section 4.2.1)

Use torch.stft to help you.  Don't forget to squeeze() the input to the STFT to get rid of the singleton dimension at the end


In [None]:
## TODO: Fill this in!

# Part 4: Testing Example Code

In [None]:
## TODO: Fill this in!

# Part 5: Train Loop

Put it all together!

In [None]:
## TODO: Fill this in!

# Musical Statement

In [None]:
## TODO: Fill this in!  Have fun!