Skip to content

Commit

Permalink
Add faster version of MLPG and linalg utilities
Browse files Browse the repository at this point in the history
New function `UnitVarianceMLPG` can run on GPU/CPU. Fixes #4

The package now requires cython.
  • Loading branch information
r9y9 committed Aug 20, 2017
1 parent 6d5514e commit 1484238
Show file tree
Hide file tree
Showing 15 changed files with 705 additions and 74 deletions.
2 changes: 2 additions & 0 deletions .gitignore
@@ -1,3 +1,5 @@
nnmnkwii/functions/_impl/_mlpg.c
nnmnkwii/util/_linalg.c
examples
docs/references/generated

Expand Down
7 changes: 4 additions & 3 deletions docs/references/autograd.rst
Expand Up @@ -6,16 +6,14 @@ Autograd
Differenciable functions for PyTorch. This may be extended
to support other autograd frameworks.

Currently all functions doesn't have CUDA implementation, but should be
addressed later.

Functional interface
--------------------

.. autosummary::
:toctree: generated/

mlpg
unit_variance_mlpg
modspec

Function classes
Expand All @@ -24,5 +22,8 @@ Function classes
.. autoclass:: MLPG
:members:

.. autoclass:: UnitVarianceMLPG
:members:

.. autoclass:: ModSpec
:members:
14 changes: 14 additions & 0 deletions docs/references/functions.rst
Expand Up @@ -6,10 +6,24 @@ speech synthesis.

.. automodule:: nnmnkwii.functions

MLPG
----

.. autosummary::
:toctree: generated/

build_win_mats
mlpg
mlpg_grad
unit_variance_mlpg_matrix
reshape_means


Modulation spectrum
-------------------

.. autosummary::
:toctree: generated/

modspec
modphase
12 changes: 12 additions & 0 deletions docs/references/util.rst
Expand Up @@ -44,3 +44,15 @@ Example quetsion file was taken from Merlin_.
example_question_file
example_file_data_sources_for_duration_model
example_file_data_sources_for_acoustic_model


Linear algebra
--------------

.. automodule:: nnmnkwii.util.linalg

.. autosummary::
:toctree: generated/

cholesky_inv
cholesky_inv_banded
5 changes: 3 additions & 2 deletions nnmnkwii/autograd/__init__.py
@@ -1,4 +1,5 @@
from __future__ import with_statement, print_function, absolute_import

from nnmnkwii.autograd._impl.mlpg import mlpg, MLPG
from nnmnkwii.autograd._impl.modspec import modspec, ModSpec
from ._impl.mlpg import mlpg, MLPG
from ._impl.mlpg import unit_variance_mlpg, UnitVarianceMLPG
from ._impl.modspec import modspec, ModSpec
117 changes: 90 additions & 27 deletions nnmnkwii/autograd/_impl/mlpg.py
Expand Up @@ -9,46 +9,37 @@


class MLPG(Function):
"""MLPG as an autograd function ``f : (T, D) -> (T, static_dim)``.
"""Generic MLPG as an autograd function.
``f : (T, D) -> (T, static_dim)``.
This is meant to be used for Minimum Geneartion Error (MGE) training for
speech synthesis and voice conversion. See [1]_ for details.
speech synthesis and voice conversion. See [1]_ and [2]_ for details.
It relies on :func:`nnmnkwii.functions.mlpg` and
:func:`nnmnkwii.functions.mlpg_grad` for forward and backward computation,
respectively.
.. [1] Wu, Zhizheng, and Simon King. "Minimum trajectory error training
for deep neural networks, combined with stacked bottleneck features."
INTERSPEECH. 2015.
Let :math:`d` is the index of static features, :math:`l` is the index
of windows, gradients :math:`g_{d,l}` can be computed by:
.. math::
g_{d,l} = (\sum_{l} W_{l}^{T}P_{d,l}W_{l})^{-1} W_{l}^{T}P_{d,l}
where :math:`W_{l}` is a banded window matrix and :math:`P_{d,l}` is a
diagonal precision matrix.
Assuming the variances are diagonals, MLPG can be performed in
dimention-by-dimention efficiently.
Let :math:`o_{d}` be ``T`` dimentional back-propagated gradients, the
resulting gradients :math:`g'_{l,d}` to be propagated are
computed as follows:
.. math::
g'_{d,l} = o_{d}^{T} g_{d,l}
.. [2] Xie, Feng-Long, et al. "Sequence error (SE) minimization training of
neural network for voice conversion." Fifteenth Annual Conference of the
International Speech Communication Association. 2014.
Attributes:
variance_frames (torch.FloatTensor): Variances same as in
:func:`nnmnkwii.functions.mlpg`.
windows (list): same as in :func:`nnmnkwii.functions.mlpg`.
TODO:
CUDA implementation
Warnings:
The function is generic but cannot run on CUDA. For faster
differenciable MLPG, see :obj:`UnitVarianceMLPG`.
See also:
:func:`nnmnkwii.functions.mlpg`, :func:`nnmnkwii.functions.mlpg_grad`.
:func:`nnmnkwii.autograd.mlpg`,
:func:`nnmnkwii.functions.mlpg`,
:func:`nnmnkwii.functions.mlpg_grad`.
"""

def __init__(self, variance_frames, windows):
Expand Down Expand Up @@ -86,6 +77,61 @@ def backward(self, grad_output):
return torch.from_numpy(grads_numpy).clone()


class UnitVarianceMLPG(Function):
"""Special case of MLPG assuming data is normalized to have unit variance.
``f : (T*num_windows, static_dim) -> (T, static_dim)``.
The funtion is theoretically a special case of :obj:`MLPG`. The function
assumes input data is noramlized to have unit variance for each dimention.
The property of the unit-variance greatly simplifies the backward
computation of MLPG.
Let :math:`\mu` is the input mean sequence (``num_windows*T x static_dim``),
:math:`W` is a window matrix ``(T x num_windows*T)``, MLPG can be written
as follows:
.. math::
y = R \mu
where
.. math::
R = (W^{T} W)^{-1} W^{T}
Note that we offen represent static + dynamic features as
(``T x static_dim*num_windows``) matirx, but the function assumes input has
shape of (``num_windows*T x static_dim``).
To avoid dupulicate computations in forward and backward, the function
takes ``R`` at construction time. The matrix ``R`` can be computed by
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`.
Args:
R: Unit-variance MLPG matrix of shape (``T x num_windows*T``). This
should be created with
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`.
Attributes:
R: Unit-variance MLPG matrix (``T x num_windows*T``).
See also:
:func:`nnmnkwii.autograd.unit_variance_mlpg`.
"""
def __init__(self, R):
super(UnitVarianceMLPG, self).__init__()
self.R = R

def forward(self, means):
return torch.mm(self.R, means)

def backward(self, grad_output):
return torch.mm(self.R.transpose(0,1), grad_output)


def mlpg(mean_frames, variance_frames, windows):
"""Maximum Liklihood Paramter Generation (MLPG).
Expand All @@ -103,9 +149,26 @@ def mlpg(mean_frames, variance_frames, windows):
windows (list): A sequence of window specification
See also:
:func:`nnmnkwii.functions.mlpg`
:obj:`nnmnkwii.autograd.MLPG`, :func:`nnmnkwii.functions.mlpg`
"""
T, D = mean_frames.size()
assert mean_frames.size() == variance_frames.size()
return MLPG(variance_frames, windows)(mean_frames)

def unit_variance_mlpg(means, R):
"""Special case of MLPG assuming data is normalized to have unit variance.
Args:
means (torch.autograd.Variable): Means, of shape
(``T*num_windows x static_dim``). See
:func:`nnmnkwii.functions.reshape_means` to reshape means from
(``T x D``) to (``T*num_windows x static_dim``).
R (torch.FloatTensor): MLPG matrix.
See also:
:obj:`nnmnkwii.autograd.UnitVarianceMLPG`,
:func:`nnmnkwii.functions.unit_variance_mlpg_matrix`,
:func:`reshape_means`.
"""
return UnitVarianceMLPG(R)(means)
5 changes: 3 additions & 2 deletions nnmnkwii/functions/__init__.py
@@ -1,5 +1,6 @@
# coding: utf-8
from __future__ import division, print_function, absolute_import

from nnmnkwii.functions._impl.mlpg import mlpg, mlpg_grad
from nnmnkwii.functions._impl.modspec import modspec, modphase
from ._impl.mlpg import mlpg, mlpg_grad, build_win_mats, full_window_mat
from ._impl.mlpg import unit_variance_mlpg_matrix, reshape_means
from ._impl.modspec import modspec, modphase
33 changes: 33 additions & 0 deletions nnmnkwii/functions/_impl/_mlpg.pyx
@@ -0,0 +1,33 @@
# coding: utf-8
# cython: wraparound = False
# cython: boundscheck = False

import numpy as np
cimport numpy as np


def full_window_mat(win_mats, int T):
cdef np.ndarray[np.float64_t, ndim = 2] mat_full

mat_full = np.zeros((T * len(win_mats), T))

cdef long size
cdef long i
cdef long win_index
cdef unsigned long row
cdef unsigned long j
cdef long transposed

for win_index, win_mat in enumerate(win_mats):
transposed = win_mat.transposed
row_offset = win_index * T
u = win_mat.u
l = win_mat.l
mat_rect = win_mat.data
size = mat_rect.shape[1]
for i in range(-u, l + 1):
row = l - i if transposed else u + i
for j in range(max(0, -i), max(0, size + min(0, -i))):
mat_full[row_offset + j + i, j] = mat_rect[row, j]

return mat_full

0 comments on commit 1484238

Please sign in to comment.