Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam search w/ Flashlight Text #2017

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions benchmark/benchmark_generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import time
from functools import partial

from torch.utils.data import DataLoader
from torcheval.metrics.functional import word_error_rate
from torchtext.datasets import Multi30k
from torchtext.models import T5_BASE_GENERATION, T5_3B_GENERATION
from torchtext.prototype.generate import GenerationUtils

multi_batch_size = 16
language_pair = ("en", "de")
multi_datapipe = Multi30k(split="test", language_pair=language_pair)
task = "translate English to German"


def apply_prefix(task, x):
return f"{task}: " + x[0], x[1]


multi_datapipe = multi_datapipe.map(partial(apply_prefix, task))
multi_datapipe = multi_datapipe.batch(multi_batch_size)
multi_datapipe = multi_datapipe.rows2columnar(["english", "german"])
multi_dataloader = DataLoader(multi_datapipe, batch_size=None)


def benchmark_beam_search_wer():
model = T5_BASE_GENERATION.get_model()
transform = T5_BASE_GENERATION.transform()

seq_generator = GenerationUtils(model)

batch = next(iter(multi_dataloader))
input_text = batch["english"]
target = batch["german"]
beam_size = 8

model_input = transform(input_text)
model_output = seq_generator.generate(
model_input,
num_beams=beam_size,
beam_threshold=1000,
vocab_size=model.config.vocab_size,
eos_score=-1.0,
eos_idx=1,
pad_idx=0,
)
output_text = transform.decode(model_output.tolist())

print(word_error_rate(output_text, target))


if __name__ == "__main__":
benchmark_beam_search_wer()
129 changes: 121 additions & 8 deletions notebooks/hf_with_torchtext_gen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
Expand All @@ -39,14 +39,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
Expand Down Expand Up @@ -74,7 +74,55 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n"
]
}
],
"source": [
"# Testing HuggingFace's T5 w/ Beam Search\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n",
"['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -99,7 +147,54 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n"
]
}
],
"source": [
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n",
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -119,11 +214,29 @@
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n"
]
}
],
"source": [
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"display_name": "torchtext",
"language": "python",
"name": "python3"
},
Expand All @@ -137,12 +250,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
"hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650"
}
}
},
Expand Down
88 changes: 86 additions & 2 deletions test/integration_tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUp(self) -> None:
def test_greedy_generate_with_t5(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
tokens = generation_model.generate(self.inputs, num_beams=1)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
Expand All @@ -41,13 +41,97 @@ def test_greedy_generate_with_t5(self) -> None:

self.assertEqual(generated_text, expected_generated_text)

def test_beam_search_generate_t5(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30
)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
"Das ist gut.",
"acceptable",
"4.0",
"a tornado ripped through a swath of a lake in southeastern michigan . a spokesman",
]

self.assertEqual(generated_text, expected_generated_text)

def test_beam_search_generate_t5_small_batch_size(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, max_inference_batch_size=3
)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
"kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for",
"Das ist gut.",
"acceptable",
"4.0",
"a tornado ripped through a swath of a lake in southeastern michigan . a spokesman",
]

self.assertEqual(generated_text, expected_generated_text)

def test_beam_search_generate_t5_with_small_beam_threshold(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(
self.inputs, num_beams=3, vocab_size=self.model.config.vocab_size, max_length=30, beam_threshold=5
)
generated_text = self.transform.decode(tokens.tolist())

expected_text = [
"kate mccartney: a dog is good for you . kate mccartney: dogs",
"Das ist gut.",
"acceptable",
"4.0",
"a tornado ripped through a swath of a lake in southeastern mississippi, causing",
]

self.assertEqual(generated_text, expected_text)

def test_beam_search_generate_t5_large_num_beams(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(
self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30
)
generated_text = self.transform.decode(tokens.tolist())

expected_text = [
"aaron carroll, aaron jones, aaron jones and aaron jones",
"Das ist gut.",
"acceptable",
"4.0",
"a blizzard and power outages have prompted a blizzard and power outages, a spokesman says",
]

self.assertEqual(generated_text, expected_text)

def test_beam_search_generate_t5_large_num_beams_eos_score(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(
self.inputs, num_beams=25, vocab_size=self.model.config.vocab_size, max_length=30, eos_score=10.0
)
generated_text = self.transform.decode(tokens.tolist())

expected_text = ["", "Das ist gut.", "acceptable", "4.0", ""]

self.assertEqual(generated_text, expected_text)

def test_generate_errors_with_incorrect_beams(self) -> None:
generation_model = GenerationUtils(self.model, is_encoder_decoder=True)

with self.assertRaises(ValueError):
generation_model.generate(self.inputs, num_beams=0)

@patch("logging.Logger.warning")
@patch("warnings.warn")
def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtils(self.model)
generation_model.generate(self.inputs)
Expand Down
Loading