Skip to content

Commit

Permalink
UnitVarianceMLPG can now support reshaped/unreshaped means
Browse files Browse the repository at this point in the history
and also support 3d tensors
  • Loading branch information
r9y9 committed Aug 21, 2017
1 parent 6a59d9a commit ef5bd1c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 44 deletions.
114 changes: 83 additions & 31 deletions nnmnkwii/autograd/_impl/mlpg.py
Expand Up @@ -28,7 +28,7 @@ class MLPG(Function):
International Speech Communication Association. 2014.
Attributes:
variance_frames (torch.FloatTensor): Variances same as in
variances (torch.FloatTensor): Variances same as in
:func:`nnmnkwii.functions.mlpg`.
windows (list): same as in :func:`nnmnkwii.functions.mlpg`.
Expand All @@ -42,36 +42,36 @@ class MLPG(Function):
:func:`nnmnkwii.functions.mlpg_grad`.
"""

def __init__(self, variance_frames, windows):
def __init__(self, variances, windows):
super(MLPG, self).__init__()
self.windows = windows
self.variance_frames = variance_frames
self.variances = variances

def forward(self, mean_frames):
assert mean_frames.dim() == 2 # we cannot do MLPG on minibatch
variance_frames = self.variance_frames
self.save_for_backward(mean_frames)
def forward(self, means):
assert means.dim() == 2 # we cannot do MLPG on minibatch
variances = self.variances
self.save_for_backward(means)

T, D = mean_frames.size()
assert mean_frames.size() == variance_frames.size()
T, D = means.size()
assert means.size() == variances.size()

mean_frames_np = mean_frames.numpy()
variance_frames_np = variance_frames.numpy()
y = F.mlpg(mean_frames_np, variance_frames_np, self.windows)
means_np = means.numpy()
variances_np = variances.numpy()
y = F.mlpg(means_np, variances_np, self.windows)
y = torch.from_numpy(y.astype(np.float32))
return y

def backward(self, grad_output):
mean_frames, = self.saved_tensors
variance_frames = self.variance_frames
means, = self.saved_tensors
variances = self.variances

T, D = mean_frames.size()
T, D = means.size()

grad_output_numpy = grad_output.numpy()
mean_frames_numpy = mean_frames.numpy()
variance_frames_numpy = variance_frames.numpy()
means_numpy = means.numpy()
variances_numpy = variances.numpy()
grads_numpy = F.mlpg_grad(
mean_frames_numpy, variance_frames_numpy, self.windows,
means_numpy, variances_numpy, self.windows,
grad_output_numpy)

return torch.from_numpy(grads_numpy).clone()
Expand All @@ -80,7 +80,8 @@ def backward(self, grad_output):
class UnitVarianceMLPG(Function):
"""Special case of MLPG assuming data is normalized to have unit variance.
``f : (T*num_windows, static_dim) -> (T, static_dim)``.
``f : (T x D) -> (T, static_dim)``. or
``f : (T*num_windows, static_dim) -> (T, static_dim)``.
The funtion is theoretically a special case of :obj:`MLPG`. The function
assumes input data is noramlized to have unit variance for each dimention.
Expand Down Expand Up @@ -124,45 +125,96 @@ class UnitVarianceMLPG(Function):
def __init__(self, R):
super(UnitVarianceMLPG, self).__init__()
self.R = R
self.num_windows = R.shape[-1] // R.shape[0]

def forward(self, means):
return torch.mm(self.R, means)
# TODO: remove this
self.save_for_backward(means)
T = self.R.shape[0]
dim = means.dim()

# Add batch axis if necessary
if dim == 2:
T_, D = means.shape
B = 1
means = means.view(B, T_, D)
else:
B, T_, D = means.shape

# Check if means has proper shape
reshaped = not (T == T_)
if not reshaped:
static_dim = means.shape[-1] // self.num_windows
reshaped_means = means.contiguous().view(
B, T, self.num_windows, -1).transpose(
1, 2).contiguous().view(B, -1, static_dim)
else:
static_dim = means.shape[-1]
reshaped_means = means

out = torch.matmul(self.R, reshaped_means)
if dim == 2:
return out.view(-1, static_dim)

return out

def backward(self, grad_output):
return torch.mm(self.R.transpose(0,1), grad_output)
means, = self.saved_tensors
T = self.R.shape[0]
dim = means.dim()

# Add batch axis if necessary
if dim == 2:
T_, D = means.shape
B = 1
grad_output = grad_output.view(B, T, -1)
else:
B, T_, D = means.shape

grad = torch.matmul(self.R.transpose(0, 1), grad_output)

reshaped = not (T == T_)
if not reshaped:
grad = grad.view(B, self.num_windows, T, -1).transpose(
1, 2).contiguous().view(B, T, D)

if dim == 2:
return grad.view(-1, D)

return grad


def mlpg(mean_frames, variance_frames, windows):
def mlpg(means, variances, windows):
"""Maximum Liklihood Paramter Generation (MLPG).
The parameters are almost same as :func:`nnmnkwii.functions.mlpg` expects.
The differences are:
- The function assumes ``mean_frames`` as :obj:`torch.autograd.Variable`
- The function assumes ``means`` as :obj:`torch.autograd.Variable`
instead of :obj:`numpy.ndarray`.
- The fucntion assumes ``variances_frames`` as :obj:`torch.FloatTensor` 
instead of :obj:`numpy.ndarray`.
Args:
mean_frames (torch.autograd.Variable): Means
variance_frames (torch.FloatTensor): Variances
means (torch.autograd.Variable): Means
variances (torch.FloatTensor): Variances
windows (list): A sequence of window specification
See also:
:obj:`nnmnkwii.autograd.MLPG`, :func:`nnmnkwii.functions.mlpg`
"""
T, D = mean_frames.size()
if variance_frames.dim() == 1 and variance_frames.shape[0] == D:
variance_frames = variance_frames.expand(T, D)
assert mean_frames.size() == variance_frames.size()
return MLPG(variance_frames, windows)(mean_frames)
T, D = means.size()
if variances.dim() == 1 and variances.shape[0] == D:
variances = variances.expand(T, D)
assert means.size() == variances.size()
return MLPG(variances, windows)(means)

def unit_variance_mlpg(R, means):
"""Special case of MLPG assuming data is normalized to have unit variance.
Args:
means (torch.autograd.Variable): Means, of shape
means (torch.autograd.Variable): Means, of shape (``T x D``) or
(``T*num_windows x static_dim``). See
:func:`nnmnkwii.functions.reshape_means` to reshape means from
(``T x D``) to (``T*num_windows x static_dim``).
Expand Down
7 changes: 4 additions & 3 deletions perf/autograd_mlpg_perf.py
Expand Up @@ -75,9 +75,10 @@ def benchmark_mlpg(static_dim=59, T=100, batch_size=10, use_cuda=True):
R = R.cuda()
for _ in range(batch_size):
if use_cuda:
reshaped_means = reshaped_means.cpu()
reshaped_means = reshaped_means.cuda()
y_hat = AF.unit_variance_mlpg(R, reshaped_means)
means = means.cpu()
means = means.cuda()

y_hat = AF.unit_variance_mlpg(R, means)
L = criterion(y_hat, y)
assert np.allclose(y_hat.cpu().data.numpy(), y.cpu().data.numpy(),
atol=1e-5)
Expand Down
102 changes: 92 additions & 10 deletions tests/test_autograd.py
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from warnings import warn


def _get_windows_set():
windows_set = [
# Static
Expand All @@ -32,35 +33,43 @@ def _get_windows_set():
]
return windows_set


def test_functional_mlpg():
static_dim = 2
T = 10
T = 5

for windows in _get_windows_set():
torch.manual_seed(1234)
means = torch.rand(T, static_dim * len(windows))
reshaped_means = torch.from_numpy(F.reshape_means(means.numpy(), static_dim))
variances = torch.ones(static_dim * len(windows))

y = F.mlpg(means.numpy(), variances.numpy(), windows)
y = Variable(torch.from_numpy(y), requires_grad=False)

means = Variable(means, requires_grad=True)
reshaped_means = Variable(reshaped_means, requires_grad=True)

# mlpg
y_hat = AF.mlpg(means, variances, windows)
assert np.allclose(y.data.numpy(), y_hat.data.numpy())

# Test backward pass
nn.MSELoss()(y_hat, y).backward()

# unit_variance_mlpg
R = torch.from_numpy(F.unit_variance_mlpg_matrix(windows, T))
y_hat = AF.unit_variance_mlpg(R, reshaped_means)
y_hat = AF.unit_variance_mlpg(R, means)
assert np.allclose(y.data.numpy(), y_hat.data.numpy())

nn.MSELoss()(y_hat, y).backward()

def test_unit_variance_mlpg():
# Test 3D tensor inputs
y_hat = AF.unit_variance_mlpg(R, means.view(1, -1, means.size(-1)))
assert np.allclose(
y.data.numpy(), y_hat.data.view(-1, static_dim).numpy())

nn.MSELoss()(y_hat.view(-1, static_dim), y).backward()

def test_unit_variance_mlpg_gradcheck():
static_dim = 2
T = 10

Expand All @@ -71,9 +80,6 @@ def test_unit_variance_mlpg():
requires_grad=True)

# Input for UnitVarianceMLPG
# Equivalent:
# reshaped_means = means.view(
# T, len(windows), -1).transpose(0, 1).contiguous().view(-1, static_dim)
reshaped_means = F.reshape_means(means.data.clone().numpy(), static_dim)
reshaped_means = Variable(torch.from_numpy(reshaped_means),
requires_grad=True)
Expand All @@ -82,20 +88,96 @@ def test_unit_variance_mlpg():
R = F.unit_variance_mlpg_matrix(windows, T).astype(np.float32)
R = torch.from_numpy(R)

y = UnitVarianceMLPG(R)(reshaped_means)
# UnitVarianceMLPG can take input with both means and reshaped_means
y1 = UnitVarianceMLPG(R)(means)
y2 = UnitVarianceMLPG(R)(reshaped_means)

# Unit variances
variances = torch.ones(static_dim * len(windows)
).expand(T, static_dim * len(windows))
y_hat = MLPG(variances, windows)(means)

# Make sure UnitVarianceMLPG and MLPG can get same result
# if we use unit variances
assert np.allclose(y.data.numpy(), y_hat.data.numpy())
for y in [y1,y2]:
assert np.allclose(y.data.numpy(), y_hat.data.numpy())

# Grad check
inputs = (reshaped_means,)
assert gradcheck(UnitVarianceMLPG(R),
inputs, eps=1e-3, atol=1e-3)

inputs = (means,)
assert gradcheck(UnitVarianceMLPG(R),
inputs, eps=1e-3, atol=1e-3)

def test_minibatch_unit_variance_mlpg_gradcheck():
static_dim = 2
T = 5

for windows in _get_windows_set():
batch_size = 5
torch.manual_seed(1234)

# Prepare inputs
means = torch.rand(T, static_dim * len(windows))
means_expanded = means.expand(
batch_size, means.shape[0], means.shape[1])
reshaped_means = torch.from_numpy(
F.reshape_means(means.numpy(), static_dim))
reshaped_means_expanded = reshaped_means.expand(
batch_size, reshaped_means.shape[0], reshaped_means.shape[1])

# Target
y = F.mlpg(means.numpy(), np.ones(static_dim*len(windows)), windows)
y = Variable(torch.from_numpy(y), requires_grad=False)
y_expanded = y.expand(batch_size, y.size(0), y.size(1))

# Pack into variables
means = Variable(means, requires_grad=True)
means_expanded = Variable(means_expanded, requires_grad=True)
reshaped_means = Variable(reshaped_means, requires_grad=True)
reshaped_means_expanded = Variable(
reshaped_means_expanded, requires_grad=True)

# Case 1: 2d with reshaped means
R = torch.from_numpy(F.unit_variance_mlpg_matrix(windows, T))
y_hat1 = AF.unit_variance_mlpg(R, reshaped_means)

# Case 2: 3d with reshaped means
y_hat2 = AF.unit_variance_mlpg(R, reshaped_means_expanded)
for i in range(batch_size):
assert np.allclose(y_hat1.data.numpy(), y_hat2[i].data.numpy())

nn.MSELoss()(y_hat1, y).backward()
nn.MSELoss()(y_hat2, y_expanded).backward()

# Check grad consistency
for i in range(batch_size):
grad1 = reshaped_means.grad.data.numpy()
grad2 = reshaped_means_expanded.grad[i].data.numpy()
assert np.allclose(grad1, grad2)

# Case 3: 2d with non-reshaped input
y_hat3 = AF.unit_variance_mlpg(R, means)

# Case 4: 3d with non-reshaped input
y_hat4 = AF.unit_variance_mlpg(R, means_expanded)

for i in range(batch_size):
assert np.allclose(y_hat1.data.numpy(), y_hat3.data.numpy())
assert np.allclose(y_hat3.data.numpy(), y_hat4[i].data.numpy())

nn.MSELoss()(y_hat3, y).backward()
nn.MSELoss()(y_hat4, y_expanded).backward()

# Check grad consistency
for i in range(batch_size):
grad1 = means.grad.data.numpy()
grad2 = means_expanded.grad[i].data.numpy()
assert np.allclose(grad1, grad2)


def test_mlpg_gradcheck():
# MLPG is performed dimention by dimention, so static_dim 1 is enough,
# 2 just for in case.
Expand Down

0 comments on commit ef5bd1c

Please sign in to comment.