Skip to content

Commit

Permalink
add autograd to biquad filters (#1400)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyololicon committed Mar 31, 2021
1 parent e4a0bd2 commit 52decd2
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 68 deletions.
106 changes: 102 additions & 4 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@


class Autograd(common_utils.TestBaseMixin):
def test_x_grad(self):
def test_lfilter_x(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_a_grad(self):
def test_lfilter_a(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_b_grad(self):
def test_lfilter_b(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_all_grad(self):
def test_lfilter_all_inputs(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
Expand All @@ -38,3 +38,101 @@ def test_all_grad(self):
a.requires_grad = True
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_biquad(self):
torch.random.manual_seed(2434)
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10)

def test_band_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q))

def test_band_biquad_with_noise(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True))

def test_bass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))

def test_treble_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))

def test_allpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q))

def test_lowpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q))

def test_highpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q))

def test_bandpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q))

def test_bandpass_biquad_with_const_skirt_gain(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True))

def test_equalizer_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q))

def test_bandreject_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q))

0 comments on commit 52decd2

Please sign in to comment.