Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BahdanauMonoAttention cannot work well #8

Closed
zhbbupt opened this issue Jan 9, 2018 · 2 comments
Closed

BahdanauMonoAttention cannot work well #8

zhbbupt opened this issue Jan 9, 2018 · 2 comments

Comments

@zhbbupt
Copy link

zhbbupt commented Jan 9, 2018

I follow monotonic attention here: https://arxiv.org/pdf/1704.00784.pdf.

In tensorflow, it work well. (source code here: https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py. )

But in pytorch, it cannot work. Here is my source code. Could you take a look, please?

def safe_cumprod(x, exclusive=False, max_value=1):
    """
    exclusive=True: cumprod(x) = [1, x1, x1*x2, x1*x2*x3, ...]
    exclusive=False: cumprod(x) = [x1, x1*x2, x1*x2*x3, ...]
    Args:
        x (torch.Tensor): shape of [batch, input_dim]
        exclusive ():
        max_value (): clip max value

    Returns:

    """
    tiny = float(np.finfo(np.float32).tiny)
    clip_x = torch.clamp(x, tiny, max_value)
    cumprod_x = torch.exp(torch.cumsum(torch.log(clip_x), dim=1))
    if exclusive is True:
        return F.pad(cumprod_x, (1, 0, 0, 0), value=1)[:, :-1]
    else:
        return cumprod_x


class BahdanauAttention(nn.Module):
    def __init__(self, dim):
        super(BahdanauAttention, self).__init__()
        self.query_layer = nn.Linear(dim, dim, bias=False)
        self.tanh = nn.Tanh()
        self.v = Parameter(torch.Tensor(1, dim))
        self.reset_parameters()

    def reset_parameters(self):
        fan_in, fan_out = self.v.size()
        scale = 1 / max(1., (fan_in + fan_out) / 2.)
        limit = math.sqrt(3.0 * scale)
        self.v.data.uniform_(-limit, limit)

    def _alignment_probability(self, score, previous_alignment=None):
        return F.softmax(score, dim=1)

    def forward(self, query, processed_memory):
        """
        Args:
            query: (batch, 1, dim) or (batch, dim)
            processed_memory: (batch, max_time, dim)
        """
        if query.dim() == 2:
            # insert time-axis for broadcasting
            query = query.unsqueeze(1)
        # (batch, 1, dim)
        processed_query = self.query_layer(query)

        # (batch, max_time, 1)
        alignment = F.linear(self.tanh(processed_query + processed_memory), self.v)

        # (batch, max_time)
        return alignment.squeeze(-1)


class BahdanauMonoAttention(BahdanauAttention):
    """BahdanauMonoAttention
    """
    def __init__(self, dim):
        super(BahdanauMonoAttention, self).__init__(dim)
        self.score_bias = Parameter(torch.Tensor(1))
        self.reset_parameters()

    def reset_parameters(self):
        self.score_bias.data.zero_()

    def forward(self, query, processed_memory):
        return super(BahdanauMonoAttention, self).forward(query, processed_memory) + self.score_bias

    def _alignment_probability(self, score, previous_alignment=None):
        """
        _mono_score, https://arxiv.org/pdf/1704.00784.pdf
        Args:
            score (): shape of [batch, encoder_length]
            previous_alignment (): shape of [batch, encoder_length]

        Returns:

        """
       #score += Variable(torch.FloatTensor(np.random.randn(*score.shape) * 2).cuda())
        p_choose_i = F.sigmoid(score)
        cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, exclusive=True, max_value=1)
        attention = p_choose_i * cumprod_1mp_choose_i * torch.cumsum(
            previous_alignment / torch.clamp(cumprod_1mp_choose_i, 1e-10, 1.), dim=1)
        return attention



def get_mask_from_lengths(memory, memory_lengths):
    """Get mask tensor from list of length

    Args:
        memory: (batch, max_time, dim)
        memory_lengths: array like
    """
    mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
    for idx, l in enumerate(memory_lengths):
        mask[idx][:l] = 1
    return ~mask


class AttentionWrapper(nn.Module):
    def __init__(self, rnn_cell, attention_mechanism,
                 score_mask_value=-float("inf")):
        super(AttentionWrapper, self).__init__()
        self.rnn_cell = rnn_cell
        self.attention_mechanism = attention_mechanism
        self.score_mask_value = score_mask_value

    def forward(self, query, attention, cell_state, memory, previous_alignment=None,
                processed_memory=None, mask=None, memory_lengths=None):
        if processed_memory is None:
            processed_memory = memory
        if memory_lengths is not None and mask is None:
            mask = get_mask_from_lengths(memory, memory_lengths)

        # Concat input query and previous attention context
        cell_input = torch.cat((query, attention), -1)

        # Feed it to RNN
        cell_output = self.rnn_cell(cell_input, cell_state)

        # Alignment
        # (batch, max_time)
        alignment = self.attention_mechanism(cell_output, processed_memory)

        if mask is not None:
            mask = mask.view(query.size(0), -1)
            alignment.data.masked_fill_(mask, self.score_mask_value)

        # Normalize attention weight
        # alignment = F.softmax(alignment, dim=-1)
        alignment = self.attention_mechanism._alignment_probability(alignment, previous_alignment)

        # Attention context vector
        # (batch, 1, dim)
        attention = torch.bmm(alignment.unsqueeze(1), memory)

        # (batch, dim)
        attention = attention.squeeze(1)

        return cell_output, attention, alignment
class Decoder(nn.Module):
    def __init__(self, in_dim, r, use_mono=True):
        super(Decoder, self).__init__()
        self.in_dim = in_dim
        self.r = r
        self.prenet = Prenet(in_dim, sizes=[256, 128])
        # (prenet_out + attention context) -> output
        if use_mono is True:
            attention_mechanism = BahdanauMonoAttention(256)
        else:
            attention_mechanism = BahdanauAttention(256)
        self.attention_rnn = AttentionWrapper(
            nn.GRUCell(256 + 128, 256),
            attention_mechanism
        )
        self.memory_layer = nn.Linear(256, 256, bias=False)
        self.project_to_decoder_in = nn.Linear(512, 256)

        self.decoder_rnns = nn.ModuleList(
            [nn.GRUCell(256, 256) for _ in range(2)])

        self.proj_to_mel = nn.Linear(256, in_dim * r)
        self.max_decoder_steps = 200

    def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
        """
        Decoder forward step.

        If decoder inputs are not given (e.g., at testing time), as noted in
        Tacotron paper, greedy decoding is adapted.

        Args:
            encoder_outputs: Encoder outputs. (B, T_encoder, dim)
            inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time),
              decoder outputs are used as decoder inputs.
            memory_lengths: Encoder output (memory) lengths. If not None, used for
              attention masking.
        """
        B = encoder_outputs.size(0)
        T_encoder = encoder_outputs.size(1)

        processed_memory = self.memory_layer(encoder_outputs)
        if memory_lengths is not None:
            mask = get_mask_from_lengths(processed_memory, memory_lengths)
        else:
            mask = None

        # Run greedy decoding if inputs is None
        greedy = inputs is None

        if inputs is not None:
            # Grouping multiple frames if necessary
            if inputs.size(-1) == self.in_dim:
                inputs = inputs.view(B, inputs.size(1) // self.r, -1)
            assert inputs.size(-1) == self.in_dim * self.r
            T_decoder = inputs.size(1)

        # go frames
        initial_input = Variable(
            encoder_outputs.data.new(B, self.in_dim).zero_())

        # Init decoder states
        attention_rnn_hidden = Variable(
            encoder_outputs.data.new(B, 256).zero_())
        decoder_rnn_hiddens = [Variable(
            encoder_outputs.data.new(B, 256).zero_())
            for _ in range(len(self.decoder_rnns))]
        current_attention = Variable(
            encoder_outputs.data.new(B, 256).zero_())

        # Time first (T_decoder, B, in_dim)
        if inputs is not None:
            inputs = inputs.transpose(0, 1)

        outputs = []
        alignments = []

        t = 0
        current_input = initial_input
        previous_alignment = Variable(
            encoder_outputs.data.new(B, T_encoder).zero_())
        previous_alignment[:, 0] = 1.0
        while True:
            if t > 0:
                current_input = outputs[-1] if greedy else inputs[t - 1]
                current_input = current_input[:, -self.in_dim:]
            # Prenet
            current_input = self.prenet(current_input)

            # Attention RNN
            attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
                current_input, current_attention, attention_rnn_hidden,
                encoder_outputs, previous_alignment=previous_alignment,
                processed_memory=processed_memory, mask=mask)
            previous_alignment = alignment

            # Concat RNN output and attention context vector
            decoder_input = self.project_to_decoder_in(
                torch.cat((attention_rnn_hidden, current_attention), -1))

            # Pass through the decoder RNNs
            for idx in range(len(self.decoder_rnns)):
                decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
                    decoder_input, decoder_rnn_hiddens[idx])
                # Residual connectinon
                decoder_input = decoder_rnn_hiddens[idx] + decoder_input

            output = decoder_input
            output = self.proj_to_mel(output)

            outputs += [output]
            alignments += [alignment]

            t += 1

            if greedy:
                if t > 1 and is_end_of_frames(output):
                    break
                elif t > self.max_decoder_steps:
                    print("Warning! doesn't seems to be converged")
                    break
            else:
                if t >= T_decoder:
                    break

        assert greedy or len(outputs) == T_decoder

        # Back to batch first
        alignments = torch.stack(alignments).transpose(0, 1)
        outputs = torch.stack(outputs).transpose(0, 1).contiguous()

        return outputs, alignments

@r9y9

@zhbbupt zhbbupt mentioned this issue Jan 9, 2018
@stale
Copy link

stale bot commented May 30, 2019

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label May 30, 2019
@r9y9 r9y9 added discussion and removed wontfix labels May 30, 2019
@stale
Copy link

stale bot commented Jul 29, 2019

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 29, 2019
@stale stale bot closed this as completed Aug 5, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants