Skip to content

Commit

Permalink
Add beam search generation w/ Flashlight Text
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Dec 27, 2022
1 parent b699de2 commit 9da074e
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 18 deletions.
129 changes: 121 additions & 8 deletions notebooks/hf_with_torchtext_gen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
Expand All @@ -39,14 +39,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
Expand Down Expand Up @@ -74,7 +74,55 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n"
]
}
],
"source": [
"# Testing HuggingFace's T5 w/ Beam Search\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n",
"['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -99,7 +147,54 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n"
]
}
],
"source": [
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n",
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -119,11 +214,29 @@
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n"
]
}
],
"source": [
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"display_name": "torchtext",
"language": "python",
"name": "python3"
},
Expand All @@ -137,12 +250,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
"hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650"
}
}
},
Expand Down
21 changes: 20 additions & 1 deletion test/torchtext_unittest/prototype/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,23 @@ def test_generate_errors_with_incorrect_beams(self) -> None:
def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtil(self.model)
generation_model.generate(self.inputs)
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.")
mock.assert_called_with("`max_len` was not specified. Defaulting to 256 tokens.")

def test_beam_search(self) -> None:
generation_model = GenerationUtil(self.model)

tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30, beam_size_token=self.model.config.vocab_size)

generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
'kate mccartney: a dog is good for you . she says studies have shown that dog ownership is good for',
'Das ist gut.',
'acceptable',
'4.0',
'a tornado ripped through a swath of a lake in st. louis . a s'
]

self.assertEqual(generated_text, expected_generated_text)


Loading

0 comments on commit 9da074e

Please sign in to comment.