Skip to content

Commit

Permalink
Generate: update basic llm tutorial (huggingface#26937)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and staghado committed Oct 24, 2023
1 parent 8fd2bf3 commit 4439b29
Showing 1 changed file with 67 additions and 20 deletions.
87 changes: 67 additions & 20 deletions docs/source/en/llm_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,13 @@ If you're interested in basic LLM usage, our high-level [`Pipeline`](pipeline_tu

</Tip>

<!-- TODO: update example to llama 2 (or a newer popular baseline) when it becomes ungated -->
First, you need to load the model.

```py
>>> from transformers import AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained(
... "openlm-research/open_llama_7b", device_map="auto", load_in_4bit=True
... "mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True
... )
```

Expand All @@ -97,18 +96,31 @@ Next, you need to preprocess your text input with a [tokenizer](tokenizer_summar
```py
>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
>>> model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
```

The `model_inputs` variable holds the tokenized text input, as well as the attention mask. While [`~generation.GenerationMixin.generate`] does its best effort to infer the attention mask when it is not passed, we recommend passing it whenever possible for optimal results.

Finally, call the [`~generation.GenerationMixin.generate`] method to returns the generated tokens, which should be converted to text before printing.
After tokenizing the inputs, you can call the [`~generation.GenerationMixin.generate`] method to returns the generated tokens. The generated tokens then should be converted to text before printing.

```py
>>> generated_ids = model.generate(**model_inputs)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'A list of colors: red, blue, green, yellow, black, white, and brown'
'A list of colors: red, blue, green, yellow, orange, purple, pink,'
```

Finally, you don't need to do it one sequence at a time! You can batch your inputs, which will greatly improve the throughput at a small latency and memory cost. All you need to do is to make sure you pad your inputs properly (more on that below).

```py
>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
>>> model_inputs = tokenizer(
... ["A list of colors: red, blue", "Portugal is"], return_tensors="pt", padding=True
... ).to("cuda")
>>> generated_ids = model.generate(**model_inputs)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
['A list of colors: red, blue, green, yellow, orange, purple, pink,',
'Portugal is a country in southwestern Europe, on the Iber']
```

And that's it! In a few lines of code, you can harness the power of an LLM.
Expand All @@ -121,10 +133,10 @@ There are many [generation strategies](generation_strategies), and sometimes the
```py
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
>>> tokenizer.pad_token = tokenizer.eos_token # Llama has no pad token by default
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
>>> model = AutoModelForCausalLM.from_pretrained(
... "openlm-research/open_llama_7b", device_map="auto", load_in_4bit=True
... "mistralai/Mistral-7B-v0.1", device_map="auto", load_in_4bit=True
... )
```

Expand Down Expand Up @@ -154,7 +166,7 @@ By default, and unless specified in the [`~generation.GenerationConfig`] file, `
```py
>>> # Set seed or reproducibility -- you don't need this unless you want full reproducibility
>>> from transformers import set_seed
>>> set_seed(0)
>>> set_seed(42)

>>> model_inputs = tokenizer(["I am a cat."], return_tensors="pt").to("cuda")

Expand All @@ -166,7 +178,7 @@ By default, and unless specified in the [`~generation.GenerationConfig`] file, `
>>> # With sampling, the output becomes more creative!
>>> generated_ids = model.generate(**model_inputs, do_sample=True)
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'I am a cat.\nI just need to be. I am always.\nEvery time'
'I am a cat. Specifically, I am an indoor-only cat. I'
```

### Wrong padding side
Expand All @@ -175,17 +187,17 @@ LLMs are [decoder-only](https://huggingface.co/learn/nlp-course/chapter1/6?fw=pt

```py
>>> # The tokenizer initialized above has right-padding active by default: the 1st sequence,
>>> # which is shorter, has padding on the right side. Generation fails.
>>> # which is shorter, has padding on the right side. Generation fails to capture the logic.
>>> model_inputs = tokenizer(
... ["1, 2, 3", "A, B, C, D, E"], padding=True, return_tensors="pt"
... ).to("cuda")
>>> generated_ids = model.generate(**model_inputs)
>>> tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)[0]
''
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'1, 2, 33333333333'

>>> # With left-padding, it works as expected!
>>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b", padding_side="left")
>>> tokenizer.pad_token = tokenizer.eos_token # Llama has no pad token by default
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", padding_side="left")
>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
>>> model_inputs = tokenizer(
... ["1, 2, 3", "A, B, C, D, E"], padding=True, return_tensors="pt"
... ).to("cuda")
Expand All @@ -194,26 +206,61 @@ LLMs are [decoder-only](https://huggingface.co/learn/nlp-course/chapter1/6?fw=pt
'1, 2, 3, 4, 5, 6,'
```

<!-- TODO: when the prompting guide is ready, mention the importance of setting the right prompt in this section -->
### Wrong prompt

Some models and tasks expect a certain input prompt format to work properly. When this format is not applied, you will get a silent performance degradation: the model kinda works, but not as well as if you were following the expected prompt. More information about prompting, including which models and tasks need to be careful, is available in this [guide](tasks/prompting). Let's see an example with a chat LLM, which makes use of [chat templating](chat_templating):

```python
>>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
>>> model = AutoModelForCausalLM.from_pretrained(
... "HuggingFaceH4/zephyr-7b-alpha", device_map="auto", load_in_4bit=True
... )
>>> set_seed(0)
>>> prompt = """How many helicopters can a human eat in one sitting? Reply as a thug."""
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
>>> input_length = model_inputs.input_ids.shape[1]
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=20)
>>> print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])
"I'm not a thug, but i can tell you that a human cannot eat"
>>> # Oh no, it did not follow our instruction to reply as a thug! Let's see what happens when we write
>>> # a better prompt and use the right template for this model (through `tokenizer.apply_chat_template`)

>>> set_seed(0)
>>> messages = [
... {
... "role": "system",
... "content": "You are a friendly chatbot who always responds in the style of a thug",
... },
... {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
... ]
>>> model_inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to("cuda")
>>> input_length = model_inputs.shape[1]
>>> generated_ids = model.generate(model_inputs, do_sample=True, max_new_tokens=20)
>>> print(tokenizer.batch_decode(generated_ids[:, input_length:], skip_special_tokens=True)[0])
'None, you thug. How bout you try to focus on more useful questions?'
>>> # As we can see, it followed a proper thug style 馃槑
```

## Further resources

While the autoregressive generation process is relatively straightforward, making the most out of your LLM can be a challenging endeavor because there are many moving parts. For your next steps to help you dive deeper into LLM usage and understanding:

<!-- TODO: complete with new guides -->
### Advanced generate usage

1. [Guide](generation_strategies) on how to control different generation methods, how to set up the generation configuration file, and how to stream the output;
2. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils).
2. [Guide](chat_templating) on the prompt template for chat LLMs;
3. [Guide](tasks/prompting) on to get the most of prompt design;
4. API reference on [`~generation.GenerationConfig`], [`~generation.GenerationMixin.generate`], and [generate-related classes](internal/generation_utils).

### LLM leaderboards

1. [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard), which focuses on the quality of the open-source models;
2. [Open LLM-Perf Leaderboard](https://huggingface.co/spaces/optimum/llm-perf-leaderboard), which focuses on LLM throughput.

### Latency and throughput
### Latency, throughput and memory utilization

1. [Guide](main_classes/quantization) on dynamic quantization, which shows you how to drastically reduce your memory requirements.
1. [Guide](llm_tutorial_optimization) on how to optimize LLMs for speed and memory;
2. [Guide](main_classes/quantization) on quantization such as bitsandbytes and autogptq, which shows you how to drastically reduce your memory requirements.

### Related libraries

Expand Down

0 comments on commit 4439b29

Please sign in to comment.