Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add autograd to biquad filters #1400

Merged
merged 9 commits into from
Mar 31, 2021
106 changes: 102 additions & 4 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@


class Autograd(common_utils.TestBaseMixin):
def test_x_grad(self):
def test_lfilter_x(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_a_grad(self):
def test_lfilter_a(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
a.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_b_grad(self):
def test_lfilter_b(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
b.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_all_grad(self):
def test_lfilter_all_inputs(self):
torch.random.manual_seed(2434)
x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
Expand All @@ -38,3 +38,101 @@ def test_all_grad(self):
a.requires_grad = True
x.requires_grad = True
assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

def test_biquad(self):
torch.random.manual_seed(2434)
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True)
b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10)

def test_band_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q))

def test_band_biquad_with_noise(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True))

def test_bass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))

def test_treble_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))

def test_allpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q))

def test_lowpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q))

def test_highpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q))

def test_bandpass_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q))

def test_bandpass_biquad_with_const_skirt_gain(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True))

def test_equalizer_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q))

def test_bandreject_biquad(self):
torch.random.manual_seed(2434)
sr = 22050
x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q))