Skip to content

Commit

Permalink
fix the oom problem of synthesis and fix dependecy of deep-voice3 (#55)
Browse files Browse the repository at this point in the history
* fix dependecy of deep-voice3

* fix oom of synthesis

* add multi gpu support

* update readme

* remove version error

* add cpu limit

* fix synthesis

* revert the version commit
  • Loading branch information
azraelkuan authored and r9y9 committed Apr 27, 2018
1 parent 4d5f68c commit 5b54777
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ When this is done, you will see time-aligned extracted features (pairs of audio

### 2. Training

>Note: for multi gpu training, you have better ensure that batch_size % num_gpu == 0
Usage:

```
Expand Down
1 change: 1 addition & 0 deletions synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from hparams import hparams


torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()


Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,9 @@ def __train_step(phase, epoch, global_step, global_test_step,
# NOTE: softmax is handled in F.cross_entrypy_loss
# y_hat: (B x C x T)

y_hat = model(x, c=c, g=g, softmax=False)
# multi gpu support
# you must make sure that batch size % num gpu == 0
y_hat = torch.nn.parallel.data_parallel(model, (x, c, g, False))

if is_mulaw_quantize(hparams.input_type):
# wee need 4d inputs for spatial cross entropy loss
Expand Down Expand Up @@ -742,7 +744,7 @@ def train_loop(model, data_loaders, optimizer, writer, checkpoint_dir=None):
averaged_loss = running_loss / len(data_loader)
writer.add_scalar("{} loss (per epoch)".format(phase),
averaged_loss, global_epoch)
print("[{}] Loss: {}".format(phase, running_loss / len(data_loader)))
print("Step {} [{}] Loss: {}".format(global_step, phase, running_loss / len(data_loader)))

global_epoch += 1

Expand Down
67 changes: 67 additions & 0 deletions wavenet_vocoder/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding: utf-8
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


class Conv1d(nn.Conv1d):
"""Extended nn.Conv1d for incremental dilated convolutions
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.clear_buffer()
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)

def incremental_forward(self, input):
# input: (B, T, C)
if self.training:
raise RuntimeError('incremental_forward only supports eval mode')

# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)

# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
dilation = self.dilation[0]

bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
if self.input_buffer is None:
self.input_buffer = input.new(bsz, kw + (kw - 1) * (dilation - 1), input.size(2))
self.input_buffer.zero_()
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
with torch.no_grad():
input = torch.autograd.Variable(self.input_buffer)
if dilation > 1:
input = input[:, 0::dilation, :].contiguous()
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)

def clear_buffer(self):
self.input_buffer = None

def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
# nn.Conv1d
if self.weight.size() == (self.out_channels, self.in_channels, kw):
weight = self.weight.transpose(1, 2).contiguous()
else:
# fairseq.modules.conv_tbc.ConvTBC
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight

def _clear_linearized_weight(self, *args):
self._linearized_weight = None
25 changes: 18 additions & 7 deletions wavenet_vocoder/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,26 @@
import numpy as np

import torch
from wavenet_vocoder import conv
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


def Conv1d(in_channels, out_channels, kernel_size, dropout=0, std_mul=4.0, **kwargs):
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((std_mul * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)


def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, std)
return m


def ConvTranspose2d(in_channels, out_channels, kernel_size,
weight_normalization=True, **kwargs):
freq_axis_kernel_size = kernel_size[0]
Expand All @@ -26,13 +41,11 @@ def Conv1d1x1(in_channels, out_channels, bias=True, weight_normalization=True):
"""1-by-1 convolution layer
"""
if weight_normalization:
from deepvoice3_pytorch.modules import Conv1d
assert bias
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
dilation=1, bias=bias, std_mul=1.0)
else:
from deepvoice3_pytorch.conv import Conv1d
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
return conv.Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
dilation=1, bias=bias)


Expand Down Expand Up @@ -85,14 +98,12 @@ def __init__(self, residual_channels, gate_channels, kernel_size,
self.causal = causal

if weight_normalization:
from deepvoice3_pytorch.modules import Conv1d
assert bias
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
dropout=dropout, padding=padding, dilation=dilation,
padding=padding, dilation=dilation,
bias=bias, std_mul=1.0, *args, **kwargs)
else:
from deepvoice3_pytorch.conv import Conv1d
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
self.conv = conv.Conv1d(residual_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation,
bias=bias, *args, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions wavenet_vocoder/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.autograd import Variable
from torch.nn import functional as F

from deepvoice3_pytorch.modules import Embedding
from .modules import Embedding

from .modules import Conv1d1x1, ResidualConv1dGLU, ConvTranspose2d
from .mixture import sample_from_discretized_mix_logistic
Expand Down Expand Up @@ -317,12 +317,14 @@ def incremental_forward(self, initial_input=None, c=None, g=None,
initial_input = initial_input.transpose(1, 2).contiguous()

current_input = initial_input

for t in tqdm(range(T)):
if test_inputs is not None and t < test_inputs.size(1):
current_input = test_inputs[:, t, :].unsqueeze(1)
else:
if t > 0:
current_input = outputs[-1]
current_input = Variable(current_input)

# Conditioning features for single time step
ct = None if c is None else c[:, t, :].unsqueeze(1)
Expand Down Expand Up @@ -352,8 +354,7 @@ def incremental_forward(self, initial_input=None, c=None, g=None,
np.arange(self.out_channels), p=x.view(-1).data.cpu().numpy())
x.zero_()
x[:, sample] = 1.0
outputs += [x]

outputs += [x.data]
# T x B x C
outputs = torch.stack(outputs)
# B x C x T
Expand Down

2 comments on commit 5b54777

@danshirron
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im having trouble with multi-gpu support:
The model_eval fails with:
File "train.py", line 496, in eval_model
length = input_lengths[idx].data.cpu().numpy()[0]
IndexError: too many indices for array

Also, Can you elaborate more on:

  1. for multi gpu training, you have better ensure that batch_size % num_gpu == 0. What happens when gpu=4 and batch=2?
  2. Your remark in multi-gpu support regarding model_eval: ", because in eval mode, we only use one gpu to generate samples which will waste other resources". Is it just a resource issue?

@azraelkuan
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danshirron
for your index error
i guess that you are using pytorch v0.4? you can fix
length = input_lengths[idx].data.cpu().numpy()[0]
to
length = input_lengths[idx].data.cpu().numpy()

for question 1
the batch size is the total batch size of all gpu, so if you use batch size 2 for one gpu and you can use batch size 8 for 4 gpus

for question 2
for multi gpus, it just will waste resource

Please sign in to comment.