Skip to content

Commit

Permalink
Fix causal block
Browse files Browse the repository at this point in the history
This commit fixes several issues:
 1. Remove necessary convolutions inside residual block
 2. Apply mask type 'A' on the very first layer
 3. Fix channel masking mess
  • Loading branch information
rampage644 committed Feb 26, 2017
1 parent 94b962b commit 0ef9c97
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions wavenet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,42 @@ def __call__(self, x):
class ResidualBlock(chainer.Chain):
def __init__(self, in_channels, out_channels, filter_size, mask='B', nobias=False):
super(ResidualBlock, self).__init__(
vertical_conv=CroppedConvolution(
in_channels, 2 * out_channels, ksize=[filter_size//2+1, filter_size],
vertical_conv_t=CroppedConvolution(
in_channels, out_channels, ksize=[filter_size//2+1, filter_size],
pad=[filter_size//2+1, filter_size//2]),
v_to_h_conv=MaskedConvolution2D(2 * out_channels, 2 * out_channels, 1, mask=mask),
vertical_gate_conv=L.Convolution2D(2*out_channels, 2*out_channels, 1),
horizontal_conv=CroppedConvolution(
in_channels, 2 * out_channels, ksize=[1, filter_size//2+1],
pad=[0, filter_size//2+1]),
horizontal_gate_conv=L.Convolution2D(2*out_channels, 2*out_channels, 1),
vertical_conv_s=CroppedConvolution(
in_channels, out_channels, ksize=[filter_size//2+1, filter_size],
pad=[filter_size//2+1, filter_size//2]),
v_to_h_conv_t=L.Convolution2D(out_channels, out_channels, 1),
v_to_h_conv_s=L.Convolution2D(out_channels, out_channels, 1),

horizontal_conv_t=MaskedConvolution2D(
in_channels, out_channels, ksize=[1, filter_size],
pad=[0, filter_size // 2], mask=mask),
horizontal_conv_s=MaskedConvolution2D(
in_channels, out_channels, ksize=[1, filter_size],
pad=[0, filter_size // 2], mask=mask),

horizontal_output=MaskedConvolution2D(out_channels, out_channels, 1, mask=mask),
label=L.EmbedID(10, 2*out_channels)
label=L.EmbedID(10, out_channels)
)

def __call__(self, v, h, label):
v = self.vertical_conv(v)
to_vertical = self.v_to_h_conv(v)
v_t = self.vertical_conv_t(v)
v_s = self.vertical_conv_s(v)
to_vertical_t = self.v_to_h_conv_t(v_t)
to_vertical_s = self.v_to_h_conv_s(v_s)

v_gate = self.vertical_gate_conv(v)
# label bias is addede to both vertical and horizontal conv
# v_gate = self.vertical_gate_conv(v)
# label bias is added to both vertical and horizontal conv
# here we take only shape as it should be the same
label = F.broadcast_to(F.expand_dims(F.expand_dims(self.label(label), -1), -1), v_gate.shape)
v_t, v_s = F.split_axis(v_gate + label, 2, axis=1)
label = F.broadcast_to(F.expand_dims(F.expand_dims(self.label(label), -1), -1), v_t.shape)
v_t, v_s = v_t + label, v_s + label
v = F.tanh(v_t) * F.sigmoid(v_s)

h_ = self.horizontal_conv(h)
h_t, h_s = F.split_axis(self.horizontal_gate_conv(h_ + to_vertical) + label, 2, axis=1)
h_t = self.horizontal_conv_t(h)
h_s = self.horizontal_conv_s(h)
h_t, h_s = h_t + to_vertical_t + label, h_s + to_vertical_s + label
h = self.horizontal_output(F.tanh(h_t) * F.sigmoid(h_s))

return v, h
Expand Down

0 comments on commit 0ef9c97

Please sign in to comment.