Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

LazyEmbedding, an embedding layer with a dynamically sized vocabulary #55981

Open
PetrochukM opened this issue Apr 14, 2021 · 13 comments
Open

LazyEmbedding, an embedding layer with a dynamically sized vocabulary #55981

PetrochukM opened this issue Apr 14, 2021 · 13 comments
Labels
feature A request for a proper, new feature. module: embedding module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@PetrochukM
Copy link

PetrochukM commented Apr 14, 2021

馃殌 Feature

It'd be AMAZING to have a lazy embedding layer that grows to accommodate new tokens. So, for example, the interface would look like this:

import torch

hidden_size = 16
embed_token = torch.nn.LazyEmbedding(hidden_size)

embedding = embed_token(["a", "b", "c"])
assert embedding.shape == (3, hidden_size)
assert embed_token.vocab_size == 3

embedding = embed_token(["d", "e"])
assert embedding.shape == (2, hidden_size)
assert embed_token.vocab_size == 5

embed_token = embed_token.eval()
embedding = embed_token(["f"])  # ERROR: Token 'f' not found.

And, for example, here is a basic, inefficient, implementation:

import typing

import torch

class LazyEmbedding(torch.nn.Module):
    def __init__(self, max_num_embeddings: int, *args, **kwargs):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=max_num_embeddings, *args, **kwargs)
        self.vocab: typing.Dict[str, int] = {}

    def forward(self, tokens: typing.List) -> torch.Tensor:
        indicies = []
        for token in tokens:
            if token not in self.vocab and self.training:
                self.vocab[token] = len(self.vocab)
            indicies.append(self.vocab[token])
        return self.embedding(torch.tensor(indicies))


embedding = LazyEmbedding(10, embedding_dim=16)

tensor = embedding(["a", "b", "c"])
print(tensor.shape) # torch.Size([3, 16])
print(embedding.vocab)  # {'a': 0, 'b': 1, 'c': 2}

tensor = embedding(["d", "e", "f"])
print(tensor.shape) # torch.Size([3, 16])
print(embedding.vocab)  # {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5}

try:
    embedding.eval()(["g"])
except KeyError:
    print("KeyError")

There is also a bit of work that can be done with modularity. For example, PyTorch could be responsible for providing a LazyEmbedding that supports only torch.Tensor. torchtext could provide a wrapper that adds the vocabulary.

Motivation

This feature has been on my mind a lot because it'd dramatically simplify my NLP training pipelines.

Without a LazyEmbedding, typically, I need to use torchnlp.encoders or torchtext.vocab. I'll need to initialize these objects by looping through my entire dataset, in order to determine all the tokens, that I might need. Afterward, I'll need to use the vocabulary with my DataLoader in order to encode training examples. Lastly, I need to store this object in the related checkpoints, so that other people can use the same encoding.

With a module like this, I wouldn't need to use torchnlp.encoders and I wouldn't need to use torchtext.vocab. Also, we could:

Furthermore, a LazyEmbedding could also be extended with a tokenizer, padding tokens, eos tokens, unknown tokens, pre-trained word vectors, etc.

cc @albanD @mruberry @jbschlosser

@PetrochukM PetrochukM changed the title LazyEmbedding LazyEmbedding, an embedding layer with a dynamically sized vocabulary Apr 14, 2021
@jbschlosser jbschlosser added feature A request for a proper, new feature. module: embedding module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 14, 2021
@jbschlosser
Copy link
Contributor

Hey @PetrochukM, thanks for the cool suggestion!

There is also a bit of work that can be done with modularity. For example, PyTorch could be responsible for providing a LazyEmbedding that supports only torch.Tensor. torchtext could provide a wrapper that adds the vocabulary.

My initial thoughts are that LazyEmbedding wouldn't fit well in PyTorch core. The current nn.Embedding module is essentially a fixed-size table of embeddings addressable by contiguous indices. Consider a potential LazyEmbedding in the same spirit; some questions come to mind:

  • Would new embeddings be instantiated whenever a new index is seen? Example: if we start with 10 embeddings and try to access index 50, would new embeddings be instantiated for indices 10-49? I could see some use cases wanting sparse behavior (allowing for gaps), while others would require strict contiguity.
  • What should the growth strategy be for the table? The most obvious would be to expand as needed to accommodate new indices, with no extra space. I could see an exponential growth strategy being useful for performance when a lot of embeddings are being added to the table.
  • How will the new embeddings be initialized? We'd need to support arbitrary initialization, as that's possible now for nn.Embedding by e.g. calling the torch.nn.init.* functions on the Embedding's weight parameter.

Given that there are multiple valid answers for each of the above, I think it's easier to maintain only nn.Embedding in PyTorch core as a performant, fixed-size embedding table. Imo functionality like that described for LazyEmbedding should be built on top of nn.Embedding in domain-specific libraries / user code, where the appropriate implementation details can be chosen depending on the use case.

