-
Notifications
You must be signed in to change notification settings - Fork 816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added Text Generation example #1473
Conversation
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, very close - just needs a bit more polish
pull_request_template.md
Outdated
@@ -9,9 +9,9 @@ Fixes #(issue) | |||
Please delete options that are not relevant. | |||
|
|||
- [ ] Bug fix (non-breaking change which fixes an issue) | |||
- [ ] New feature (non-breaking change which adds functionality) | |||
- [x] New feature (non-breaking change which adds functionality) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert the changes here, the template is automatically used when you open a new PR - no need to update it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do!
@@ -126,7 +129,7 @@ def preprocess(self, requests): | |||
max_length = self.setup_config["max_length"] | |||
logger.info("Received text: '%s'", input_text) | |||
# preprocessing text for sequence_classification and token_classification. | |||
if self.setup_config["mode"] == "sequence_classification" or self.setup_config["mode"] == "token_classification": | |||
if self.setup_config["mode"] == "sequence_classification" or self.setup_config["mode"] == "token_classification" or self.setup_config["mode"] == "text_generation": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's break this into multiple lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe something like "if self.setup_config["mode"] in {"sequence_classification", "token_classification", "text_generation"}:" would be more elegant?
outputs = self.model.generate(input, max_length=150, do_sample=True, top_p=0.95, top_k=60) | ||
generated = self.tokenizer.decode(input) + self.tokenizer.decode(outputs[0])[prompt_length + 1 :] | ||
|
||
inferences.append(generated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just trying to understand this snippet. Are batches completely independent or are you chunking text from the same sentence? Should the batch size be equal to the max sequence length for this example to work?
Are the arguments the right for a demo? Is the model too slow? Is the generated output readable? Is the output too long? Is the output deterministic for your input? (important for tests and README) so people know they ran things correctly
outputs = self.model.generate(input, max_length=150, do_sample=True, top_p=0.95, top_k=60)
Also could you try printing the expected output?
I'd suggest adding a single test similar to this one https://github.com/pytorch/serve/blob/master/test/pytest/test_handler.py#L230 so it's easier for people to maintain this example without breaking it
Let me know if you need any more help with this. You put the mar file on any URL that makes it wgettable. I'm working on a cleaner longer frame story for this #1470
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to keep the same approach/style as the other transformer examples, so for each text example (prompt) from the batch, there is one instance of generated text (which is put together with the prompt, as in the official example). Max sequence is per every text example. The output will have the length of the max sequence and will be decoded in text. The parameters chosen in this function are the ones used in the official example from Huggingface (with a smaller max length), because I figured that maybe a lot of people would like see the same example served. Of course this might be improved in the future, by adding more flexibility in parameters selection.
The model size is comparable of the other transformers (~500 mb), but I'm not sure about the speed.
The output is not deterministic, because we can't know in advance what the generated text will be. From the local tests that I've done, each time I got a different output. I think the user can be sure the example is executed correctly when some text is generated, regardless of the content.
Considering this, I'm not sure how I should make assertion tests with this example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright I think if that's the case please summarize the the first paragraph as a comment and update the README and we can merge this in
@@ -4,7 +4,7 @@ | |||
import json | |||
import torch | |||
from transformers import (AutoModelForSequenceClassification, AutoTokenizer, AutoModelForQuestionAnswering, | |||
AutoModelForTokenClassification, AutoConfig) | |||
AutoModelForTokenClassification, AutoModelForCausalLM, AutoConfig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to update examples/Huggingface_Transformers/README.md
as well
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
outputs = self.model.generate(input, max_length=150, do_sample=True, top_p=0.95, top_k=60) | ||
generated = self.tokenizer.decode(input) + self.tokenizer.decode(outputs[0])[prompt_length + 1 :] | ||
|
||
inferences.append(generated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright I think if that's the case please summarize the the first paragraph as a comment and update the README and we can merge this in
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent, thank you for your contribution!
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
AWS CodeBuild CI Report
Powered by github-codebuild-logs, available on the AWS Serverless Application Repository |
Description
Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes #(1386)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Please describe the tests [UT/IT] that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Test A
Test B
Logs
Checklist: