Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ format: tool

test: tool
[ -n "$(MODULE)" ] && module=tests/test_$(MODULE).py || module=; \
. ./venv/bin/activate && export PATH=tools/SPTK/bin:$$PATH NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0 && \
. ./venv/bin/activate && export PATH=./tools/SPTK/bin:$$PATH CUDA_HOME=/usr/local/cuda-11.8 NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0 && \
python -m pytest $$module $(OPT)

test-clean:
Expand Down
48 changes: 48 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,54 @@ def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs):
)


def hilbert(x, fft_length=None, dim=-1):
"""Compute analytic signal using the Hilbert transform.

Parameters
----------
x : Tensor [shape=(..., T, ...)]
Input signal.

fft_length : int >= 1 or None
Number of FFT bins. If None, set to :math:`T`.

dim : int
Dimension along which to take the Hilbert transform.

Returns
-------
out : Tensor [shape=(..., T, ...)]
Analytic signal, where real part is the input signal and imaginary part is
the Hilbert transform of the input signal.

"""
return nn.HilbertTransform._func(x, fft_length=fft_length, dim=dim)


def hilbert2(x, fft_length=None, dim=(-2, -1)):
"""Compute analytic signal using the Hilbert transform.

Parameters
----------
x : Tensor [shape=(..., T1, T2, ...)]
Input signal.

fft_length : int, list[int], or None
Number of FFT bins. If None, set to (:math:`T1`, :math:`T2`).

dim : list[int]
Dimensions along which to take the Hilbert transform.

Returns
-------
out : Tensor [shape=(..., T1, T2, ...)]
Analytic signal, where real part is the input signal and imaginary part is
the Hilbert transform of the input signal.

"""
return nn.TwoDimensionalHilbertTransform._func(x, fft_length=fft_length, dim=dim)


def histogram(x, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3):
"""Compute histogram.

Expand Down
2 changes: 2 additions & 0 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from .gmm import GaussianMixtureModeling as GMM
from .gnorm import GeneralizedCepstrumGainNormalization
from .grpdelay import GroupDelay
from .hilbert import HilbertTransform
from .hilbert2 import TwoDimensionalHilbertTransform
from .histogram import Histogram
from .ialaw import ALawExpansion
from .icqt import InverseConstantQTransform
Expand Down
3 changes: 0 additions & 3 deletions diffsptk/modules/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def forward(self, x):
x : Tensor [shape=(..., T, ...)]
Signal.

dim : int
Dimension along which to decimate the tensors.

Returns
-------
out : Tensor [shape=(..., T/P-S, ...)]
Expand Down
96 changes: 96 additions & 0 deletions diffsptk/modules/hilbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..misc.utils import to


class HilbertTransform(nn.Module):
"""Hilbert transform module.

Parameters
----------
fft_length : int >= 1
Number of FFT bins, should be :math:`T`.

dim : int
Dimension along which to take the Hilbert transform.

"""

def __init__(self, fft_length, dim=-1):
super().__init__()

self.dim = dim
self.register_buffer("h", self._precompute(fft_length))

def forward(self, x):
"""Compute analytic signal using the Hilbert transform.

Parameters
----------
x : Tensor [shape=(..., T, ...)]
Input signal.

Returns
-------
out : Tensor [shape=(..., T, ...)]
Analytic signal, where real part is the input signal and imaginary part is
the Hilbert transform of the input signal.

Examples
--------
>>> x = diffsptk.nrand(3)
>>> x
tensor([ 1.1809, -0.2834, -0.4169, 0.3883])
>>> hilbert = diffsptk.HilbertTransform(4)
>>> z = hilbert(x)
>>> z.real
tensor([ 1.1809, -0.2834, -0.4169, 0.3883])
>>> z.imag
tensor([ 0.3358, 0.7989, -0.3358, -0.7989])

"""
return self._forward(x, self.h, self.dim)

@staticmethod
def _forward(x, h, dim):
L = len(h)
target_shape = [1] * x.dim()
target_shape[dim] = L
h = h.view(*target_shape)
X = torch.fft.fft(x, n=L, dim=dim)
z = torch.fft.ifft(X * h, n=L, dim=dim)
return z

@staticmethod
def _func(x, fft_length, dim):
if fft_length is None:
fft_length = x.size(dim)
h = HilbertTransform._precompute(fft_length, dtype=x.dtype, device=x.device)
return HilbertTransform._forward(x, h, dim)

@staticmethod
def _precompute(fft_length, dtype=None, device=None):
h = torch.zeros(fft_length, dtype=torch.double, device=device)
center = (fft_length + 1) // 2
h[0] = 1
h[1:center] = 2
if fft_length % 2 == 0:
h[center] = 1
return to(h, dtype=dtype)
105 changes: 105 additions & 0 deletions diffsptk/modules/hilbert2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..misc.utils import to
from .hilbert import HilbertTransform


class TwoDimensionalHilbertTransform(nn.Module):
"""2-D Hilbert transform module.

Parameters
----------
fft_length : int or list[int]
Number of FFT bins.

dim : list[int]
Dimensions along which to take the Hilbert transform.

"""

def __init__(self, fft_length, dim=(-2, -1)):
super().__init__()

assert len(dim) == 2

self.dim = dim
self.register_buffer("h", self._precompute(fft_length))

def forward(self, x):
"""Compute analytic signal using the Hilbert transform.

Parameters
----------
x : Tensor [shape=(..., T1, T2, ...)]
Input signal.

Returns
-------
out : Tensor [shape=(..., T1, T2, ...)]
Analytic signal, where real part is the input signal and imaginary part is
the Hilbert transform of the input signal.

Examples
--------
>>> x = diffsptk.nrand(3)
>>> x
tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]])
>>> hilbert2 = diffsptk.TwoDimensionalHilbertTransform((1, 4))
>>> z = hilbert2(x)
>>> z.real
tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]])
>>> z.imag
tensor([[ 0.3358, 0.7989, -0.3358, -0.7989]])

"""
return self._forward(x, self.h, self.dim)

@staticmethod
def _forward(x, h, dim):
L = h.size(dim[0]), h.size(dim[1])
target_shape = [1] * x.dim()
target_shape[dim[0]] = L[0]
target_shape[dim[1]] = L[1]
h = h.view(*target_shape)
X = torch.fft.fft2(x, s=L, dim=dim)
z = torch.fft.ifft2(X * h, s=L, dim=dim)
return z

@staticmethod
def _func(x, fft_length, dim):
if fft_length is None:
fft_length = (x.size(dim[0]), x.size(dim[1]))
h = TwoDimensionalHilbertTransform._precompute(
fft_length, dtype=x.dtype, device=x.device
)
return TwoDimensionalHilbertTransform._forward(x, h, dim)

@staticmethod
def _precompute(fft_length, dtype=None, device=None):
if isinstance(fft_length, int):
fft_length = (fft_length, fft_length)
h1 = HilbertTransform._precompute(
fft_length[0], dtype=torch.double, device=device
)
h2 = HilbertTransform._precompute(
fft_length[1], dtype=torch.double, device=device
)
h = h1.unsqueeze(1) * h2.unsqueeze(0)
return to(h, dtype=dtype)
13 changes: 13 additions & 0 deletions docs/modules/hilbert.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _hilbert:

hilbert
=======

.. autoclass:: diffsptk.HilbertTransform
:members:

.. autofunction:: diffsptk.functional.hilbert

.. seealso::

:ref:`hilbert2`
13 changes: 13 additions & 0 deletions docs/modules/hilbert2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _hilbert2:

hilbert2
========

.. autoclass:: diffsptk.TwoDimensionalHilbertTransform
:members:

.. autofunction:: diffsptk.functional.hilbert2

.. seealso::

:ref:`hilbert`
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"soundfile >= 0.10.2",
"torch >= 2.0.0",
"torchaudio >= 2.0.1",
"torchcrepe >= 0.0.21",
"torchcrepe >= 0.0.22",
"torchlpc >= 0.2.0",
"torchcomp >= 0.1.0",
"vector-quantize-pytorch >= 1.14.9",
Expand Down
51 changes: 51 additions & 0 deletions tests/test_hilbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest
from scipy.signal import hilbert as scipy_hilbert

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
@pytest.mark.parametrize("L", [7, 8, None])
def test_compatibility(device, module, L, B=2):
if module and L is None:
return

hilbert = U.choice(
module,
diffsptk.HilbertTransform,
diffsptk.functional.hilbert,
{"fft_length": L},
)

def func(x):
return scipy_hilbert(x, N=L)

if L is None:
L = 8

U.check_confidence(
device,
hilbert,
func,
[B, L],
)

U.check_differentiability(device, [lambda x: x.real, hilbert], [B, L])
Loading