Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,23 @@ def interpolate(x, period=1, start=0, dim=-1):
return nn.Interpolation._func(x, period=period, start=start, dim=dim)


def ipnorm(y):
"""Perform cepstrum inverse power normalization.

Parameters
----------
y : Tensor [shape=(..., M+2)]
Power-normalized cepstrum.

Returns
-------
out : Tensor [shape=(..., M+1)]
Output cepstrum.

"""
return nn.MelCepstrumInversePowerNormalization._func(y)


def istft(
y,
*,
Expand Down Expand Up @@ -1672,6 +1689,29 @@ def phase(b=None, a=None, *, fft_length=512, unwrap=False):
return nn.Phase._func(b, a, fft_length=fft_length, unwrap=unwrap)


def pnorm(x, alpha=0, ir_length=128):
"""Perform cepstrum power normalization.

Parameters
----------
x : Tensor [shape=(..., M+1)]
Input cepstrum.

alpha : float in (-1, 1)
Frequency warping factor, :math:`\\alpha`.

ir_length : int >= 1
Length of impulse response.

Returns
-------
out : Tensor [shape=(..., M+2)]
Power-normalized cepstrum.

"""
return nn.MelCepstrumPowerNormalization._func(x, alpha=alpha, ir_length=ir_length)


def pol_root(x, real=False):
"""Compute polynomial coefficients from roots.

Expand Down
2 changes: 2 additions & 0 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .imglsadf import PseudoInverseMGLSADigitalFilter as IMLSA
from .imsvq import InverseMultiStageVectorQuantization
from .interpolate import Interpolation
from .ipnorm import MelCepstrumInversePowerNormalization
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks as IPQMF
from .istft import InverseShortTimeFourierTransform
Expand Down Expand Up @@ -99,6 +100,7 @@
from .pitch import Pitch
from .plp import PerceptualLinearPredictiveCoefficientsAnalysis
from .plp import PerceptualLinearPredictiveCoefficientsAnalysis as PLP
from .pnorm import MelCepstrumPowerNormalization
from .pol_root import RootsToPolynomial
from .poledf import AllPoleDigitalFilter
from .pqmf import PseudoQuadratureMirrorFilterBanks
Expand Down
71 changes: 71 additions & 0 deletions diffsptk/modules/ipnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..misc.utils import check_size


class MelCepstrumInversePowerNormalization(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/ipnorm.html>`_
for details.

Parameters
----------
cep_order : int >= 0
Order of cepstrum, :math:`M`.

"""

def __init__(self, cep_order):
super().__init__()

self.cep_order = cep_order

def forward(self, y):
"""Perform cepstrum inverse power normalization.

Parameters
----------
y : Tensor [shape=(..., M+2)]
Power-normalized cepstrum.

Returns
-------
out : Tensor [shape=(..., M+1)]
Output cepstrum.

Examples
--------
>>> x = diffsptk.ramp(1, 4)
>>> pnorm = diffsptk.MelCepstrumPowerNormalization(3, alpha=0.1)
>>> ipnorm = diffsptk.MelCepstrumInversePowerNormalization(3)
>>> y = ipnorm(pnorm(x))
>>> y
tensor([1., 2., 3., 4.])

"""
check_size(y.size(-1), self.cep_order + 2, "dimension of cepstrum")
return self._forward(y)

@staticmethod
def _forward(y):
P, y1, y2 = torch.split(y, [1, 1, y.size(-1) - 2], dim=-1)
x = torch.cat((0.5 * P + y1, y2), dim=-1)
return x

_func = _forward
6 changes: 3 additions & 3 deletions diffsptk/modules/mcpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, cep_order, alpha=0, beta=0, onset=2, ir_length=128):

assert 0 <= onset