That said, I'm open to being convinced, especially if there's a lot of interest for something like this :)

@PetrochukM
Copy link
Author

Thanks for considering my suggestion.

Here are my thoughts:

  • Great question. We could also add indirection so that index 50 maps to index 11 so that there are no gaps. There are interesting strategies for optimizing performance in a sparse context.
  • Yup! I think both are valid and they'd work. I think this answer depends on how expensive it is to add a new embedding. The more expensive it is, the more I'd lean toward exponentially allocating memory.
  • I'd probably lean toward defining a per embedding initialization strategy.

Re:

Given that there are multiple valid answers for each of the above, I think it's easier to maintain only nn.Embedding in PyTorch core as a performant, fixed-size embedding table. Imo functionality like that described for LazyEmbedding should be built on top of nn.Embedding in domain-specific libraries / user code, where the appropriate implementation details can be chosen depending on the use case.

With regards to this, I'm happy to include an implementation in an open-source library. The issue is that it's difficult to implement a LazyEmbedding efficiently without touching core components. For example, I'm not sure how to update DistributedDataParallel or optim.Adam after resizing the Embedding table weights. Is there an easy way to do so?

Let me know if you think this is possible to implement, efficiently, without PyTorch support.... I posted this feature request because of this perception.

Either way, I'll definitely open-source an inefficient implementation because it'll make NLP so much easier. I have an NLP student and torch.nn.Embedding was overwhelming to her. It'll also make my life a lot easier. The number of vocabularies I need to maintain for text-to-speech is growing quickly 馃槄 . I need a better abstraction with less boilerplate.

@PetrochukM
Copy link
Author

PetrochukM commented Apr 15, 2021

(I pinged a couple of NLP chats to see if there is interest in something like this!)

@jspisak
Copy link
Contributor

jspisak commented Apr 15, 2021

should we discuss whether this is a better fit for torchtext?

@PetrochukM
Copy link
Author

PetrochukM commented Apr 15, 2021

@jspisak I'd be happy to discuss. I don't think an efficient implementation is doable without making modifications to PyTorch its self. So, I don't think torchtext could implement it, easily.

Even so, we could implement a version of this, in torchtext. I think that'd still be valuable.

@jessicapetrochuk
Copy link

Love the idea! Have been using the nn.Embedding module but would really benefit from having an implementation of LazyEmbedding

@wangkuiyi
Copy link
Contributor

wangkuiyi commented Apr 26, 2021

+1 with this idea.

@jbschlosser reminded me of this issue when I asked for the same embedding https://gist.github.com/wangkuiyi/dd2e3794d11010f0cd562ed009664f90. Slightly different from @PetrochukM's sample, the above gist doesn't allocate the underlying dense embedding table beforehand.

This fully lazy and dynamic embedding is useful not only for NLP but also recommendation and ad systems.

TensorFlow had a distributed version of this dynamic embedding feature in 2020 for recommendation systems https://github.com/tensorflow/community/blob/a4542d13baaa64e81ec4689719fbe30abe89aee0/rfcs/20200424-sparse-domain-isolation.md

Baidu's Paddle has had this for advertising systems since 2018 https://github.com/PaddlePaddle/Paddle/projects/56.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Apr 30, 2021

I'm wondering what people think about the following design proposal

  class LazyEmbedding(torch.nn.Module):
      def __init__(self, embedding_dim):
          super().__init__()
          self.vocab = {}
          self.data = torch.empty((1, embedding_dim))
          self.embedding_dim = embedding_dim

      def forward(self, tokens: torch.Tensor) -> torch.Tensor:
          tokens = list(tokens)
          for i in range(len(tokens)):
              if tokens[i] not in self.vocab and self.training:
                  while len(self.data) <= len(self.vocab):
                      self.data = torch.cat((self.data, self.data))
                  self.vocab[tokens[i]] = len(self.vocab)
                  new_entry = torch.empty((self.embedding_dim,))
                  torch.nn.init.normal_(new_entry)
                  self.data[self.vocab[tokens[i]]].copy_(new_entry)
          indices = [self.vocab[token] for token in tokens]
          indices = torch.tensor(indices)
          return torch.nn.functional.embedding(indices, self.data)

This is using the same kernels as nn.Embedding, but doubles the underlying embedding table whenever it runs out of space for a given index and stores a table that maps a token index to the actual underlying data. It's also different in that it accepts indices and returns Tensors, so actually a full on replacement for nn.Embedding. To actually make this performant the map from given index to physical index will need to be written efficiently and we might want to reorder the embedding table based on some observed frequency of tokens on a call to eval to fully match the inference performance of nn.Embedding. Other cost such as doubling and reallocating the underlying data, initializing an entry that isn't available (repeated calls to torch.nn.init.normal_) and most importantly cache locality are potentially minimal assuming the input follows a distribution such as Zipf's law and the program runs long enough. In particular we might want to run some very precise benchmarks around cache locality.

