Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the oom problem of synthesis and fix dependecy of deep-voice3 #55

Merged
merged 8 commits into from
Apr 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ When this is done, you will see time-aligned extracted features (pairs of audio

### 2. Training

>Note: for multi gpu training, you have better ensure that batch_size % num_gpu == 0

Usage:

```
Expand Down
1 change: 1 addition & 0 deletions synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from hparams import hparams


torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()


Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,9 @@ def __train_step(phase, epoch, global_step, global_test_step,
# NOTE: softmax is handled in F.cross_entrypy_loss
# y_hat: (B x C x T)

y_hat = model(x, c=c, g=g, softmax=False)
# multi gpu support
# you must make sure that batch size % num gpu == 0
y_hat = torch.nn.parallel.data_parallel(model, (x, c, g, False))

if is_mulaw_quantize(hparams.input_type):
# wee need 4d inputs for spatial cross entropy loss
Expand Down Expand Up @@ -742,7 +744,7 @@ def train_loop(model, data_loaders, optimizer, writer, checkpoint_dir=None):
averaged_loss = running_loss / len(data_loader)
writer.add_scalar("{} loss (per epoch)".format(phase),
averaged_loss, global_epoch)
print("[{}] Loss: {}".format(phase, running_loss / len(data_loader)))
print("Step {} [{}] Loss: {}".format(global_step, phase, running_loss / len(data_loader)))

global_epoch += 1

Expand Down
67 changes: 67 additions & 0 deletions wavenet_vocoder/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding: utf-8
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


class Conv1d(nn.Conv1d):
"""Extended nn.Conv1d for incremental dilated convolutions
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.clear_buffer()
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)

def incremental_forward(self, input):
# input: (B, T, C)
if self.training:
raise RuntimeError('incremental_forward only supports eval mode')

# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)

# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
dilation = self.dilation[0]

bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
if self.input_buffer is None:
self.input_buffer = input.new(bsz, kw + (kw - 1) * (dilation - 1), input.size(2))
self.input_buffer.zero_()
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
with torch.no_grad():
input = torch.autograd.Variable(self.input_buffer)
if dilation > 1:
input = input[:, 0::dilation, :].contiguous()
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)

def clear_buffer(self):
self.input_buffer = None

def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
# nn.Conv1d
if self.weight.size() == (self.out_channels, self.in_channels, kw):
weight = self.weight.transpose(1, 2).contiguous()
else:
# fairseq.modules.conv_tbc.ConvTBC
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight

def _clear_linearized_weight(self, *args):
self._linearized_weight = None
25 changes: 18 additions & 7 deletions wavenet_vocoder/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,26 @@
import numpy as np

import torch
from wavenet_vocoder import conv
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


def Conv1d(in_channels, out_channels, kernel_size, dropout=0, std_mul=4.0, **kwargs):
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((std_mul * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)


def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, std)
return m


def ConvTranspose2d(in_channels, out_channels, kernel_size,
weight_normalization=True, **kwargs):
freq_axis_kernel_size = kernel_size[0]
Expand All @@ -26,13 +41,11 @@ def Conv1d1x1(in_channels, out_channels, bias=True, weight_normalization=True):
"""1-by-1 convolution layer
"""
if weight_normalization:
from deepvoice3_pytorch.modules import Conv1d
assert bias
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
dilation=1, bias=bias, std_mul=1.0)
else:
from deepvoice3_pytorch.conv import Conv1d
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
return conv.Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
dilation=1, bias=bias)


Expand Down Expand Up @@ -85,14 +98,12 @@ def __init__(self, residual_channels, gate_channels, kernel_size,
self.causal = causal

if weight_normalization:
from deepvoice3_pytorch.modules import Conv1d
assert bias
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
dropout=dropout, padding=padding, dilation=dilation,
padding=padding, dilation=dilation,
bias=bias, std_mul=1.0, *args, **kwargs)
else:
from deepvoice3_pytorch.conv import Conv1d
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
self.conv = conv.Conv1d(residual_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation,
bias=bias, *args, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions wavenet_vocoder/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.autograd import Variable
from torch.nn import functional as F

from deepvoice3_pytorch.modules import Embedding
from .modules import Embedding

from .modules import Conv1d1x1, ResidualConv1dGLU, ConvTranspose2d
from .mixture import sample_from_discretized_mix_logistic
Expand Down Expand Up @@ -317,12 +317,14 @@ def incremental_forward(self, initial_input=None, c=None, g=None,
initial_input = initial_input.transpose(1, 2).contiguous()

current_input = initial_input

for t in tqdm(range(T)):
if test_inputs is not None and t < test_inputs.size(1):
current_input = test_inputs[:, t, :].unsqueeze(1)
else:
if t > 0:
current_input = outputs[-1]
current_input = Variable(current_input)

# Conditioning features for single time step
ct = None if c is None else c[:, t, :].unsqueeze(1)
Expand Down Expand Up @@ -352,8 +354,7 @@ def incremental_forward(self, initial_input=None, c=None, g=None,
np.arange(self.out_channels), p=x.view(-1).data.cpu().numpy())
x.zero_()
x[:, sample] = 1.0
outputs += [x]

outputs += [x.data]
# T x B x C
outputs = torch.stack(outputs)
# B x C x T
Expand Down