Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,13 @@ def entropy(p, out_format="nat"):
return nn.Entropy._func(p, out_format=out_format)


def excite(p, frame_period=80, voiced_region="pulse", unvoiced_region="gauss"):
def excite(
p,
frame_period=80,
voiced_region="pulse",
unvoiced_region="gauss",
polarity="auto",
):
"""Generate a simple excitation signal.

Parameters
Expand All @@ -523,12 +529,16 @@ def excite(p, frame_period=80, voiced_region="pulse", unvoiced_region="gauss"):
frame_period : int >= 1
Frame period in samples, :math:`P`.

voiced_region : ['pulse', 'sinusoidal', 'sawtooth']
voiced_region : ['pulse', 'sinusoidal', 'sawtooth', 'inverted-sawtooth', 'triangle',
'square']
Value on voiced region.

unvoiced_region : ['gauss', 'zeros']
Value on unvoiced region.

polarity : ['auto', 'unipolar', 'bipolar']
Polarity.

Returns
-------
out : Tensor [shape=(..., NxP)]
Expand All @@ -540,6 +550,7 @@ def excite(p, frame_period=80, voiced_region="pulse", unvoiced_region="gauss"):
frame_period=frame_period,
voiced_region=voiced_region,
unvoiced_region=unvoiced_region,
polarity=polarity,
)


Expand Down
69 changes: 60 additions & 9 deletions diffsptk/modules/excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,43 @@ class ExcitationGeneration(nn.Module):
frame_period : int >= 1
Frame period in samples, :math:`P`.

voiced_region : ['pulse', 'sinusoidal', 'sawtooth']
voiced_region : ['pulse', 'sinusoidal', 'sawtooth', 'inverted-sawtooth', 'triangle',
'square']
Value on voiced region.

unvoiced_region : ['gauss', 'zeros']
Value on unvoiced region.

polarity : ['auto', 'unipolar', 'bipolar']
Polarity.

"""

def __init__(self, frame_period, voiced_region="pulse", unvoiced_region="gauss"):
def __init__(
self,
frame_period,
voiced_region="pulse",
unvoiced_region="gauss",
polarity="auto",
):
super().__init__()

assert 1 <= frame_period
assert voiced_region in ("pulse", "sinusoidal", "sawtooth")
assert voiced_region in (
"pulse",
"sinusoidal",
"sawtooth",
"inverted-sawtooth",
"triangle",
"square",
)
assert unvoiced_region in ("gauss", "zeros")
assert polarity in ("auto", "unipolar", "bipolar")

self.frame_period = frame_period
self.voiced_region = voiced_region
self.unvoiced_region = unvoiced_region
self.polarity = polarity

def forward(self, p):
"""Generate a simple excitation signal.
Expand All @@ -74,11 +93,15 @@ def forward(self, p):

"""
return self._forward(
p, self.frame_period, self.voiced_region, self.unvoiced_region
p,
self.frame_period,
self.voiced_region,
self.unvoiced_region,
self.polarity,
)

@staticmethod
def _forward(p, frame_period, voiced_region, unvoiced_region):
def _forward(p, frame_period, voiced_region, unvoiced_region, polarity):
# Make mask represents voiced region.
base_mask = torch.clip(p, min=0, max=1)
mask = torch.ne(base_mask, UNVOICED_SYMBOL)
Expand Down Expand Up @@ -106,16 +129,44 @@ def _forward(p, frame_period, voiced_region, unvoiced_region):
phase = (s - bias).to(p.dtype)

# Generate excitation signal using phase.
if polarity == "auto":
unipolar = voiced_region == "pulse"
else:
unipolar = polarity == "unipolar"
e = torch.zeros_like(p)
if voiced_region == "pulse":
r = torch.ceil(phase)
r = F.pad(r, (1, 0))
pulse_pos = torch.ge(torch.diff(r), 1)
e = torch.zeros_like(p)
e[pulse_pos] = torch.sqrt(p[pulse_pos])
if unipolar:
e[pulse_pos] = torch.sqrt(p[pulse_pos])
else:
raise RuntimeError
elif voiced_region == "sinusoidal":
e = torch.sin(TWO_PI * phase)
if unipolar:
e[mask] = 0.5 * (1 - torch.cos(TWO_PI * phase[mask]))
else:
e[mask] = torch.sin(TWO_PI * phase[mask])
elif voiced_region == "sawtooth":
e = torch.fmod(phase, 2) - 1
if unipolar:
e[mask] = torch.fmod(phase[mask], 1)
else:
e[mask] = 2 * torch.fmod(phase[mask], 1) - 1
elif voiced_region == "inverted-sawtooth":
if unipolar:
e[mask] = 1 - torch.fmod(phase[mask], 1)
else:
e[mask] = 1 - 2 * torch.fmod(phase[mask], 1)
elif voiced_region == "triangle":
if unipolar:
e[mask] = torch.abs(2 * torch.fmod(phase[mask] + 0.5, 1) - 1)
else:
e[mask] = 2 * torch.abs(2 * torch.fmod(phase[mask] + 1.75, 1) - 1) - 1
elif voiced_region == "square":
if unipolar:
e[mask] = torch.le(torch.fmod(phase[mask], 1), 0.5).to(e.dtype)
else:
e[mask] = 2 * torch.le(torch.fmod(phase[mask], 1), 0.5).to(e.dtype) - 1
else:
raise ValueError(f"voiced_region {voiced_region} is not supported.")

Expand Down
15 changes: 11 additions & 4 deletions tests/test_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,21 @@ def compute_error(infile):
U.call("rm -f excite.tmp?", get=False)


@pytest.mark.parametrize("voiced_region", ["pulse", "sinusoidal", "sawtooth"])
def test_waveform(voiced_region, P=80, verbose=False):
@pytest.mark.parametrize(
"voiced_region",
["pulse", "sinusoidal", "sawtooth", "inverted-sawtooth", "triangle", "square"],
)
@pytest.mark.parametrize("polarity", ["unipolar", "bipolar"])
def test_waveform(voiced_region, polarity, P=80, verbose=False):
if voiced_region == "pulse" and polarity == "bipolar":
return

excite = diffsptk.ExcitationGeneration(
P, voiced_region=voiced_region, unvoiced_region="zeros"
P, voiced_region=voiced_region, unvoiced_region="zeros", polarity=polarity
)
pitch = torch.from_numpy(
U.call("x2x +sd tools/SPTK/asset/data.short | " f"pitch -s 16 -p {P} -o 0 -a 2")
)
e = excite(pitch)
if verbose:
sf.write(f"excite_{voiced_region}.wav", e, 16000)
sf.write(f"excite_{voiced_region}_{polarity}.wav", e, 16000)