diff --git a/.gitignore b/.gitignore index 6a2606d9..f9c33561 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ coverage.xml # docs _build/ +# tests +*.wav + # tools tools/**/ diff --git a/Makefile b/Makefile index 607e2980..b45a0ad2 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ PROJECT := diffsptk MODULE := +OPT := PYTHON_VERSION := 3.9 TORCH_VERSION := 2.0.0 @@ -63,7 +64,7 @@ format: tool test: tool [ -n "$(MODULE)" ] && module=tests/test_$(MODULE).py || module=; \ - . ./venv/bin/activate && export PATH=tools/SPTK/bin:$$PATH && python -m pytest $$module + . ./venv/bin/activate && export PATH=tools/SPTK/bin:$$PATH && python -m pytest $$module $(OPT) test-clean: rm -rf tests/__pycache__ diff --git a/README.md b/README.md index 7d3b1882..ee4803a8 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,41 @@ diffsptk.write("voiced.wav", x_voiced, sr) diffsptk.write("unvoiced.wav", x_unvoiced, sr) ``` +### LPC analysis and synthesis + +```python +import diffsptk + +fl = 400 # Frame length. +fp = 80 # Frame period. +M = 24 # LPC dimensions. + +# Read waveform. +x, sr = diffsptk.read("assets/data.wav") + +# Estimate LPC of x. +frame = diffsptk.Frame(frame_length=fl, frame_period=fp) +window = diffsptk.Window(in_length=fl) +lpc = diffsptk.LPC(frame_length=fl, lpc_order=M, eps=1e-6) +a = lpc(window(frame(x))) + +# Convert to inverse filter coefficients. +norm0 = diffsptk.AllPoleToAllZeroDigitalFilterCoefficients(filter_order=M) +b = norm0(a) + +# Reconstruct x. +zerodf = diffsptk.AllZeroDigitalFilter(filter_order=M, frame_period=fp) +poledf = diffsptk.AllPoleDigitalFilter(filter_order=M, frame_period=fp) +x_hat = poledf(zerodf(x, b), a) + +# Write reconstructed waveform. +diffsptk.write("reconst.wav", x_hat, sr) + +# Compute error. +error = (x_hat - x).abs().sum() +print(error) +``` + ### Mel-spectrogram, MFCC, and PLP extraction ```python @@ -108,7 +143,7 @@ x, sr = diffsptk.read("assets/data.wav") stft = diffsptk.STFT(frame_length=fl, frame_period=fp, fft_length=n_fft) X = stft(x) -# Extract mel-spectrogram. +# Extract log mel-spectrogram. fbank = diffsptk.FBANK( n_channel=n_channel, fft_length=n_fft, @@ -196,6 +231,32 @@ error = (x_hat - x).abs().sum() print(error) ``` +### Modified discrete cosine transform + +```python +import diffsptk + +fl = 512 # Frame length. + +# Read waveform. +x, sr = diffsptk.read("assets/data.wav") + +# Transform x. +mdct = diffsptk.MDCT(fl) +c = mdct(x) + +# Reconstruct x. +imdct = diffpstk.IMDCT(fl) +x_hat = imdct(c, out_length=x.size(0)) + +# Write reconstructed waveform. +diffsptk.write("reconst.wav", x_hat, sr) + +# Compute error. +error = (x_hat - x).abs().sum() +print(error) +``` + ### Vector quantization ```python diff --git a/diffsptk/functional.py b/diffsptk/functional.py index 9062a259..150b6dd0 100644 --- a/diffsptk/functional.py +++ b/diffsptk/functional.py @@ -1005,7 +1005,7 @@ def lar2par(g): return nn.LogAreaRatioToParcorCoefficients._func(g) -def levdur(r): +def levdur(r, eps=1e-6): """Solve a Yule-Walker linear system. Parameters @@ -1013,13 +1013,16 @@ def levdur(r): r : Tensor [shape=(..., M+1)] Autocorrelation. + eps : float >= 0 + A small value to improve numerical stability. + Returns ------- out : Tensor [shape=(..., M+1)] Gain and LPC coefficients. """ - return nn.LevinsonDurbin._func(r) + return nn.LevinsonDurbin._func(r, eps=eps) def linear_intpl(x, upsampling_factor=80): @@ -1042,7 +1045,7 @@ def linear_intpl(x, upsampling_factor=80): return nn.LinearInterpolation._func(x, upsampling_factor=upsampling_factor) -def lpc(x, lpc_order): +def lpc(x, lpc_order, eps=1e-6): """Compute LPC coefficients. Parameters @@ -1053,13 +1056,16 @@ def lpc(x, lpc_order): lpc_order : int >= 0 Order of LPC, :math:`M`. + eps : float >= 0 + A small value to improve numerical stability. + Returns ------- out : Tensor [shape=(..., M+1)] Gain and LPC coefficients. """ - return nn.LinearPredictiveCodingAnalysis._func(x, lpc_order=lpc_order) + return nn.LinearPredictiveCodingAnalysis._func(x, lpc_order=lpc_order, eps=eps) def lpc2lsp(a, log_gain=False, sample_rate=None, out_format="radian"): diff --git a/diffsptk/modules/acorr.py b/diffsptk/modules/acorr.py index fa9814cd..dc15e6ce 100644 --- a/diffsptk/modules/acorr.py +++ b/diffsptk/modules/acorr.py @@ -84,7 +84,7 @@ def _forward(x, acr_order, norm, const): fft_length = x.size(-1) + acr_order if fft_length % 2 == 1: fft_length += 1 - X = torch.square(torch.fft.rfft(x, n=fft_length).abs()) + X = torch.fft.rfft(x, n=fft_length).abs().square() r = torch.fft.irfft(X)[..., : acr_order + 1] * const if norm: r = r / r[..., :1] diff --git a/diffsptk/modules/acr2csm.py b/diffsptk/modules/acr2csm.py index 8395fa92..9fb27aa1 100644 --- a/diffsptk/modules/acr2csm.py +++ b/diffsptk/modules/acr2csm.py @@ -76,6 +76,7 @@ def forward(self, r): @staticmethod def _forward(r, C): + assert r.dtype == torch.double u = torch.matmul(r, C) u1, u2 = torch.tensor_split(u, 2, dim=-1) diff --git a/diffsptk/modules/levdur.py b/diffsptk/modules/levdur.py index f29a6fe0..dcc08ef6 100644 --- a/diffsptk/modules/levdur.py +++ b/diffsptk/modules/levdur.py @@ -30,14 +30,18 @@ class LevinsonDurbin(nn.Module): lpc_order : int >= 0 Order of LPC coefficients, :math:`M`. + eps : float >= 0 + A small value to improve numerical stability. + """ - def __init__(self, lpc_order): + def __init__(self, lpc_order, eps=0): super().__init__() assert 0 <= lpc_order self.lpc_order = lpc_order + self.register_buffer("eye", self._precompute(self.lpc_order, eps)) def forward(self, r): """Solve a Yule-Walker linear system. @@ -65,14 +69,14 @@ def forward(self, r): """ check_size(r.size(-1), self.lpc_order + 1, "dimension of autocorrelation") - return self._forward(r) + return self._forward(r, self.eye) @staticmethod - def _forward(r): + def _forward(r, eye): r0, r1 = torch.split(r, [1, r.size(-1) - 1], dim=-1) # Make Toeplitz matrix. - R = symmetric_toeplitz(r[..., :-1]) # [..., M, M] + R = symmetric_toeplitz(r[..., :-1]) + eye # [..., M, M] # Solve system. a = torch.matmul(R.inverse(), -r1.unsqueeze(-1)).squeeze(-1) @@ -83,4 +87,13 @@ def _forward(r): a = torch.cat((K, a), dim=-1) return a - _func = _forward + @staticmethod + def _func(r, eps): + eye = LevinsonDurbin._precompute( + r.size(-1) - 1, eps, dtype=r.dtype, device=r.device + ) + return LevinsonDurbin._forward(r, eye) + + @staticmethod + def _precompute(order, eps, dtype=None, device=None): + return torch.eye(order, dtype=dtype, device=device) * eps diff --git a/diffsptk/modules/lpc.py b/diffsptk/modules/lpc.py index 300a867b..a72c8057 100644 --- a/diffsptk/modules/lpc.py +++ b/diffsptk/modules/lpc.py @@ -22,7 +22,7 @@ class LinearPredictiveCodingAnalysis(nn.Module): """See `this page `_ - for details. This module is a simple cascade of acorr and levdur. + for details. Double precision is recommended. Parameters ---------- @@ -32,14 +32,17 @@ class LinearPredictiveCodingAnalysis(nn.Module): lpc_order : int >= 0 Order of LPC, :math:`M`. + eps : float >= 0 + A small value to improve numerical stability. + """ - def __init__(self, frame_length, lpc_order): + def __init__(self, frame_length, lpc_order, eps=1e-6): super().__init__() self.lpc = nn.Sequential( Autocorrelation(frame_length, lpc_order), - LevinsonDurbin(lpc_order), + LevinsonDurbin(lpc_order, eps=eps), ) def forward(self, x): @@ -59,7 +62,7 @@ def forward(self, x): -------- >>> x = diffsptk.nrand(4) tensor([ 0.8226, -0.0284, -0.5715, 0.2127, 0.1217]) - >>> lpc = diffsptk.LPC(2, 5) + >>> lpc = diffsptk.LPC(5, 2) >>> a = lpc(x) >>> a tensor([0.8726, 0.1475, 0.5270]) @@ -68,7 +71,7 @@ def forward(self, x): return self.lpc(x) @staticmethod - def _func(x, lpc_order): + def _func(x, lpc_order, eps): r = Autocorrelation._func(x, lpc_order) - a = LevinsonDurbin._func(r) + a = LevinsonDurbin._func(r, eps) return a diff --git a/tests/test_acr2csm.py b/tests/test_acr2csm.py index 29575a64..357331c4 100644 --- a/tests/test_acr2csm.py +++ b/tests/test_acr2csm.py @@ -15,6 +15,7 @@ # ------------------------------------------------------------------------ # import pytest +import torch import diffsptk import tests.utils as U @@ -23,6 +24,9 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("module", [False, True]) def test_compatibility(device, module, M=25, L=100, B=2): + if torch.get_default_dtype() != torch.double: # pragma: no cover + return + acr2csm = U.choice( module, diffsptk.AutocorrelationToCompositeSinusoidalModelCoefficients, diff --git a/tests/test_csm2acr.py b/tests/test_csm2acr.py index 0675eb0a..8f3b5cd1 100644 --- a/tests/test_csm2acr.py +++ b/tests/test_csm2acr.py @@ -15,6 +15,7 @@ # ------------------------------------------------------------------------ # import pytest +import torch import diffsptk import tests.utils as U @@ -23,6 +24,9 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("module", [False, True]) def test_compatibility(device, module, M=25, L=100, B=2): + if torch.get_default_dtype() != torch.double: # pragma: no cover + return + csm2acr = U.choice( module, diffsptk.CompositeSinusoidalModelCoefficientsToAutocorrelation, diff --git a/tests/test_imdct.py b/tests/test_imdct.py index d9f7267f..54a0e9f1 100644 --- a/tests/test_imdct.py +++ b/tests/test_imdct.py @@ -15,6 +15,7 @@ # ------------------------------------------------------------------------ # import pytest +import torch import diffsptk import tests.utils as U @@ -34,9 +35,10 @@ def test_compatibility(device, module, window, L=512): mdct_params, ) + # torch.round is for float precision. U.check_compatibility( device, - [imdct, mdct], + [torch.round, imdct, mdct], [], "x2x +sd tools/SPTK/asset/data.short", "sopr", diff --git a/tests/test_istft.py b/tests/test_istft.py index b50cb389..1f7e51a1 100644 --- a/tests/test_istft.py +++ b/tests/test_istft.py @@ -17,6 +17,7 @@ from operator import itemgetter import pytest +import torch import diffsptk import tests.utils as U @@ -41,9 +42,10 @@ def test_compatibility(device, module, T=19200): stft_params, ) + # torch.round is for float precision. U.check_compatibility( device, - [itemgetter(slice(0, T)), istft, stft], + [torch.round, itemgetter(slice(0, T)), istft, stft], [], "x2x +sd tools/SPTK/asset/data.short", "sopr", diff --git a/tests/test_lpc.py b/tests/test_lpc.py index 2f29f90f..68f71a8e 100644 --- a/tests/test_lpc.py +++ b/tests/test_lpc.py @@ -20,8 +20,9 @@ import tests.utils as U +@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("module", [False, True]) -def test_compatibility(module, M=14, L=30, B=2): +def test_compatibility(device, module, M=14, L=30, B=2): lpc = U.choice( module, diffsptk.LPC, @@ -31,7 +32,7 @@ def test_compatibility(module, M=14, L=30, B=2): ) U.check_compatibility( - "cpu", + device, lpc, [], f"nrand -l {B*L}", @@ -40,3 +41,5 @@ def test_compatibility(module, M=14, L=30, B=2): dx=L, dy=M + 1, ) + + U.check_differentiability(device, lpc, [B, L]) diff --git a/tests/test_poledf.py b/tests/test_poledf.py index dba937e8..7189e660 100644 --- a/tests/test_poledf.py +++ b/tests/test_poledf.py @@ -14,6 +14,8 @@ # limitations under the License. # # ------------------------------------------------------------------------ # +import os + import numpy as np import pytest @@ -24,7 +26,7 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("module", [False, True]) @pytest.mark.parametrize("ignore_gain", [False, True]) -def test_compatibility(device, module, ignore_gain, M=3, T=100, P=10): +def test_compatibility(device, module, ignore_gain, M=24, P=80, L=400): poledf = U.choice( module, diffsptk.AllPoleDigitalFilter, @@ -36,11 +38,19 @@ def test_compatibility(device, module, ignore_gain, M=3, T=100, P=10): tmp1 = "poledf.tmp1" tmp2 = "poledf.tmp2" + T = os.path.getsize("tools/SPTK/asset/data.short") // 2 + cmd1 = f"nrand -l {T} > {tmp1}" + cmd2 = ( + f"x2x +sd tools/SPTK/asset/data.short | " + f"frame -p {P} -l {L} | " + f"window -w 1 -n 1 -l {L} | " + f"lpc -m {M} -l {L} > {tmp2}" + ) opt = "-k" if ignore_gain else "" U.check_compatibility( device, poledf, - [f"nrand -l {T} > {tmp1}", f"nrand -l {T//P*(M+1)} > {tmp2}"], + [cmd1, cmd2], [f"cat {tmp1}", f"cat {tmp2}"], f"poledf {tmp2} < {tmp1} -m {M} -p {P} {opt}", [f"rm {tmp1} {tmp2}"], @@ -48,4 +58,4 @@ def test_compatibility(device, module, ignore_gain, M=3, T=100, P=10): eq=lambda a, b: np.corrcoef(a, b)[0, 1] > 0.99, ) - U.check_differentiability(device, poledf, [(T,), (T // P, M + 1)]) + U.check_differentiability(device, poledf, [(P,), (1, M + 1)]) diff --git a/tools/Makefile b/tools/Makefile index adaaf1a7..ee94d007 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -14,7 +14,7 @@ # limitations under the License. # # ------------------------------------------------------------------------ # -TAPLO_VERSION := 0.8.1 +TAPLO_VERSION := 0.9.2 YAMLFMT_VERSION := 0.13.0 all: SPTK taplo yamlfmt