@parmeet
Copy link

parmeet commented May 4, 2021

I'm wondering what people think about the following design proposal

  class LazyEmbedding(torch.nn.Module):
      def __init__(self, embedding_dim):
          super().__init__()
          self.vocab = {}
          self.data = torch.empty((1, embedding_dim))
          self.embedding_dim = embedding_dim

      def forward(self, tokens: torch.Tensor) -> torch.Tensor:
          tokens = list(tokens)
          for i in range(len(tokens)):
              if tokens[i] not in self.vocab and self.training:
                  while len(self.data) <= len(self.vocab):
                      self.data = torch.cat((self.data, self.data))
                  self.vocab[tokens[i]] = len(self.vocab)
                  new_entry = torch.empty((self.embedding_dim,))
                  torch.nn.init.normal_(new_entry)
                  self.data[self.vocab[tokens[i]]].copy_(new_entry)
          indices = [self.vocab[token] for token in tokens]
          indices = torch.tensor(indices)
          return torch.nn.functional.embedding(indices, self.data)

This is using the same kernels as nn.Embedding, but doubles the underlying embedding table whenever it runs out of space for a given index and stores a table that maps a token index to the actual underlying data. It's also different in that it accepts indices and returns Tensors, so actually a full on replacement for nn.Embedding. To actually make this performant the map from given index to physical index will need to be written efficiently and we might want to reorder the embedding table based on some observed frequency of tokens on a call to eval to fully match the inference performance of nn.Embedding. Other cost such as doubling and reallocating the underlying data, initializing an entry that isn't available (repeated calls to torch.nn.init.normal_) and most importantly cache locality are potentially minimal assuming the input follows a distribution such as Zipf's law and the program runs long enough. In particular we might want to run some very precise benchmarks around cache locality.

I think overall I really like the idea. This is similar in-line with std::vectors that grows dynamically as we push more objects into the structure.

One question I have though is on the forward API. why not support list of strings directly? It seems LazyEmbedding wound need to depend on some external structure to provide token indices. That external structure also need to support dynamic (Lazy) semantics. It may be an overhead in terms of workflow and check-pointing compared to the alternative where LazyEmbedding encapsulate both the vocabulary and embeddings together?

@cpuhrsch
Copy link
Contributor

cpuhrsch commented May 5, 2021

One question I have though is on the forward API. why not support list of strings directly? It seems LazyEmbedding wound need to depend on some external structure to provide token indices. That external structure also need to support dynamic (Lazy) semantics. It may be an overhead in terms of workflow and check-pointing compared to the alternative where LazyEmbedding encapsulate both the vocabulary and embeddings together?

The idea is to use a Vocab to map from string to int and then LazyEmbedding to map from int to a vector (same as nn.Embedding now).

@jbschlosser
Copy link
Contributor

For example, I'm not sure how to update DistributedDataParallel or optim.Adam after resizing the Embedding table weights. Is there an easy way to do so?

Good point. AFAIK optimizers can't handle the parameters they're optimizing being resized- that would either result in new parameters the optimizer doesn't know about or mess with the internal optimizer state (more info here). There's optimizer.add_param_group() for dynamically adding a parameter group to an optimizer, but I don't think that would work if the full embedding table is stored in a single Parameter. Because of this, I don't think the LazyEmbedding proposed by @cpuhrsch is trainable.

An alternative I haven't fully thought through might be to maintain the table across multiple Parameters, calling optimizer.add_param_group() with the new entries whenever the table is resized. This may require a "table resized" hook, and the user would be responsible for calling optimizer.add_param_group() within that hook. It's a bit messy.

Not sure about DDP either..

@PetrochukM
Copy link
Author

PetrochukM commented May 7, 2021

Not sure about DDP either...

There are a couple of strategies I can think of:

  • (Slow) We'd need to synchronize the parameter table every forward pass. We'd need to make sure every process knows about any new tokens that are created.
  • (Faster; Forward Pass Sync) We could have a hook in DDP. SyncBatchNorm has specialized handling via _passing_sync_batchnorm_handle. During the forward pass, we'd need to know if there are any new tokens. If there are, we need to update the parameter table for every process.
  • (Faster; Backward Pass Sync) There is an interesting idea using an "Unknown Token". In order to avoid synchronizing every forward pass, we'd use an "Unknown Token", if there is a new token during the forward pass. Then, during the backward pass, along with all the other parameters, we'd update the embedding table. (I implemented this approach in my private codebase!)

I wonder how Baidu or TF handled this problem! And i'm not an expert in PyTorch's distributed toolchain, so I hope someone has better ideas!

@PetrochukM
Copy link
Author

@cpuhrsch I'm a big fan of A LOT of your ideas! Thanks for providing concrete and actionable ideas around design and performance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: embedding module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants