Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Including reference to teacher_forcing_ratio and padding_idx to seq2seq_translation_tutorial #2870

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

brcolli
Copy link

@brcolli brcolli commented May 15, 2024

Fixes #2840

Description

I have added details regarding teacher_forcing_ratio, as well as padding_idx. I agree that padding_idx needs to be set to 0 to properly handle the padding.

I do think having batch processing is still meaningful, even if the sentences are short (max 10 words in the tutorial).

I didn't want to include too much discussion about the pros/cons of batch processing and padding, as I feel that might be out of scope.

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

cc @albanD

Copy link

pytorch-bot bot commented May 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2870

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@svekars
Copy link
Contributor

svekars commented May 15, 2024

CC: @spro

@brcolli
Copy link
Author

brcolli commented May 16, 2024

Hello! Possibly a silly question, but I see my PR failed from a spell check issue. It looks like it's complaining about AttnDecoderRNN.forward_step(). Is this something I add as an ignore (regex?) in the .pyspelling.yml file or something?

@gavril0
Copy link

gavril0 commented May 17, 2024

Hi, I am still interested by this issue as I am still working on porting this tutorial to R torch. Using batch training with
variable length sequences brings many subtle issues which I think should be discussed in a tutorial that implement
such a case. In case this is useful, I am happy to share my current implementation with some comments.

For the encoder, it is necessary to pad and pack the sentences to take advantage of torch built-in capacity
to deal with variable lenght sequences. Padding value must be indicated at the embedding stage
and packing is necessary for the RNN modules to process correctly sentences with variable lenghts in the batch.

s2s_encoder <- nn_module(
 "s2s_encoder",  
 initialize = function(input_size, embedding_size, hidden_size, padding=NULL) {
   self$padding <- padding
   # note: padding value must cannot be 0 for embeddings in R torch (must be a valid R index)
   self$embedding <- nn_embedding(input_size, embedding_size, padding_idx = padding)
   self$gru <- nn_gru(embedding_size, hidden_size, batch_first=TRUE)
 },
 # the encoder input is a list of tensors that code each input sentence in the batch
 forward = function(list_tensors) {
    # length of sentences
    input_len <- lengths(list_tensors)     
    # 2d tensor (batch_size, max_len) where max_len = max(input_len) and shorted sentences are padded
    padded <- nn_utils_rnn_pad_sequence(list_tensors, batch_first = TRUE, padding_value = self$padding)
    # 3d tensor (batch_size, seq_len, embedding_size); shorter sequence are padded with 0
    embedded <- self$embedding$forward(padded)   
    # pack embedded tensor
    packed <- nn_utils_rnn_pack_padded_sequence(embedded, torch_tensor(input_len), 
       batch_first=TRUE, enforce_sorted=FALSE)     
    # RNN with one layer. The dimension of the input (batch_size, max_len,  embedding size).
    # By default, the initial value for the hidden state is zero.
    out <- self$gru(packed)
    # returns history of GRU (last layer) hidden state in a 3d tensor with
    padded <- nn_utils_rnn_pad_packed_sequence(out[[1]],  batch_first = TRUE, padding_value = 0)     
    list(H=torch_transpose(padded[[1]], dim0=2, dim1=3) , lengths=as.numeric(padded[[2]]), hidden=out[[2]])
 }
)

For the decoder, it is not possible to take advange of padding and packing in the same way
because the input for the decoder at each step depends on the hidden state at the previous
step in a manner that requires implementing the decoding process step-by-step. More precisely,
at each time step, it is necessary

  1. update decoder RNN hidden state
  2. use all encoder states and the current value of decoder hidden states to compute attention weights
  3. use attention weights to define a new current context
  4. use the new context to find the best output. The best output and new context are used input for the next step

My implementation follows loosely the tutorial.

In my implementation, the attention modules take the encoder hidden states (H) and the current decoder
hidden state (si) to compute the attention scores.

# dot attention
attn_dot_module <- nn_module(
  initialize = function() {
    # this module does not have any parameter
  },
  # compute attention weights using dot product:
  #         eij = si ^T hj for j=1,...,seq_len
  # where
  # + si: decoder hidden state  (batch_size, 1, hidden_size)
  # + H: history of encoder hidden states  (batch_size, hidden_size, max_len) and 
  #      max_len is the maximum length of input sequences
  forward = function(si, H) {
    # (B, 1, seq_len) = (B, 1, hidden_size)*(B, hidden_size, seq_len) 
    torch_matmul(si, H) 
  }
)

# general attention
attn_general_module <- nn_module(
  initialize = function(hidden_size) {
    self$Wa <- nn_parameter( torch_randn(hidden_size, hidden_size) )
  },
  # compute attention weights using generalized dot product:
  #         eij = si ^T Wa hj for j=1,...,seq_len
  # where Wa is a matrix of parameters (hidden_size, hidden_size)
  forward = function (si, H) {
    #              (si Wa)       H            si      Wa     H
    # (B, 1, m) = (B, 1, H)  (B, H, m) =  (B, 1, H) (H, H) (B, H, m) 
    torch_matmul(si, torch_matmul(self$Wa, H)) 
  }
)

# concat attention
attn_concat_module <- nn_module(
  "attn_concat_module",
  initialize = function(hidden_size) {
      self$Wa <- nn_parameter( torch_randn(hidden_size, hidden_size) )
      self$Ua <- nn_parameter( torch_randn(hidden_size, hidden_size) )
      self$Va <- nn_parameter( torch_randn(1, hidden_size) )
  },
  forward = function (si, H) {
    # Ua*H:           (h, h)*(b, h, m) => (b, h, m) 
    # Wa*si:          (h, h)*(b, h, 1) => (b, h, 1) 
    # Wa*si + Ua*H :  (b, h, 1)+(b, h, m) (broadcasting)=> (b, h, m)+(b, h, m) => (b, h, m) 
    # Va^T*tanh(...): (1, h)*(b,h,m) => (b, 1, m)
    torch_matmul(self$Va,
      torch_tanh( torch_add(
        torch_matmul(self$Wa, torch_transpose(si,2,3)),
        torch_matmul(self$Ua, H)))) 
  }
)

In order to compute the attention weights, the function masked_softmax masks the end of shorter
input sequences in the batch to normalize the scores correctly.

#  Compute attention weights by normalizing scores computed by attention module
# + scores: scores computed by attention module (batch_size, 1, max_len)
# + lenghts: lengths of input sentences
masked_softmax <- function(scores, lengths) {
  d <- scores$size()      )
  attn_weights <- scores$exp() 
  # mask extra attention scores 
  mask <- array(matrix(1:d[3], d[1], d[3], byrow=TRUE) <=  matrix(lengths, d[1], d[3]), d)
  attn_weights[!mask] <- 0 # set to 0 elements that are longer than the sequence(s).
  # compute the attention scores  by dividing each scores by the sum 
  # note: the sum is expanded broadcasted before performing elementwise division
  attn_weights$div(attn_weights$sum(dim=3)$unsqueeze(3))
}

The follwing decoder module performs a single decoding step

# This module implements a single step of the decoding process
s2s_attn_decoder <- nn_module(
  "s2s_attn_decoder",
  # Initialization of attention decoder instance
  # - attn_module: instance of an attention module.
  # - output_size: size of the target vocabulary including special tokens
  # - embedding_size: size of the embedding
  # - hidden_size: size of the RNN hidden state
  initialize= function(attn_module, embedding_size, hidden_size, output_size) {
    self$attn <- attn_module
    self$embedding <-  nn_embedding(output_size, embedding_size)
    self$gru <- nn_gru(embedding_size+hidden_size, hidden_size, batch_first=TRUE)
    self$output <- nn_linear(hidden_size, output_size)
  },
  # Perform a single step of the decoding process
  # - input: integer representing a word/token (batch_size, 1).  The integer value must be 
  #          between 1 and vocab_size included.
  # - context: context (batch_size, 1, hidden_size)
  # - hidden: initial value of the RNN hidden state (1, batch_size, hidden_size)  
  # - H: history of encoder hidden states (batch_size, hidden_size, max_input_len).  
  # - lengths: lenghts of the input sentences in the batch
  # All arguments must be tensors.
  forward = function(input, context, hidden, H, lengths) {
    #
    # 1. update decoder RNN hidden state
    #
    # embed input
    embedded <- self$embedding(input) |> nnf_relu()      # (batch_size, 1, embedding_size)
    # the input of the RNN is the concatenation of the embedded token and context 
    # note: for concatenation along the third dimension, the two first dimensions must 
    #       have the same number of elements
    dec_input  <- torch_cat(list(embedded, context), dim=3)  # (batch_size, 1, embedding_size + hidden_size)
    # update hidden state
    dec_out <- self$gru(dec_input, hx=hidden) 
    # the output is the updated hidden state
    # note: since there is only one element and one hidden state, the two ouput are equivalent    
    #       except for the order of the dimensions
    si <- dec_out[[1]]       # (batch_size, 1, hidden_size) history of last hidden state
    hidden <- dec_out[[2]]   # (1, batch_size, hidden_size) last value of all hidden state 
    #
    # 2. compute attention and a new context
    #
    # attention weights
    scores <- attn$forward(si, H)                    # (batch_size, 1, seq_len)
    aij <- masked_softmax(scores, lengths)
    # update the contex by computing weighted average of encoder hidden state
    # (batch_size,1,max_len) (batch_size, max_lenght,hidden_size) => (batch_size,1,hidden_size) 
    ci <- torch_matmul(aij, torch_transpose(H, dim0=2, dim1=3))
    #
    # 3. identify word with highest probability 
    #
    # use a linear transformation to transform hidden state into the output space that has the dimension 
    # of the target vacabular and used the softmax transformation so that values corresponds to 
    # a probability distribution
    prob <- self$output$forward(ci) |>  nnf_log_softmax(dim=3)  # (batch_size, 1, output_size)
    best_token <- torch_topk(prob, k=1, dim=3)[[2]]$squeeze(3)  # (batch_size, 1)
    #
    # output
    list(prob=prob, best_token=best_token, attn_weights=aij, context=ci, hidden=hidden)
  }
)

Note that the input for the decoder at each step is the concatenation of the previously generated word
with a context that is updated with the attention mechanism at each decoder step
(as in Bahdanau model and in Luong input-feeding approach model).

A final module puts everything together and implements the full encoding and decoding process.
The module implements a forward function the generate/translate a batch of sentences (forward), which
makes sense to use once the network is trained, and a loss function that computes the loss
during batch training:

  • In the forward function, the generation of new words must stop when the special EOS
    token is generated. It is also necessary to limit the number of decoding steps in case
    the decoder does not generate the EOS token.
  • In the loss function, the number of decoding steps during training is limited by the lenghts
    of the target sentences (one needs to know both to compute the loss). Also, when using forced
    teaching, the target sentences are also used to define the inputs of the decoder.

The implementation is a bit tricky because the number of sentences to be processed in the
batch will decrease during the decoding process, either because a EOS has been
generated or because the target sentence is shorter.

s2s_attn_module <- nn_module(
  "s2s_attn, module",
  initialize = function(encoder, decoder, special_tokens, optim_fun, learning_rate) {
    self$special_tokens <- special_tokens
    self$encoder <- encoder
    self$decoder <- decoder
    self$enc_optim <- optim_fun(self$encoder$parameters, lr=learning_rate)
    self$dec_optim <- optim_fun(self$decoder$parameters, lr=learning_rate)
  },  
  # Translate an input sentence in a target language. 
  # + input: list of sentences (R numeric vectors)
  # + max_length: maximum length of the output sentence (default to 10 or to longest target sentence)
  forward = function(input, max_output_len) {
    batch_size <- length(input)
    input_len <- lengths(input) # lenghts of input sentences
    # vector to store output sentence
    output <- matrix(0, batch_size, max_output_len)
    # array to store the attention weights 
    attn_weights <- array(0, c(batch_size, max(input_len), max_output_len)) 
    # 1. Encoding 
    enc_out <- self$encoder$forward(input) 
    # 2. Decoding  
    # initial input and state of the decoder  
    sos_token <- torch_tensor(matrix(self$special_tokens["<SOS>"], batch_size, 1), dtype=torch_long())
    state <- list(best_token=sos_token,                        # (batch_size, 1) 
      context = torch_transpose(enc_out$hidden,dim0=1,dim1=2), # (batch_size, 1, hidden_size)
      hidden = enc_out$hidden)                                 # (1, batch_size, hidden_size)
    # flag completed output sentences
    b_index <- 1:batch_size        # index of unterminated sentences  in the batch 
    index <- rep(TRUE, batch_size) # sub-index 
    for (i in 1:max_output_len) { # i <- 1
      # decoder step
      # the batch dimension decreases when sentences are terminated. The `index` vector
      # selects the hidden states and output of unterminated sentences from previous step for 
      # the next step. The `b_index` numerical vector contains the positions of the unterminated 
      # sentences in the original batch.          
      state <- self$decoder$forward(
                      input  = state$best_token[index,,drop=FALSE], 
                      context= state$context[index,,,drop=FALSE], 
                      hidden = state$hidden[,index,,drop=FALSE],
                      H = enc_out$H[b_index,,,drop=FALSE],
                      lengths=enc_out$lengths[b_index])
      # store attention weights (batch_size, 1, seq_length) into 3d array (batch_size, seq_length, max_length)
      attn_weights[b_index,,i] <- as.numeric(state$attn_weights[,1,])
      # store output (best token)
      output[b_index,i] <- as.numeric(state$best_token)
      #  check sentences in the batch that are noted terminated (EOS)
      index <- output[b_index,i]!=self$special_tokens["<EOS>"]
      b_index <- b_index[index]  # keep unterminated sentences
      # exit from loop if all sentences are terminated
      if(length(b_index)==0) break
    }
    # return output
    list(output=output[,1:i,drop=FALSE], attn_weights=attn_weights[,,1:i,drop=FALSE])
  },
  # compute the loss
  loss = function(input, target, loss_reduction="mean", teacher_forcing=FALSE) {
    if(length(input)!=length(target)) stop("target must have the same number of sentences")
    batch_size <- length(input)
    target_len <- lengths(target)          # lenghts of target sentences
    max_target_len <- max(target_len)      # maximum lenght of target sentences
    # initial loss value 
    total_loss <- torch_tensor(0) 
    losses <- if(loss_reduction=="none") {
        matrix(0, batch_size, max_target_len) # for debugging
      } else NULL
    # transform list of target sentences in a padded 2d array (batch_size, seq_len)
    # padding value (by default -100, see nnf_nll_loss)
    padding <- if(!is.null(self$special_tokens["<PAD>"])) self$special_tokens["<PAD>"] else -100
    target_padded <- nn_utils_rnn_pad_sequence(target, batch_first = TRUE, padding_value = padding)
    # vector to store output sentence
    output <- matrix(0, batch_size, max_target_len)
    # array to store the attention weights 
    attn_weights <- array(0, c(batch_size, max(lengths(input)), max_target_len)) 
    #
    # 1. Encoding 
    #
    enc_out <- self$encoder$forward(input) 
    #
    # 2. Decoding  
    #
    # initial input and state of the decoder  
    sos_token <- torch_tensor(matrix(self$special_tokens["<SOS>"], batch_size, 1), dtype=torch_long())
    state <- list(best_token=sos_token,                        # (batch_size, 1) 
      context = torch_transpose(enc_out$hidden,dim0=1,dim1=2), # (batch_size, 1, hidden_size)
      hidden = enc_out$hidden)                                 # (1, batch_size, hidden_size)
    # flag completed output sentences
    b_index <- 1:batch_size        # index of unterminated sentences  in the batch 
    index <- rep(TRUE, batch_size) # sub-index 
    for (i in 1:max_target_len) { # i <- 4
      # decoder step
      state <- self$decoder$forward(
                      input  = state$best_token[index,,drop=FALSE], 
                      context= state$context[index,,,drop=FALSE], 
                      hidden = state$hidden[,index,,drop=FALSE],
                      H = enc_out$H[b_index,,,drop=FALSE],
                      lengths=enc_out$lengths[b_index])
      # store attention weights (batch_size, 1, seq_length) into 3d array (batch_size, seq_length, max_length)
      attn_weights[b_index,,i] <- as.numeric(state$attn_weights[,1,])
      # store output (best token)
      output[b_index,i] <- as.numeric(state$best_token)
      # compute loss
      #   `input` must have the format (batch_size, output_size)
      #   `target` must have the format (batch_size)
      # ignored_index is not neededed because only unterminated sentences
      loss <- nnf_nll_loss(input=state$prob$squeeze(2), target=target_padded[b_index,i]$view(length(b_index)), 
          reduction=loss_reduction) # ignore_index = padding
      if(loss_reduction=="none") {  
        # compute and sore loss for each word of every sentence in the batch separately (for debugging)
        losses[b_index,i] <- as.numeric(loss)  
        # total loss 
        total_loss <- total_loss + torch_sum(loss)
      } else {
        total_loss <- total_loss + loss
      }
      # use word in target sentence as input
      if(teacher_forcing) state$best_token <- target_padded[b_index,i,drop=FALSE]
      # keep unterminated sentences, i.e. sentences that are longer than the target sentence
      index <- target_len[b_index]>i
      b_index <- b_index[index]       # keep unterminated sentences
    }
    # return output
    list(output=output, attn_weights=attn_weights, losses=losses, total_loss=total_loss)
  },
  # Training function
  # + input: R numeric vector with tokens represeting the input sequence
  # + target: R numeric vector with tokens represeting the target sequence
  # + teacher_forcing_ratio: proporition of sentences where teacher forcing should be used
  #                          during training (default to 0).
  train = function(input, target, teacher_forcing_ratio=0) {
    # zero gradients
    self$enc_optim$zero_grad()
    self$dec_optim$zero_grad()
    # compute loss
    forcing <- rbinom(1, 1, prob=teacher_forcing_ratio)==1
    val <- self$loss(input=input, target=target, teacher_forcing=forcing, loss_reduction="mean")
    # compute gradient
    val$total_loss$backward()
    # update parameters
    self$enc_optim$step()
    self$dec_optim$step()
    # 
    list(output=val$output, loss=as.numeric(val$loss_tensor))
  }
)

