In [47]:
import torch
import math
from typing import List, Tuple
import torchaudio
from IPython.display import Audio
import dasp_pytorch
import numpy as np
from functools import partial
from time import time

In [None]:
import tensorboard

In [48]:
x, sr = torchaudio.load("examples/voice_raw.wav")
x = x.unsqueeze(0)

Audio(x[0].numpy(), rate=sr)

In [92]:

def noise_shaped_reverberation(
    x: torch.Tensor,
    sample_rate: float,
    band0_gain: torch.Tensor,
    band1_gain: torch.Tensor,
    band2_gain: torch.Tensor,
    band3_gain: torch.Tensor,
    band4_gain: torch.Tensor,
    band5_gain: torch.Tensor,
    band6_gain: torch.Tensor,
    band7_gain: torch.Tensor,
    band8_gain: torch.Tensor,
    band9_gain: torch.Tensor,
    band10_gain: torch.Tensor,
    band11_gain: torch.Tensor,
    band0_decay: torch.Tensor,
    band1_decay: torch.Tensor,
    band2_decay: torch.Tensor,
    band3_decay: torch.Tensor,
    band4_decay: torch.Tensor,
    band5_decay: torch.Tensor,
    band6_decay: torch.Tensor,
    band7_decay: torch.Tensor,
    band8_decay: torch.Tensor,
    band9_decay: torch.Tensor,
    band10_decay: torch.Tensor,
    band11_decay: torch.Tensor,
    mix: torch.Tensor,
    num_samples: int = 65536,
    num_bandpass_taps: int = 1023,
):
    """Artificial reverberation using frequency-band noise shaping.

    This differentiable artificial reverberation model is based on the idea of
    filtered noise shaping, similar to that proposed in [1]. This approach leverages
    the well known idea that a room impulse response (RIR) can be modeled as the direct sound,
    a set of early reflections, and a decaying noise-like tail [2].

    [1] Steinmetz, Christian J., Vamsi Krishna Ithapu, and Paul Calamia.
        "Filtered noise shaping for time domain room impulse response estimation from reverberant speech."
        2021 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA). IEEE, 2021.

    [2] Moorer, James A.
        "About this reverberation business."
        Computer Music Journal (1979): 13-28.

    Args:
        x (torch.Tensor): Input audio signal. Shape (bs, chs, seq_len).
        sample_rate (float): Audio sample rate.
        band0_gain (torch.Tensor): Gain for first octave band on (0,1). Shape (bs, 1).
        band1_gain (torch.Tensor): Gain for second octave band on (0,1). Shape (bs, 1).
        band2_gain (torch.Tensor): Gain for third octave band on (0,1). Shape (bs, 1).
        band3_gain (torch.Tensor): Gain for fourth octave band on (0,1). Shape (bs, 1).
        band4_gain (torch.Tensor): Gain for fifth octave band on (0,1). Shape (bs, 1).
        band5_gain (torch.Tensor): Gain for sixth octave band on (0,1). Shape (bs, 1).
        band6_gain (torch.Tensor): Gain for seventh octave band on (0,1). Shape (bs, 1).
        band7_gain (torch.Tensor): Gain for eighth octave band on (0,1). Shape (bs, 1).
        band8_gain (torch.Tensor): Gain for ninth octave band on (0,1). Shape (bs, 1).
        band9_gain (torch.Tensor): Gain for tenth octave band on (0,1). Shape (bs, 1).
        band10_gain (torch.Tensor): Gain for eleventh octave band on (0,1). Shape (bs, 1).
        band11_gain (torch.Tensor): Gain for twelfth octave band on (0,1). Shape (bs, 1).
        band0_decays (torch.Tensor): Decay parameter for first octave band (0,1). Shape (bs, 1).
        band1_decays (torch.Tensor): Decay parameter for second octave band (0,1). Shape (bs, 1).
        band2_decays (torch.Tensor): Decay parameter for third octave band (0,1). Shape (bs, 1).
        band3_decays (torch.Tensor): Decay parameter for fourth octave band (0,1). Shape (bs, 1).
        band4_decays (torch.Tensor): Decay parameter for fifth octave band (0,1). Shape (bs, 1).
        band5_decays (torch.Tensor): Decay parameter for sixth octave band (0,1). Shape (bs, 1).
        band6_decays (torch.Tensor): Decay parameter for seventh octave band (0,1). Shape (bs, 1).
        band7_decays (torch.Tensor): Decay parameter for eighth octave band (0,1). Shape (bs, 1).
        band8_decays (torch.Tensor): Decay parameter for ninth octave band (0,1). Shape (bs, 1).
        band9_decays (torch.Tensor): Decay parameter for tenth octave band (0,1). Shape (bs, 1).
        band10_decays (torch.Tensor): Decay parameter for eleventh octave band (0,1). Shape (bs, 1).
        band11_decays (torch.Tensor): Decay parameter for twelfth octave band (0,1). Shape (bs, 1).
        mix (torch.Tensor): Mix between dry and wet signal. Shape (bs, 1).
        num_samples (int, optional): Number of samples to use for IR generation. Defaults to 88200.
        num_bandpass_taps (int, optional): Number of filter taps for the octave band filterbank filters. Must be odd. Defaults to 1023.

    Returns:
        y (torch.Tensor): Reverberated signal. Shape (bs, chs, seq_len).

    """
    assert num_bandpass_taps % 2 == 1, "num_bandpass_taps must be odd"

    bs, chs, seq_len = x.size()
    assert chs <= 2, "only mono/stereo signals are supported"

    # if mono copy to stereo
    if chs == 1:
        x = x.repeat(1, 2, 1)
        chs = 2

    # stack gains and decays into a single tensor
    start_time = time()
    band_gains = torch.stack(
        [
            band0_gain,
            band1_gain,
            band2_gain,
            band3_gain,
            band4_gain,
            band5_gain,
            band6_gain,
            band7_gain,
            band8_gain,
            band9_gain,
            band10_gain,
            band11_gain,
        ],
        dim=1,
    )
    band_gains = band_gains.unsqueeze(-1)

    band_decays = torch.stack(
        [
            band0_decay,
            band1_decay,
            band2_decay,
            band3_decay,
            band4_decay,
            band5_decay,
            band6_decay,
            band7_decay,
            band8_decay,
            band9_decay,
            band10_decay,
            band11_decay,
        ],
        dim=1,
    )
    band_decays = band_decays.unsqueeze(-1)
    print("stacking time: ", time() - start_time)

    # create the octave band filterbank filters
    filters = dasp_pytorch.signal.octave_band_filterbank(1023, sample_rate)
    filters = filters.type_as(x)
    num_bands = filters.shape[0]

    # reshape gain, decay, and mix parameters
    band_gains = band_gains.view(bs, 1, num_bands, 1)
    band_decays = band_decays.view(bs, 1, num_bands, 1)
    mix = mix.view(bs, 1, 1)

    # generate white noise for IR generation
    start_time = time()
    pad_size = num_bandpass_taps - 1
    wn = torch.randn(bs * 2, num_bands, num_samples + pad_size).type_as(x)
    print("white noise time: ", time() - start_time)

    # filter white noise signals with each bandpass filter
    start_time = time()
    wn_filt = torch.nn.functional.conv1d(
        wn,
        filters,
        groups=num_bands,
        # padding=self.num_taps -1,
    )
    # shape: (bs * 2, num_bands, num_samples)
    wn_filt = wn_filt.view(bs, 2, num_bands, num_samples)
    print("filtering time: ", time() - start_time)

    # apply bandwise decay parameters (envelope)
    start_time = time()
    t = torch.linspace(0, 1, steps=num_samples).type_as(x)  # timesteps
    band_decays = (band_decays * 10.0) + 1.0
    env = torch.exp(-band_decays * t.view(1, 1, 1, -1))
    wn_filt *= env * band_gains
    print("decay time: ", time() - start_time)

    # sum signals to create impulse shape: bs, 2, 1, num_samp
    w_filt_sum = wn_filt.mean(2, keepdim=True)

    # apply impulse response for each batch item (vectorized)
    start_time = time()

    x_pad = torch.nn.functional.pad(x, (num_samples - 1, 0))
    vconv1d = torch.vmap(partial(torch.nn.functional.conv1d, groups=2), in_dims=0)
    y = vconv1d(x_pad, torch.flip(w_filt_sum, dims=[-1]))
    print("conv time: ", time() - start_time)

    # create a wet/dry mix
    y = (1 - mix) * x + mix * y

    return y

In [93]:
band_gains = [torch.rand((1,1)) for _ in range(12)]
band_decays = [torch.rand((1,1)) for _ in range(12)]
mix = torch.rand((1,1))

In [94]:
y = noise_shaped_reverberation(x, sr, *band_gains, *band_decays, mix)

stacking time:  9.942054748535156e-05
white noise time:  0.008884668350219727
filtering time:  0.042407989501953125
decay time:  0.0008034706115722656
conv time:  0.00077056884765625


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

tensor([[[-0.0059, -0.0064, -0.0058,  ..., -0.0184,  0.0228,  0.0301],
         [-0.0057, -0.0057, -0.0063,  ...,  0.0222, -0.0022, -0.0021]]])

In [27]:
Audio(y[0].numpy(), rate=sr)


In [23]:
y.abs().max()

tensor(0.1662)

In [20]:
y_dist = dasp_pytorch.distortion(x, sr, torch.tensor([[[40]]]))
Audio(y_dist[0].numpy(), rate=sr)

In [2]:
def biqaud(
    gain_dB: torch.Tensor,
    cutoff_freq: torch.Tensor,
    q_factor: torch.Tensor,
    sample_rate: float,
    filter_type: str = "peaking",
):

    # convert inputs to Tensors if needed
    # gain_dB = torch.tensor([gain_dB])
    # cutoff_freq = torch.tensor([cutoff_freq])
    # q_factor = torch.tensor([q_factor])

    A = 10 ** (gain_dB / 40.0)
    w0 = 2 * math.pi * (cutoff_freq / sample_rate)
    alpha = torch.sin(w0) / (2 * q_factor)
    cos_w0 = torch.cos(w0)
    sqrt_A = torch.sqrt(A)

    if filter_type == "high_shelf":
        b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
        b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
        b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
        a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
        a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
        a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
    elif filter_type == "low_shelf":
        b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
        b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
        b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
        a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
        a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
        a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
    elif filter_type == "peaking":
        b0 = 1 + alpha * A
        b1 = -2 * cos_w0
        b2 = 1 - alpha * A
        a0 = 1 + (alpha / A)
        a1 = -2 * cos_w0
        a2 = 1 - (alpha / A)
    else:
        raise ValueError(f"Invalid filter_type: {filter_type}.")

    b = torch.stack([b0, b1, b2], dim=0).view(-1)
    a = torch.stack([a0, a1, a2], dim=0).view(-1)

    # normalize
    b = b.type_as(gain_dB) / a0
    a = a.type_as(gain_dB) / a0

    return b, a


def freqz(b, a, n_fft: int = 512):

    B = torch.fft.rfft(b, n_fft)
    A = torch.fft.rfft(a, n_fft)

    H = B / A

    return H


def freq_domain_filter(x, H, n_fft):

    X = torch.fft.rfft(x, n_fft)

    # move H to same device as input x
    H = H.type_as(X)

    Y = X * H

    y = torch.fft.irfft(Y, n_fft)

    return y


def approx_iir_filter(b, a, x):
    """Approimxate the application of an IIR filter.

    Args:
        b (Tensor): The numerator coefficients.

    """

    # round up to nearest power of 2 for FFT
    # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))

    n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
    n_fft = n_fft.int()

    # move coefficients to same device as x
    b = b.type_as(x).view(-1)
    a = a.type_as(x).view(-1)

    # compute complex response
    H = freqz(b, a, n_fft=n_fft).view(-1)

    # apply filter
    y = freq_domain_filter(x, H, n_fft)

    # crop
    y = y[: x.shape[-1]]

    return y


def approx_iir_filter_cascade(
    b_s: List[torch.Tensor],
    a_s: List[torch.Tensor],
    x: torch.Tensor,
):
    """Apply a cascade of IIR filters.

    Args:
        b (list[Tensor]): List of tensors of shape (3)
        a (list[Tensor]): List of tensors of (3)
        x (torch.Tensor): 1d Tensor.
    """

    if len(b_s) != len(a_s):
        raise RuntimeError(
            f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}."
        )

    # round up to nearest power of 2 for FFT
    # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1))
    n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
    n_fft = n_fft.int()

    # this could be done in parallel
    b = torch.stack(b_s, dim=0).type_as(x)
    a = torch.stack(a_s, dim=0).type_as(x)

    H = freqz(b, a, n_fft=n_fft)
    H = torch.prod(H, dim=0).view(-1)

    # apply filter
    y = freq_domain_filter(x, H, n_fft)

    # crop
    y = y[: x.shape[-1]]

    return y


In [3]:
def compute_biquad_coeff(f, R, mLP, mBP, mHP):
    if f.dim() == 1:
        f = torch.unsqueeze(f, dim=-1)
        R = torch.unsqueeze(R, dim=-1)
        mLP = torch.unsqueeze(mLP, dim=-1)
        mBP = torch.unsqueeze(mBP, dim=-1)
        mHP = torch.unsqueeze(mHP, dim=-1)

    K = f.size(1) # number of cascaded filters
    bs = f.size(0) # batch size  
    beta = torch.zeros(bs, 3, K, device=f.device)     
    alpha = torch.zeros(bs, 3, K, device=f.device)  

    beta[:,0,:] = (f**2) * mLP + f * mBP + mHP
    beta[:,1,:] = 2*(f**2) * mLP - 2 * mHP
    beta[:,2,:] = (f**2) * mLP - f * mBP + mHP
    alpha[:,0,:] = f**2 + 2*R*f + 1
    alpha[:,1,:] = 2* (f**2) - 2
    alpha[:,2,:] = f**2 - 2*R*f + 1  

    beta = torch.complex(beta, torch.zeros(beta.size(), device=f.device))
    alpha = torch.complex(alpha, torch.zeros(alpha.size(), device=f.device))
    return beta, alpha

In [4]:
def SAP(x, m, gamma):
    ''' differentiable Schoreder allpass filter 
    x: frequency sampling points 
    m: delay lengths
    gamma: feed-forward/back gains'''
    M = m.size(0) # number of channels
    K = m.size(1) # number of cascaded filters 
    bs = gamma.size(0) # batch size
    # compute transfer function of first filter
    zK = torch.pow(x.expand(M,-1).permute(1, 0),-m[:,0]).expand(bs, -1, -1)     # this step instroudces some numerical errors so that the absolute value is no longer <= 1
    gammaK = (gamma[:,:,0].expand(x.size(0), -1, -1)).permute(1, 0, 2)
    H = torch.div(
        (gammaK + zK), 
        (1 + gammaK * zK))
    # compute all the other SAP filters in series 
    for k in range(1, K):
        zK = torch.pow(x.expand(M,-1).permute(1, 0),-m[:,k]).expand(bs, -1, -1)
        gammaK = (gamma[:,:,k].expand(x.size(0), -1, -1)).permute(1, 0, 2)
        Hi = torch.div(
            (gammaK + zK),
            (1 + gammaK * zK))
        # element-wise mul to compute overall system's transfer function
        H = torch.mul(H, Hi)    
    return H

In [5]:
def biquad_to_tf(x, beta, alpha):
    # TODO: too many transpose operations. they can be removed
        H = torch.div(
                torch.matmul(
                    torch.pow(x.expand((3 ,-1)).transpose(1, 0), 
                        torch.tensor([0, -1, -2], device=x.device)),
                    beta.transpose(1,0)),
                torch.matmul(
                    torch.pow(x.expand((3 ,-1)).transpose(1, 0), 
                        torch.tensor([0, -1, -2], device=x.device)),
                    alpha.transpose(1,0))
            )
        return H.transpose(1, 0) 

In [6]:
def PEQ(x, f, R, G):
    ''' differentiable parameteric equilizer
    x: frequency sampling points
    f: cut off frequency 
    R: resonance 
    G: component gain
    f entries at initialization must be in ascending order '''
    # TODO: for some reason only when coefficents alpha and beta are swapped 
    # we obtain the proper response

    K = f.size(1)   # number of filters in series
    bs = f.size(0)   # number of channels
    # prevent the shelf filters magniture response to spkie over 1
    R[:,0] = R[:,0] + 1/torch.sqrt(torch.tensor(2, device=x.device)) 
    R[:,-1] = R[:,-1] + 1/torch.sqrt(torch.tensor(2, device=x.device)) 
    # TODO: it seems that low values of R give better looking respononses (low: close to 1/sqrt(2))
    # maybe the activation should not push R to high values 
    # Watch out that the formula used for the  biquad coeff of LP and HP filters are swapped 
    # low shelf filter (as flipped high shelf)
    betaLP, alphaLP = compute_biquad_coeff(
        # [:,0], R[:,0], torch.tensor(1, device=get_device()), 2 * R[:,0] * torch.sqrt(G[:,0]), G[:,0] )
        f[:,0], R[:,0], G[:,0], 2 * R[:,0] * torch.sqrt(G[:,0]), torch.tensor(1, device=x.device) )
    HHP = biquad_to_tf(x, betaLP[:,:,0], alphaLP[:,:,0]) 
    H = HHP 
    # high shelf filter (as flipped low shelf)
    
    betaHP, alphaHP = compute_biquad_coeff(
        # f[:,-1], R[:,-1], G[:,-1], 2 * R[:,-1] * torch.sqrt(G[:,-1]), torch.tensor(1, device=get_device()) )
        f[:,-1], R[:,-1], torch.tensor(1, device=x.device), 2 * R[:,-1] * torch.sqrt(G[:,-1]), G[:,-1] )  
    HLP = biquad_to_tf(x, betaHP[:,:,0], alphaHP[:,:,0])
    H = H*HLP

    # K - 2 peaking filter 
    for k in range(1, K-1):
        beta, alpha = compute_biquad_coeff(
            f[:,k], R[:,k], torch.tensor(1, device=x.device), 2*R[:,k]*G[:,k], torch.tensor(1, device=x.device))
        Hp = biquad_to_tf(x, beta[:,:,0], alpha[:,:,0])
        H = H*Hp
    return H