self.mc2en = nn.Sequential(
self.mc2pow = nn.Sequential(
FrequencyTransform(cep_order, ir_length - 1, -alpha),
CepstrumToAutocorrelation(ir_length - 1, 0, ir_length),
)
Expand Down Expand Up @@ -92,10 +92,10 @@ def forward(self, mc):

"""
mc1 = mc
e1 = self.mc2en(mc1)
e1 = self.mc2pow(mc1)

mc2 = mc * self.weight
e2 = self.mc2en(mc2)
e2 = self.mc2pow(mc2)

b2 = self.mc2b(mc2)
b2[..., :1] += 0.5 * torch.log(e1 / e2)
Expand Down
87 changes: 87 additions & 0 deletions diffsptk/modules/pnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from .c2acr import CepstrumToAutocorrelation
from .freqt import FrequencyTransform


class MelCepstrumPowerNormalization(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/pnorm.html>`_
for details.

Parameters
----------
cep_order : int >= 0
Order of cepstrum, :math:`M`.

alpha : float in (-1, 1)
Frequency warping factor, :math:`\\alpha`.

ir_length : int >= 1
Length of impulse response.

"""

def __init__(self, cep_order, alpha=0, ir_length=128):
super().__init__()

self.mc2pow = nn.Sequential(
FrequencyTransform(cep_order, ir_length - 1, -alpha),
CepstrumToAutocorrelation(ir_length - 1, 0, ir_length),
)

def forward(self, x):
"""Perform cepstrum power normalization.

Parameters
----------
x : Tensor [shape=(..., M+1)]
Input cepstrum.

Returns
-------
out : Tensor [shape=(..., M+2)]
Power-normalized cepstrum.

Examples
--------
>>> x = diffsptk.ramp(1, 4)
>>> pnorm = diffsptk.MelCepstrumPowerNormalization(3, alpha=0.1)
>>> y = pnorm(x)
>>> y
tensor([ 8.2942, -7.2942, 2.0000, 3.0000, 4.0000])

"""
return self._forward(x, self.mc2pow)

@staticmethod
def _forward(x, mc2pow):
x0, x1 = torch.split(x, [1, x.size(-1) - 1], dim=-1)
P = torch.log(mc2pow(x))
y = torch.cat((P, x0 - 0.5 * P, x1), dim=-1)
return y

@staticmethod
def _func(x, alpha, ir_length):
def mc2pow(mc):
c = FrequencyTransform._func(mc, ir_length - 1, -alpha)
r = CepstrumToAutocorrelation._func(c, 0, ir_length)
return r

return MelCepstrumPowerNormalization._forward(x, mc2pow)
13 changes: 13 additions & 0 deletions docs/modules/ipnorm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _ipnorm:

ipnorm
======

.. autoclass:: diffsptk.MelCepstrumInversePowerNormalization
:members:

.. autofunction:: diffsptk.functional.ipnorm

.. seealso::

:ref:`pnorm`
13 changes: 13 additions & 0 deletions docs/modules/pnorm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _pnorm:

pnorm
=====

.. autoclass:: diffsptk.MelCepstrumPowerNormalization
:members:

.. autofunction:: diffsptk.functional.pnorm

.. seealso::

:ref:`ipnorm`
44 changes: 44 additions & 0 deletions tests/test_ipnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
def test_compatibility(device, module, M=4, B=2):
ipnorm = U.choice(
module,
diffsptk.MelCepstrumInversePowerNormalization,
diffsptk.functional.ipnorm,
{"cep_order": M},
)

U.check_compatibility(
device,
ipnorm,
[],
f"nrand -l {B*(M+2)}",
f"ipnorm -m {M}",
[],
dx=M + 2,
dy=M + 1,
)

U.check_differentiability(device, ipnorm, [B, M + 2])
45 changes: 45 additions & 0 deletions tests/test_pnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
def test_compatibility(device, module, M=9, alpha=0.1, L=64, B=2):
pnorm = U.choice(
module,
diffsptk.MelCepstrumPowerNormalization,
diffsptk.functional.pnorm,
{"cep_order": M},
{"alpha": alpha, "ir_length": L},
)

U.check_compatibility(
device,
pnorm,
[],
f"nrand -l {B*(M+1)}",
f"pnorm -m {M} -a {alpha} -l {L}",
[],
dx=M + 1,
dy=M + 2,
)

U.check_differentiability(device, pnorm, [B, M + 1])