-
Notifications
You must be signed in to change notification settings - Fork 500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix the oom problem of synthesis and fix dependecy of deep-voice3 #55
Merged
Merged
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c2c9d53
fix dependecy of deep-voice3
azraelkuan 5f6d5a3
fix oom of synthesis
azraelkuan dce2e99
add multi gpu support
azraelkuan 720d331
update readme
azraelkuan fae89e9
remove version error
azraelkuan b2df851
add cpu limit
azraelkuan ea08686
fix synthesis
azraelkuan 847d674
revert the version commit
azraelkuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
from hparams import hparams | ||
|
||
|
||
torch.set_num_threads(4) | ||
use_cuda = torch.cuda.is_available() | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
# coding: utf-8 | ||
from __future__ import with_statement, print_function, absolute_import | ||
|
||
from .version import __version__ | ||
|
||
from .wavenet import receptive_field_size, WaveNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# coding: utf-8 | ||
import torch | ||
from torch import nn | ||
from torch.autograd import Variable | ||
from torch.nn import functional as F | ||
|
||
|
||
class Conv1d(nn.Conv1d): | ||
"""Extended nn.Conv1d for incremental dilated convolutions | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.clear_buffer() | ||
self._linearized_weight = None | ||
self.register_backward_hook(self._clear_linearized_weight) | ||
|
||
def incremental_forward(self, input): | ||
# input: (B, T, C) | ||
if self.training: | ||
raise RuntimeError('incremental_forward only supports eval mode') | ||
|
||
# run forward pre hooks (e.g., weight norm) | ||
for hook in self._forward_pre_hooks.values(): | ||
hook(self, input) | ||
|
||
# reshape weight | ||
weight = self._get_linearized_weight() | ||
kw = self.kernel_size[0] | ||
dilation = self.dilation[0] | ||
|
||
bsz = input.size(0) # input: bsz x len x dim | ||
if kw > 1: | ||
input = input.data | ||
if self.input_buffer is None: | ||
self.input_buffer = input.new(bsz, kw + (kw - 1) * (dilation - 1), input.size(2)) | ||
self.input_buffer.zero_() | ||
else: | ||
# shift buffer | ||
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() | ||
# append next input | ||
self.input_buffer[:, -1, :] = input[:, -1, :] | ||
with torch.no_grad(): | ||
input = torch.autograd.Variable(self.input_buffer) | ||
if dilation > 1: | ||
input = input[:, 0::dilation, :].contiguous() | ||
output = F.linear(input.view(bsz, -1), weight, self.bias) | ||
return output.view(bsz, 1, -1) | ||
|
||
def clear_buffer(self): | ||
self.input_buffer = None | ||
|
||
def _get_linearized_weight(self): | ||
if self._linearized_weight is None: | ||
kw = self.kernel_size[0] | ||
# nn.Conv1d | ||
if self.weight.size() == (self.out_channels, self.in_channels, kw): | ||
weight = self.weight.transpose(1, 2).contiguous() | ||
else: | ||
# fairseq.modules.conv_tbc.ConvTBC | ||
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() | ||
assert weight.size() == (self.out_channels, kw, self.in_channels) | ||
self._linearized_weight = weight.view(self.out_channels, -1) | ||
return self._linearized_weight | ||
|
||
def _clear_linearized_weight(self, *args): | ||
self._linearized_weight = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you revert the change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok