Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the encode function in BART hub_interface: to add an extra option for not always adding OOV tokens into vocabulary #3905

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

datquocnguyen
Copy link

@datquocnguyen datquocnguyen commented Sep 22, 2021

Regarding this code line 62 in the encode function in BART hub_interface, in many cases (e.g. using a monolingual vocabulary reduced from an existing multilingual one), an OOV token should be aligned with <unk> index, rather than always being added as a new token type into the vocabulary.

Recent code: https://github.com/pytorch/fairseq/blob/main/fairseq/models/bart/hub_interface.py

    def encode(
        self, sentence: str, *addl_sentences, no_separator=True
    ) -> torch.LongTensor:
        tokens = self.bpe.encode(sentence)
        if len(tokens.split(" ")) > min(self.max_positions) - 2:
            tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
        bpe_sentence = "<s> " + tokens + " </s>"
        for s in addl_sentences:
            bpe_sentence += " </s>" if not no_separator else ""
            bpe_sentence += " " + self.bpe.encode(s) + " </s>"
        tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) # Always add OOV token as new type
        return tokens.long()

Suggest to be as follows (https://github.com/datquocnguyen/fairseq/blob/main/fairseq/models/bart/hub_interface.py):

    def encode(
        self, 
        sentence: str, 
        *addl_sentences, 
        no_separator=True,
        add_if_not_exist=True # Add an extra option
    ) -> torch.LongTensor:
        tokens = self.bpe.encode(sentence)
        if len(tokens.split(" ")) > min(self.max_positions) - 2:
            tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2])
        bpe_sentence = "<s> " + tokens + " </s>"
        for s in addl_sentences:
            bpe_sentence += " </s>" if not no_separator else ""
            bpe_sentence += " " + self.bpe.encode(s) + " </s>"
        tokens = self.task.source_dictionary.encode_line(
            bpe_sentence, append_eos=False, add_if_not_exist=add_if_not_exist
        )
        return tokens.long()

With this suggested code, in the case mentioned above, encoding should be fairseq_model.encode(sentence, add_if_not_exist=False)

For mBART and the like, it still encodes and adds extra token types into the vocabulary (e.g. training for new languages) as before: fairseq_model.encode(sentence)

To provide an extra option to convert OOV tokens into <unk> rather than always adding the OOV tokens into the dictionary.
@datquocnguyen datquocnguyen changed the title Update BART hub_interface to add an extra option for handling OOV tokens (e.g. emojis) Update the encode function in BART hub_interface: to add an extra option for not always adding OOV tokens into vocabulary Sep 22, 2021
@stale
Copy link

stale bot commented Mar 2, 2022

This pull request has been automatically marked as stale. If this pull request is still relevant, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.

@stale stale bot added the stale label Mar 2, 2022
@datquocnguyen
Copy link
Author

bump

@stale stale bot removed the stale label Mar 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants