Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beginning of generation utils and necessary refactors of T5 Model #2011

Merged
merged 6 commits into from
Dec 29, 2022

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Dec 19, 2022

Context

We aim to add generation utils that support a number of encoder/decoder and decoder-based models. To do so, we also have to rework our current encoder/decoder model, T5.

Changes

  1. Separated logic for encoder and decoder into self-contained nn.Modules.
    1a. Move dropout layers and norms to T5Encoder and T5Decoder
    1b. Pass token_embeddings to the encoder if constructed through the T5Model. Now the encoder can take in tokenized text or embedded text.
    1c. Add get_encoder and get_decoder getter functions (not torchscriptable ATM)
    1d. Update type annotations to allow for padding_masks and encoder_outputs
    1e. Change T5Encoder and T5Decoder return types to dictionaries

  2. Added GenerationUtils class and greedy_search generation technique
    2a. Added deprecation warning to T5Wrapper until beam_search is added.

  3. Froze configs to avoid mutating the model unnecessarily

Testing

  • One notebook showing parity with HuggingFace's T5 model
  • One notebook showing that HuggingFace models can be used with the GenerationUtil
  • Integration tests for the new endpoints and additional functions

Notes

  • T5Wrapper is no longer Torchscriptable
  • Was not able to guarantee parity between our greedy search and HuggingFace's despite having nearly identical implementations
  • As part of a discussion with @forresti wrt onboarding T5 V1.1 and T5-FLAN, might need to refactor to include Gated-GeLU activation. This will come in a follow-up PR but documenting here.

@@ -0,0 +1,83 @@
{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we verify completeness often w/ internal notebooks - I thought for those that show parity with HuggingFace or external libraries, we could put those notebooks in the actual repo. Seems like a better way to keep track rather than some Bento notebooks w/ scattered ownership.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but for a quick fix you can right click on expand dots on the top right of this file and select "View file" and it'll give you a notebook view.



@dataclass
@dataclass(frozen=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Freezing this as we probably don't want people to be able to overwrite configs and still try to use the model - much more likely to run into bugs that way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if someone wants to experiment with a smaller model or modified architecture? Are there distilled or smaller T5 models out there? We don't freeze other configs so I am not sure I agree with this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Freezing the config won't make it impossible to try a smaller model or modified architecture. It just means that once they instantiate the config and pass the config to the model, they won't be able to modify it.

Example:

config = T5Config(encoder_only=True)
t5_model = T5Model(config=config)
t5_model.config.encoder_only = False  # Currently allowed; with freezing config, this would throw an error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to follow up here, in the example you just showed, would it affect the model behavior if users did end up changing the config after instantiating the model? IIUC the config is only used during model instantiation anyways. That being said I don't see any issues with freezing the config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't affect the model behavior, but it would throw an error saying "Config cannot be modified", which I think is what we want. It would be considered undefined behavior if someone e.g. instantiated a model without a decoder and then went back and changed the config to say that it did have a decoder.

def forward(
self,
encoder_tokens: Tensor,
encoder_tokens: Optional[Tensor] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't have to include encoder_tokens if encoder_outputs are already provided.


def forward(
self,
tgt: Tensor,
tgt: Optional[Tensor] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the forward of the entire model, if the target is already embedded, no need to inlcude the raw tokenized tgt.


if self.is_encoder_decoder:
encoder = self.model.get_encoder()
model_kwargs["encoder_outputs"] = encoder(inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the necessary args for the forward method of whatever model is being used in decoding.

from torch import nn


class GenerationUtil:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this whole class have to be torchscriptable, as well??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make it extremely difficult to incorporate other models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this whole class have to be torchscriptable, as well??

If we expect that this util will be used in Predictor during inference time then yes it does. Can you explain what makes it difficult to make this torchscriptable.

As a first step, we can always implement this without torchscriptability support for customers to experiment with. And if there's enough demand to make it torchscriptable then we can come back and add this support.



@dataclass
@dataclass(frozen=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if someone wants to experiment with a smaller model or modified architecture? Are there distilled or smaller T5 models out there? We don't freeze other configs so I am not sure I agree with this

torchtext/prototype/models/t5/model.py Outdated Show resolved Hide resolved
self, batch_size: int, device: Optional[torch.device] = None, **model_kwargs
):
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pass around model_kwargs dict instead of just having an optional decoder_input_ids param?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -0,0 +1,83 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?

torchtext/prototype/models/t5/modules.py Show resolved Hide resolved
device: Optional[torch.device] = None,
dtype=None,
) -> None:
super().__init__()

self.token_embeddings = token_embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a description of this input argument to the docstring above?

from torch import nn


class GenerationUtil:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this whole class have to be torchscriptable, as well??

If we expect that this util will be used in Predictor during inference time then yes it does. Can you explain what makes it difficult to make this torchscriptable.

As a first step, we can always implement this without torchscriptability support for customers to experiment with. And if there's enough demand to make it torchscriptable then we can come back and add this support.

self, batch_size: int, device: Optional[torch.device] = None, **model_kwargs
):
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +65 to +88
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add an explanation for this line? Having a hard time following the logic. Alternatively let's add a couple of lines to the docstring of this method explaining the approach.



@dataclass
@dataclass(frozen=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to follow up here, in the example you just showed, would it affect the model behavior if users did end up changing the config after instantiating the model? IIUC the config is only used during model instantiation anyways. That being said I don't see any issues with freezing the config.

torchtext/prototype/models/t5/model.py Outdated Show resolved Hide resolved
torchtext/prototype/models/t5/model.py Outdated Show resolved Hide resolved
Comment on lines +835 to +836
self.norm = T5LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we decide to move this from the model to the encoder/decoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeps the entire encoder forward method self-contained.

@joecummings
Copy link
Contributor Author

@Nayef211 Can I get some 👀 on this again when you have a chance?

@joecummings
Copy link
Contributor Author

@atalman @osalpekar Is this failing integration test related to Nova migration? The process seems to be killed with no helpful error and the integration tests pass on my local machine.

@atalman
Copy link
Contributor

atalman commented Dec 28, 2022

@atalman @osalpekar Is this failing integration test related to Nova migration? The process seems to be killed with no helpful error and the integration tests pass on my local machine.

@joecummings looks like integration tests are running out of memory, code 137: 3796110184

@joecummings
Copy link
Contributor Author

@atalman @osalpekar Is this failing integration test related to Nova migration? The process seems to be killed with no helpful error and the integration tests pass on my local machine.

@joecummings looks like integration tests are running out of memory, code 137: 3796110184

Silly follow-up question, but how would I go about allocating more memory for these integration tests?

Copy link

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to have sampling as part of the library!


return input_ids

def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put all the sampling methods in this class (beam_search, greedy), is that very extensible for the user? Or should these be separate classes that inherit from GenerationUtil or a general Sampler class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk about inheriting from GenerationUtil, but as a standalone class, this makes sense.

@@ -176,7 +176,8 @@ def build_model_from_huggingface_ckpt(

t5_model_state_dict = {
"token_embeddings.weight": hf_weights["shared.weight"],
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"],
"encoder.token_embeddings.weight": hf_weights["shared.weight"],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anyway to bundle Generation parameters with a model so the user doesn't have to know the correct sampling defaults for a given model?

@@ -47,6 +48,8 @@ def __init__(
strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`)
"""
warnings.warn("`T5Wrapper` is being deprecated. Please use new `GenerationUtils`.", category=DeprecationWarning)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerationUtils, if made in an nn.Module, could be treated as a generic wrapper for any LLM. This might be easier for the user but would break from the Huggingface design. It would allow for generation parameters to be saved with the model.

@joecummings joecummings merged commit a933cbe into pytorch:main Dec 29, 2022
@joecummings joecummings deleted the greedy-generation branch December 29, 2022 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants