Skip to content

Commit

Permalink
Misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed Aug 28, 2017
1 parent 779b96a commit 7a0a991
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 159 deletions.
30 changes: 15 additions & 15 deletions nnmnkwii/autograd/_impl/mlpg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding: utf-8
from __future__ import with_statement, print_function, absolute_import

from nnmnkwii import functions as F
from nnmnkwii import paramgen as G

from torch.autograd import Function
import torch
Expand All @@ -16,8 +16,8 @@ class MLPG(Function):
This is meant to be used for Minimum Geneartion Error (MGE) training for
speech synthesis and voice conversion. See [1]_ and [2]_ for details.
It relies on :func:`nnmnkwii.functions.mlpg` and
:func:`nnmnkwii.functions.mlpg_grad` for forward and backward computation,
It relies on :func:`nnmnkwii.paramgen.mlpg` and
:func:`nnmnkwii.paramgen.mlpg_grad` for forward and backward computation,
respectively.
.. [1] Wu, Zhizheng, and Simon King. "Minimum trajectory error training
Expand All @@ -29,17 +29,17 @@ class MLPG(Function):
Attributes:
variances (torch.FloatTensor): Variances same as in
:func:`nnmnkwii.functions.mlpg`.
windows (list): same as in :func:`nnmnkwii.functions.mlpg`.
:func:`nnmnkwii.paramgen.mlpg`.
windows (list): same as in :func:`nnmnkwii.paramgen.mlpg`.
Warnings:
The function is generic but cannot run on CUDA. For faster
differenciable MLPG, see :obj:`UnitVarianceMLPG`.
See also:
:func:`nnmnkwii.autograd.mlpg`,
:func:`nnmnkwii.functions.mlpg`,
:func:`nnmnkwii.functions.mlpg_grad`.
:func:`nnmnkwii.paramgen.mlpg`,
:func:`nnmnkwii.paramgen.mlpg_grad`.
"""

def __init__(self, variances, windows):
Expand All @@ -57,7 +57,7 @@ def forward(self, means):

means_np = means.numpy()
variances_np = variances.numpy()
y = F.mlpg(means_np, variances_np, self.windows)
y = G.mlpg(means_np, variances_np, self.windows)
y = torch.from_numpy(y.astype(np.float32))
return y

Expand All @@ -70,7 +70,7 @@ def backward(self, grad_output):
grad_output_numpy = grad_output.numpy()
means_numpy = means.numpy()
variances_numpy = variances.numpy()
grads_numpy = F.mlpg_grad(
grads_numpy = G.mlpg_grad(
means_numpy, variances_numpy, self.windows,
grad_output_numpy)

Expand Down Expand Up @@ -104,12 +104,12 @@ class UnitVarianceMLPG(Function):
To avoid dupulicate computations in forward and backward, the function
takes ``R`` at construction time. The matrix ``R`` can be computed by
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`.
:func:`nnmnkwii.paramgen.unit_variance_mlpg_matrix`.
Args:
R: Unit-variance MLPG matrix of shape (``T x num_windows*T``). This
should be created with
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`.
:func:`nnmnkwii.paramgen.unit_variance_mlpg_matrix`.
Attributes:
Expand Down Expand Up @@ -184,7 +184,7 @@ def backward(self, grad_output):
def mlpg(means, variances, windows):
"""Maximum Liklihood Paramter Generation (MLPG).
The parameters are almost same as :func:`nnmnkwii.functions.mlpg` expects.
The parameters are almost same as :func:`nnmnkwii.paramgen.mlpg` expects.
The differences are:
- The function assumes ``means`` as :obj:`torch.autograd.Variable`
Expand All @@ -198,7 +198,7 @@ def mlpg(means, variances, windows):
windows (list): A sequence of window specification
See also:
:obj:`nnmnkwii.autograd.MLPG`, :func:`nnmnkwii.functions.mlpg`
:obj:`nnmnkwii.autograd.MLPG`, :func:`nnmnkwii.paramgen.mlpg`
"""
T, D = means.size()
Expand All @@ -214,13 +214,13 @@ def unit_variance_mlpg(R, means):
Args:
means (torch.autograd.Variable): Means, of shape (``T x D``) or
(``T*num_windows x static_dim``). See
:func:`nnmnkwii.functions.reshape_means` to reshape means from
:func:`nnmnkwii.paramgen.reshape_means` to reshape means from
(``T x D``) to (``T*num_windows x static_dim``).
R (torch.FloatTensor): MLPG matrix.
See also:
:obj:`nnmnkwii.autograd.UnitVarianceMLPG`,
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`,
:func:`nnmnkwii.paramgen.unit_variance_mlpg_matrix`,
:func:`reshape_means`.
"""
return UnitVarianceMLPG(R)(means)
7 changes: 2 additions & 5 deletions nnmnkwii/autograd/_impl/modspec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import with_statement, print_function, absolute_import

from nnmnkwii import functions as F
from nnmnkwii.preprocessing.modspec import modspec as _modspec

from torch.autograd import Function
import torch
Expand All @@ -24,7 +24,7 @@ def forward(self, y):
self.save_for_backward(y)

y_np = y.numpy()
ms = torch.from_numpy(F.modspec(y_np, n=self.n, norm=self.norm))
ms = torch.from_numpy(_modspec(y_np, n=self.n, norm=self.norm))

return ms

Expand Down Expand Up @@ -72,8 +72,5 @@ def modspec(y, n=2048, norm=None):
n (int): DFT length.
norm (bool): Normalize DFT output or not. See :obj:`numpy.fft.fft`.
See also:
:func:`nnmnkwii.functions.modspec`
"""
return ModSpec(n=n, norm=norm)(y)
20 changes: 10 additions & 10 deletions nnmnkwii/paramgen/_mlpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def build_win_mats(windows, T):
banded representation.
Examples:
>>> from nnmnkwii import functions as F
>>> from nnmnkwii import paramgen as G
>>> import numpy as np
>>> windows = [
... (0, 0, np.array([1.0])), # static
... (1, 1, np.array([-0.5, 0.0, 0.5])), # delta
... (1, 1, np.array([1.0, -2.0, 1.0])), # delta-delta
... ]
>>> win_mats = F.build_win_mats(windows, 3)
>>> win_mats = G.build_win_mats(windows, 3)
"""
win_mats = []
for l, u, win_coeff in windows:
Expand Down Expand Up @@ -149,7 +149,7 @@ def mlpg(mean_frames, variance_frames, windows):
Generated static features over time
Examples:
>>> from nnmnkwii import functions as F
>>> from nnmnkwii import paramgen as G
>>> windows = [
... (0, 0, np.array([1.0])), # static
... (1, 1, np.array([-0.5, 0.0, 0.5])), # delta
Expand All @@ -158,7 +158,7 @@ def mlpg(mean_frames, variance_frames, windows):
>>> T, static_dim = 10, 24
>>> mean_frames = np.random.rand(T, static_dim * len(windows))
>>> variance_frames = np.random.rand(T, static_dim * len(windows))
>>> static_features = F.mlpg(mean_frames, variance_frames, windows)
>>> static_features = G.mlpg(mean_frames, variance_frames, windows)
>>> assert static_features.shape == (T, static_dim)
See also:
Expand Down Expand Up @@ -198,7 +198,7 @@ def mlpg(mean_frames, variance_frames, windows):
def mlpg_grad(mean_frames, variance_frames, windows, grad_output):
"""MLPG gradient computation
Parameters are same as :func:`nnmnkwii.functions.mlpg` except for
Parameters are same as :func:`nnmnkwii.paramgen.mlpg` except for
``grad_output``. See the function docmenent for what the parameters mean.
Let :math:`d` is the index of static features, :math:`l` is the index
Expand Down Expand Up @@ -311,17 +311,17 @@ def unit_variance_mlpg_matrix(windows, T):
See also:
:func:`nnmnkwii.autograd.UnitVarianceMLPG`,
:func:`nnmnkwii.functions.mlpg`.
:func:`nnmnkwii.paramgen.mlpg`.
Examples:
>>> from nnmnkwii import functions as F
>>> from nnmnkwii import paramgen as G
>>> import numpy as np
>>> windows = [
... (0, 0, np.array([1.0])),
... (1, 1, np.array([-0.5, 0.0, 0.5])),
... (1, 1, np.array([1.0, -2.0, 1.0])),
... ]
>>> F.unit_variance_mlpg_matrix(windows, 3)
>>> G.unit_variance_mlpg_matrix(windows, 3)
array([[ 2.73835927e-01, 1.95121944e-01, 9.20177400e-02,
9.75609720e-02, -9.09090936e-02, -9.75609720e-02,
-3.52549881e-01, -2.43902430e-02, 1.10864742e-02],
Expand Down Expand Up @@ -357,7 +357,7 @@ def reshape_means(means, static_dim):
No-op if already reshaped.
Examples:
>>> from nnmnkwii import functions as F
>>> from nnmnkwii import paramgen as G
>>> import numpy as np
>>> T, static_dim = 2, 2
>>> windows = [
Expand All @@ -366,7 +366,7 @@ def reshape_means(means, static_dim):
... (1, 1, np.array([1.0, -2.0, 1.0])), # delta-delta
... ]
>>> means = np.random.rand(T, static_dim * len(windows))
>>> reshaped_means = F.reshape_means(means, static_dim)
>>> reshaped_means = G.reshape_means(means, static_dim)
>>> assert reshaped_means.shape == (T*len(windows), static_dim)
"""
T, D = means.shape
Expand Down
2 changes: 1 addition & 1 deletion nnmnkwii/preprocessing/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def delta_features(x, windows):
Args:
x (numpy.ndarray): Input static features, of shape (``T x D``).
y (list): List of windows. See :func:`nnmnkwii.functions.mlpg` for what
y (list): List of windows. See :func:`nnmnkwii.paramgen.mlpg` for what
the delta window means.
Returns:
Expand Down
8 changes: 4 additions & 4 deletions perf/autograd_mlpg_perf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import division, print_function, absolute_import

from nnmnkwii import functions as F
from nnmnkwii import paramgen as G
from nnmnkwii import autograd as AF
from torch.autograd import Variable
import torch
Expand Down Expand Up @@ -45,10 +45,10 @@ def benchmark_mlpg(static_dim=59, T=100, batch_size=10, use_cuda=True):
torch.manual_seed(1234)
means = np.random.rand(T, static_dim * len(windows)).astype(np.float32)
variances = np.ones(static_dim * len(windows))
reshaped_means = F.reshape_means(means, static_dim)
reshaped_means = G.reshape_means(means, static_dim)

# Ppseud target
y = F.mlpg(means, variances, windows).astype(np.float32)
y = G.mlpg(means, variances, windows).astype(np.float32)

# Pack into variables
means = Variable(torch.from_numpy(means), requires_grad=True)
Expand All @@ -70,7 +70,7 @@ def benchmark_mlpg(static_dim=59, T=100, batch_size=10, use_cuda=True):
since = time.time()
if use_cuda:
y = y.cuda()
R = F.unit_variance_mlpg_matrix(windows, T)
R = G.unit_variance_mlpg_matrix(windows, T)
R = torch.from_numpy(R)
# Assuming minibatch are zero-ppaded, we only need to create MLPG matrix
# per-minibatch, not per-utterance.
Expand Down
16 changes: 8 additions & 8 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from nnmnkwii.autograd._impl.mlpg import MLPG, UnitVarianceMLPG
from nnmnkwii.autograd._impl.modspec import ModSpec
from nnmnkwii import functions as F
from nnmnkwii import paramgen as G
from nnmnkwii import autograd as AF

from torch.autograd import gradcheck
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_functional_mlpg():
means = torch.rand(T, static_dim * len(windows))
variances = torch.ones(static_dim * len(windows))

y = F.mlpg(means.numpy(), variances.numpy(), windows)
y = G.mlpg(means.numpy(), variances.numpy(), windows)
y = Variable(torch.from_numpy(y), requires_grad=False)

means = Variable(means, requires_grad=True)
Expand All @@ -56,7 +56,7 @@ def test_functional_mlpg():
nn.MSELoss()(y_hat, y).backward()

# unit_variance_mlpg
R = torch.from_numpy(F.unit_variance_mlpg_matrix(windows, T))
R = torch.from_numpy(G.unit_variance_mlpg_matrix(windows, T))
y_hat = AF.unit_variance_mlpg(R, means)
assert np.allclose(y.data.numpy(), y_hat.data.numpy())

Expand All @@ -81,13 +81,13 @@ def test_unit_variance_mlpg_gradcheck():
requires_grad=True)

# Input for UnitVarianceMLPG
reshaped_means = F.reshape_means(
reshaped_means = G.reshape_means(
means.data.clone().numpy(), static_dim)
reshaped_means = Variable(torch.from_numpy(reshaped_means),
requires_grad=True)

# Compute MLPG matrix
R = F.unit_variance_mlpg_matrix(windows, T).astype(np.float32)
R = G.unit_variance_mlpg_matrix(windows, T).astype(np.float32)
R = torch.from_numpy(R)

# UnitVarianceMLPG can take input with both means and reshaped_means
Expand Down Expand Up @@ -127,12 +127,12 @@ def test_minibatch_unit_variance_mlpg_gradcheck():
means_expanded = means.expand(
batch_size, means.shape[0], means.shape[1])
reshaped_means = torch.from_numpy(
F.reshape_means(means.numpy(), static_dim))
G.reshape_means(means.numpy(), static_dim))
reshaped_means_expanded = reshaped_means.expand(
batch_size, reshaped_means.shape[0], reshaped_means.shape[1])

# Target
y = F.mlpg(means.numpy(), np.ones(static_dim * len(windows)), windows)
y = G.mlpg(means.numpy(), np.ones(static_dim * len(windows)), windows)
y = Variable(torch.from_numpy(y), requires_grad=False)
y_expanded = y.expand(batch_size, y.size(0), y.size(1))

Expand All @@ -144,7 +144,7 @@ def test_minibatch_unit_variance_mlpg_gradcheck():
reshaped_means_expanded, requires_grad=True)

# Case 1: 2d with reshaped means
R = torch.from_numpy(F.unit_variance_mlpg_matrix(windows, T))
R = torch.from_numpy(G.unit_variance_mlpg_matrix(windows, T))
y_hat1 = AF.unit_variance_mlpg(R, reshaped_means)

# Case 2: 3d with reshaped means
Expand Down

0 comments on commit 7a0a991

Please sign in to comment.