In [41]:
def delay_network(x : torch.Tensor,
                  sample_rate : float,
                  d : torch.Tensor,
                  dAP : torch.Tensor,
                  f_C1 : torch.Tensor,
                  R_C1 : torch.Tensor,
                  m_C1_LP : torch.Tensor,
                  m_C1_HP : torch.Tensor,
                  m_C1_BP : torch.Tensor,
                  f_CDelta : torch.Tensor,
                  R_CDelta : torch.Tensor,
                  G_CDelta : torch.Tensor,
                  b : torch.Tensor,
                  c : torch.Tensor,
                  gamma : torch.Tensor,
                  h0 : torch.Tensor,
                  Q0 : torch.Tensor,) -> torch.Tensor:
    """ 
    x.shape = (1, T)
    d.shape = (1, M)
    dAP.shape = (M, K_U)
    f_C1.shape = (1, K_C1)
    R_C1.shape = (1, K_C1)
    m_C1_LP.shape = (1, K_C1)
    m_C1_HP.shape = (1, K_C1)
    m_C1_BP.shape = (1, K_C1)
    f_CDelta.shape = (1, K_CDelta)
    R_CDelta.shape = (1, K_CDelta)
    G_CDelta.shape = (1, K_CDelta)
    b.shape = (1, M)
    c.shape = (1, M)
    gamma.shape =  (M, K_U)   
    h0.shape = (Z)
    Q0.shape = (1, M, M)
    """
    n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
    n_fft = n_fft.int()

    M = d.size(1)
    
    # Construct C_1 filter
    beta_C1, alpha_C1 = compute_biquad_coeff(f_C1, R_C1, m_C1_LP, m_C1_BP, m_C1_HP)
    beta_C1 = torch.real(beta_C1.transpose_(2, 1))
    alpha_C1 = torch.real(alpha_C1.transpose_(2, 1))
    H_C1 = freqz(beta_C1, alpha_C1, n_fft=n_fft).squeeze(0)
    C_1 = torch.prod(H_C1, dim=0).unsqueeze(1).unsqueeze(2)
    # C_1.shape = (n_fft//2+1, 1, 1)

    # Construct D_N
    angle = 2*torch.pi*torch.arange(0, n_fft//2+1, device=x.device)/n_fft
    A_N = torch.exp(1j * torch.einsum('i,j->ij', angle, d.squeeze()))
    D_N_inv = torch.diag_embed(A_N, dim1=-2, dim2=-1)
    # D_N_inv.shape = (n_fft//2+1, M, M)

    # Construct U_N
    z = torch.exp(-1j * angle)
    U_N = SAP(z, dAP, gamma.unsqueeze(0)).transpose(1, 0).squeeze(1)
    # U_N.shape = (n_fft//2+1, M)

    # Construct C_Delta
    H_CDelta = PEQ(z, f_CDelta, R_CDelta, G_CDelta)
    C_delta = torch.einsum('ki,j->ijk', H_CDelta, torch.ones(M, device=x.device)).squeeze(2)
    # C_delta.shape = (n_fft//2+1, M)
    
    # batch-wise matrix multiplication
    UCDelta = torch.diag_embed(U_N * C_delta, dim1=-1, dim2=-2)
    batch_inv = torch.inverse(D_N_inv - torch.matmul(Q0 + 0j, UCDelta))
    H_z = torch.matmul(C_1, torch.matmul(torch.matmul(c + 0j, batch_inv), b.transpose(1,0) + 0j)).squeeze(1).squeeze(1)
    print(C_1.isnan().sum())
    
    # Bypass FIR H_0
    H_0 = torch.fft.rfft(h0, n_fft)

    # Final response
    H = H_z + H_0
    X = torch.fft.rfft(x, n_fft)
    Y = H * X
    y = torch.fft.irfft(Y.squeeze(), n_fft)
    y = y[: x.shape[-1]]

    return y


In [152]:

x, fs = torchaudio.load('examples/paganini.wav')


d = torch.tensor([[233, 311, 421, 461, 587, 613]])
tau = torch.tensor([[131, 151, 337, 353], [103, 173, 331, 373], [89, 181, 307, 401], [79, 197, 281, 419], [61, 211, 257, 431], [47, 229, 251, 443]])

K_C1 = 8
K_CDelta = 8
K_U = 4
M = 6

f_C1 = torch.tensor([[28., 54., 134., 256., 521., 634., 768., 1098.] ])
R_C1 = torch.abs(torch.randn(1, K_C1))
m_C1_LP = torch.abs(torch.randn(1, K_C1))
m_C1_HP = torch.abs(torch.randn(1, K_C1))
m_C1_BP = torch.abs(torch.randn(1, K_C1))

f_CDelta = torch.tensor([[28., 54., 134., 256., 521., 634., 768., 1098.] ])
R_CDelta = torch.abs(torch.randn(1, K_CDelta)) + math.sqrt(2)/2
G_CDelta = 10**(-torch.abs(torch.randn(1, K_CDelta)))

Q_0 = 1/3 * torch.ones(6,6) - torch.eye(6)


b = torch.randn(1, M)
c = torch.randn(1, M)
gamma = torch.rand(M, K_U)
h0 = torch.randn(100)

y = delay_network(x, fs, d, tau, f_C1, R_C1, m_C1_LP, m_C1_HP, m_C1_BP, f_CDelta, R_CDelta, G_CDelta, b, c, gamma, h0, Q_0)

tensor(0)


In [44]:
y

tensor([0.0054, 0.0086, 0.0242,  ..., 0.0042, 0.0020, 0.0016])

In [45]:
from IPython.display import display, Audio

x_np = x.squeeze().numpy()
Audio(x_np, rate=fs)

In [46]:
y_np = y.squeeze().numpy()
Audio(y_np, rate=fs)

In [171]:
def schroeder(x : torch.Tensor,
              fs : float,
              Tr_comb : torch.Tensor,
              Tr_ap : torch.Tensor,
              m_comb : torch.Tensor,
              m_ap : torch.Tensor,):
    
    n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1)))
    n_fft = n_fft.int()

    g_comb= 10**(-3/Tr_comb/fs * m_comb)
    # g_comb.shape = (1, K_Comb)
    g_ap = (1 - 7/Tr_ap/fs)**m_ap
    # g_ap.shape = (1, K_AP)

    z = torch.exp(1j*2*torch.pi*torch.arange(0, n_fft//2+1, device=x.device)/n_fft)

    z_comb = torch.pow(z.expand(m_comb.size(1), -1).permute(1, 0), -m_comb)
    H_comb = z_comb / (1 - g_comb * z_comb)
    H_comb = torch.sum(H_comb, dim=1)

    z_ap = torch.pow(z.expand(m_ap.size(1), -1).permute(1, 0), -m_ap)
    H_ap = (-g_ap + z_ap) / (1 - g_ap * z_ap)
    H_ap = torch.prod(H_ap, dim=1)

    H = H_comb * H_ap

    X = torch.fft.rfft(x, n_fft)
    Y = H * X
    y = torch.fft.irfft(Y.squeeze(), n_fft)
    y = y[: x.shape[-1]]

    return y

In [181]:
x, fs = torchaudio.load('examples/voice_raw.wav')


In [197]:
Tr_comb = .6
Tr_ap = .5
m_comb = (torch.tensor([[29.7, 37.1, 41.4, 43.7]]) * 1e-3 * fs).to(int)
m_ap = (torch.tensor([[96.83, 32.92]]) * 1e-3 * fs).to(int)

In [198]:
y = schroeder(x, fs, Tr_comb, Tr_ap, m_comb, m_ap)

In [199]:
y_np = y.squeeze().numpy()
Audio(y_np, rate=fs)