In this implementation of the decoder, I use the two indices (b_index) and (index) to track sentences
in the batch that still neeed to be processed. During training, the loss is computed only for these sentences.
Although padding is used to define a 2d tensor (target_padded) that contains all target sentences
in the batch, the decoder does actually process any padded value because it processes only sentences
that still need to be processed in the batch.

The full model can be tested

# vocabulary   
vocab <- list(input=1:6, target=1:6, special_tokens=c("<PAD>"=1, "<EOS>"=2, "<SOS>"=3))

# dimensions
input_size <- length(vocab$input)
output_size <- length(vocab$target)
embedding_size <- 3
hidden_size <- 4

# instanciate modules
attn <- attn_concat_module(hidden_size)
enc <- s2s_encoder(input_size, embedding_size, hidden_size, padding=vocab$special_tokens["<PAD>"] )
dec <- s2s_attn_decoder(attn, embedding_size, hidden_size, output_size)
s2s <- s2s_attn_module(enc, dec, vocab$special, optim_fun=optim_adam, learning_rate=0.001)

# input and target sentences
sentences <- list(
  input = list(c(6,5,4,3,2), c(6,5,4,2)),
  target= list(c(4,6,5,2),   c(6,5,4,3,2))
)
batch_size <- length(sentences$input)

# transform R vector into tensor
input <- lapply(sentences$input, torch_tensor, dtype=torch_long())
target <- lapply(sentences$target, torch_tensor, dtype=torch_long())

# translate a batch of sentences
out <- s2s$forward(input, max_output_len=6)
str(out)

# compute loss for a batch of setnences
out <- s2s$loss(input, target, loss_reduction="none") 
str(out)

There is still an issue when computing the gradient in the train function (there is an error about a variable
being modified in place and it could be related to R port of torch) but I think that it is important to address
these issues in one way or another in the tutorial. Some of the comments might also not reflect the current code.

I am interested to see if there is a more efficient implementation of the decoding process that would
take advantage of padding or packing but I don't see how it can be done since the input of the RNN
depends on the hidden state and that cannot be done without an explicit loop unlike in the encoder
where the input can be defined beforehand.

@svekars
Copy link
Contributor

svekars commented May 24, 2024

@spro do the changes look good to you? Please leave a comment.

@svekars svekars added the core Tutorials of any level of difficulty related to the core pytorch functionality label May 24, 2024
@brcolli
Copy link
Author

brcolli commented May 27, 2024

@gavril0 To your last point about optimizing the decoder step of an RNN, I think that would be an interesting tutorial! This could be discussed based on beam search. I'm unaware of other optimization techniques.

I would argue this could be its own tutorial, perhaps under advanced. If others disagree and would like this included in this tutorial, I'd be happy to add it.

brcolli and others added 3 commits May 27, 2024 10:39
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
@gavril0
Copy link

gavril0 commented May 27, 2024

@brcolli Thanks for the work and the pointer toward beam search. I am still refining my implementation with indices and waiting for an issue/bug in R torch port be fixed before making it public in github.

I looked again at the Python code and my understanding is that the train_epoch() function takes the decoder output and the target setnences to compute the loss by comparing the two padded tensors wihout taking into account the fact that the loss is not defined when the output is longer than the target sentences. A a result, the loss signal might become noisy toward the end of the sentences, which might slow down learning (independently of the issue of optimizing the implementation). A possible solution might be to use ignore_index when computing the loss to mask outputs longer than the target sentence but I don't think that the Python code does it.

For info, I had an error about a variable being modified in place when computing the gradient because I assigned the value 0 to some elements of the attention weights matrix in the masked_softmax() function (see above). The current implementation used attentionweights$masked_fill(mask, 0) instead:

# Normalize attention scores: 
# + scores (batch_size, 1, max_input_len)
# + lengths: vector with input sentence lenghts
masked_softmax <- function(scores, lengths) {
  d <- scores$size()             # (batch_size, 1, max_len)
  # identify elements after the ends of the shorter sentences
  mask <- array(matrix(1:d[3], d[1], d[3], byrow=TRUE) >  matrix(lengths, d[1], d[3]), d)
  mask <-torch_tensor(mask, dtype=torch_bool())
  # use the mask to set to zero after the end of the shorter sequence 
  attn_weights <- scores$exp() 
  attn_weights <- attn_weights$masked_fill(mask, 0) 
  # normalization 
  attn_weights$div(attn_weights$sum(dim=3)$unsqueeze(3))
}

@brcolli
Copy link
Author

brcolli commented Jun 7, 2024

Sorry for taking so long to respond.
@gavril0 are you thinking of adding ignore_index to the NLLLoss() criterion?
criterion = nn.NLLLoss(ignore_index=0)
using 0 to match the padding_idx. This way I believe it should be ignoring the padding values.

@gavril0
Copy link

gavril0 commented Jun 10, 2024

@brcolli. This is what I was thinking. One also needs to make sure outputs that are longer than the target sentences are padded with zeros during training because I don't think that the computed loss is correct otherwise. I wanted to compare the gradients computed in this manner to those computed using indices but I have yet to do it.

@spro
Copy link

spro commented Jun 10, 2024

@spro do the changes look good to you? Please leave a comment.

No need for my approval but yes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed core Tutorials of any level of difficulty related to the core pytorch functionality
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants