> <p><small><small>This Notebook is made available subject to the licence and terms set out in the <a href = "http://www.github.com/google-deepmind/ai-foundations">AI Research Foundations Github README file</a>.

<img src="https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C1-white-bg.png">

# Lab: Compare N-Gram Models and Transformer Language Models

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/gdm_lab_1_3_compare_n_gram_models_and_transformer_language_models.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

Compare the generations of n-gram and transformer language models.

30 minutes.

## Overview

So far, you have encountered two methods to estimate the probability distribution over the next token given a prompt. In the first lab, you manually assigned probabilities to lists of candidate tokens, and in the second lab, you used n-gram counts to build a language model.

As you have seen in the previous labs, neither of these methods are ideal. Assigning probabilities manually, on the one hand, would be too time-intensive in practice and it would be impossible to list all possible prompts. The n-gram models, on the other hand, often produced generations that did not make sense due to their short context window, or the model failed to generate a sequence of tokens at all due to the sparsity in the dataset.

In this lab, you will experiment with a more advanced language model based on the **transformer architecture**. The transformer architecture is an example of a neural network model, a class of sophisticated machine learning models that can learn very complex patterns from data. Transformers provide the foundation for modern large language models. These models are much better at producing coherent responses to arbitrary prompts than n-gram models.

You will explore this yourself by comparing generations using your n-gram model to generations from a transformer model.

### What you will learn:

By the end of this lab, you will understand:
* How the probability distributions predicted by n-gram models and transformer models differ.
* How the generations based on these probability distributions differ.

### Tasks

In this lab, you will not have to write any new code but instead you will interact with two language models: the Gemma-1B transformer model and a trigram model.


**In this lab, you will**:
* Load the transformer model Gemma-1B and the trigram language model from the previous lab.
* Observe how the probability distribution over the next token varies for the two models.
* Explore how the generations of the two models differ.


## How to use Google Colaboratory (Colab)

Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in **cells** that are executed on a remote server.

To run a cell, hover over a cell and click on the `run` button to its left. The run button is the circle with the triangle (▶). Alternatively, you can also click on a cell and use the keyboard combination Ctrl+Return (or ⌘+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime

print(f"Today is {datetime.today():%A}.")

Note that the *order in which you run the cells matters*. When you are working through a lab, make sure to always run *all* cells in order, otherwise the code might not work. If you take a break while working on a lab, Colab may disconnect you and in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose **Runtime** → **Run before** from the menu above (or use the keyboard combination Ctrl/⌘ + F8). This will re-execute all cells before the current one.

## Using Colab with a GPU


A **GPU** is a special type of hardware that can significantly speed up some types of computations of machine learning models. Several of the activities in this lab will also run a lot faster if you run them on a GPU.

Follow these steps to run the activities in this lab on a GPU:

1.  In the top menu bar, click on **Runtime**.
2.  Select **Change runtime type** from the dropdown menu.
3.  In the pop-up window under **Hardware Accelerator**, select **GPU** (usually listed as `T4 GPU`).
5.  Click **Save**.

Your Colab session will now restart with GPU access.

Note that access to GPUs is limited and at times, you may not be able to run this lab on a GPU. All activities will still work but they will run slower and you will have to wait longer for some of the cells to finish running.


## Imports

In this lab, you will primarily interact with the `ai_foundations` package, which has been specifically developed for this course. In the background, this package uses the [`gemma`](https://github.com/google-deepmind/gemma) package to load and prompt the Gemma-1B model and the [`plotly`](https://plotly.com/python/) package for creating visualizations.


In [1]:
%%capture
!pip install orbax-checkpoint==0.11.21 jax[cuda12]==0.6.2
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

# Packages used.
import os # For setting a variable needed to load the model onto the GPU.
import pandas as pd # For loading the Africa Galore dataset.

# Functions for clearing outputs and formatting.
from IPython.display import clear_output, display, HTML

# Functions for generating texts with a language model, visualizing probability
# distributions, and loading an n-gram model.
from ai_foundations import generation
from ai_foundations import visualizations
from ai_foundations.ngram import model as ngram_model

# Set the full GPU memory usage for JAX.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

## Load the models

As a preparation for the comparisons between the n-gram model and the transformer model, the following cell loads the Africa Galore dataset and initializes a trigram model whose probabilities are estimated from the n-gram counts in that dataset. It also loads the Gemma-1B model.

<br />

------
> **ℹ️ Info: The Gemma-1B model**
>
>The transformer model you will interact with in this lab is the _Gemma-1B_ model that has been developed and trained by Google [1]. You will learn more about what it means to train a model in later parts of this course but essentially, the process of **training** it to teach the model a specific task using a dataset. In the case of a language model, the task is to predict the next token based on a prompt. When you estimated the probabilities of the n-gram models using the counts in a corpus, you also trained the model. The output of the training process are **parameters** of the model. These parameters guide the model to perform whatever task it was trained to do. In case of the n-gram language model, the model parameters were the conditional probabilities. In case of transformer models, the parameters are a (often very large) collection of numbers that determine the model behavior. A single one of these numbers does not mean anything but in combination, these numbers capture many patterns about language. In the case of Gemma-1B, there are around 1 billion such numbers, which gives this model its name.
------

<br />

Run the following cell to load the two models. Note that loading the Gemma model may take up to a minute.

In [None]:
# Load the Africa Galore dataset.
africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"]
print(f"Loaded Africa Galore dataset with {len(dataset)} paragraphs.\n")

# Load a trigram model whose probabilities have been estimated using the
# Africa Galore dataset.
trigram_model = ngram_model.NGramModel(dataset, 3)
print("Loaded trigram model.\n")

print("Loading Gemma-1B model...")
gemma_model = generation.load_gemma()
print("Loaded Gemma-1B model.")

## Comparing model outputs

Now that the models have been loaded, you can compare their generations. Models are usually evaluated against many criteria and which criteria are deemed most relevant depends on the task you are trying to solve. For example, if you are developing a chatbot that aims to provide information, then it is important the model provides accurate information. If on the other hand, you are developing a model for more creative tasks, then it may be more important that the generations are very diverse and not repetitive.

In this lab, focus on the following evaluation criteria:
1. **Fluency**: Does it read naturally? Grammatical mistakes, for example, would lower the fluency. Similarly, even if sentences are grammatical, if they go on and on, they may be difficult to comprehend.
2. **Coherence**: Does it make logical sense and stay on topic? As language models are predicting one token at a time, the end of a generation may be about a different topic than its beginning.
3. **Relevance**: Does it fit the context or prompt? A model might generate a response composed of random-looking tokens that don't constitute a proper answer.
4. **Bias**: Does the output promote inequalities? Language models are trained on human-written data that likely include biases and promote stereotypes. You may observe very stereotypical outputs that could promote inequalities in the generations of a model.   

### Predict the next token

As a first investigation, generate a single token for the prompt "Jide was hungry so she went looking for." Then vary the prompt and see how the predictions change. You can generate the token by entering a prompt and running the cell below.

Evaluate whether the predicted token always makes sense in the context. Also note whether both models are able to predict the next token for arbitrary prompts. You can even enter a prompt in another language if you speak one and see how the model responds.

Note that the tokens in many transformer models may be only a character or a part of a word. As such, the generations may sometimes end with tokens that are only the beginning of a word.

In [None]:
# @title Compute the next token for a prompt

prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

output_text_transformer, _, _ = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

clear_output()
print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")

output_text_ngram = trigram_model.generate(1, prompt)
print(f"Generation by trigram model:\n{output_text_ngram}")

### Visualize the probability distribution over the predicted next token

To get a better idea of what the model will be likely to generate, it can be useful to visualize the probability distribution over the next token.

Run the following cell to plot the probability distributions over the next token for the prompt below. Each bar of the plots represents a different token, and its height corresponds to the probability assigned to that token by the model. The taller the bar, the more likely the model would choose that token for generating a sequence.

Note that in order to make the plots more compact, they only show the probabilities of the 30 tokens with the highest probabilities. The transformer model assigns a non-zero probability to many more tokens, so the probabilities shown in this plot will likely not sum to 1.

Run the cell and examine the distribution over the tokens below. Would including all these tokens result in fluent texts? Does the distribution of probabilities across tokens make sense? Repeat this process for several different prompts to get a sense of how the distributions by the trigram model compare to the distributions by Gemma-1B.

In [None]:
# @title Visualize the probability distributions

prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

display(HTML("<h3>Gemma-1B</h3>"))

# Visualize the Gemma-1B probabilities.
visualizations.plot_next_token(
    next_token_logits,
    prompt=prompt,
    tokenizer=tokenizer
)

display(HTML("<h3>Trigram model</h3>"))

# Visualize the trigram probabilities.
context_ngram = tuple(prompt.split(" ")[-2:])
if context_ngram in trigram_model.probabilities:
    visualizations.plot_next_token(
        trigram_model.probabilities[context_ngram], prompt=prompt
    )
else:
    print(
        "The trigram model does not make any predictions for the prompt"
        f" \"{prompt}\" since the bigram \"{' '.join(context_ngram)}\""
        f" is not part of the dataset."
    )

When you run the cell above, the model generates a probability distribution for the next token. Some tokens will have higher probabilities than others, meaning they are more likely to be chosen as the next token.

Here are a few likely observations :

1. The Gemma model is able to assign probabilities to more tokens than the trigram model (which fails to assign probabilities to the next token for many prompts).
2. The most probable token will usually be a common word that fits the context of the sentence (e.g., "food" after the prompt "Jide was hungry so she went looking for").
3. The model might suggest words that seem plausible but do not carry a lot of information like "a" or "something".
4. You might notice some tokens have low probabilities, meaning the model considers them less likely to fit but does not completely rule them out, like "work", "help", or "Banku".


### Investigate the context-sensitivity of the two models

What happens to the probability distribution if the context is changed? Generate a next token prediction for the prompt "Jide was thirsty so she went looking for" and consider both the generation and the distribution over the next token. Then, compare the distributions to the distributions of the original prompt "Jide was hungry so she went looking for." For which model do the distributions change more?

In [None]:
# @title Predict the next token and visualize the distributions

prompt = "Jide was thirsty so she went looking for"  # @param {type: "string"}

output_text_transformer, next_token_logits, tokenizer = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=1, loaded_model=gemma_model
    )
)

output_text_ngram = trigram_model.generate(1, prompt)

clear_output()

print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")
output_text_ngram = trigram_model.generate(1, prompt)

print(f"Generation by trigram model:\n{output_text_ngram}")

display(HTML("<h3>Gemma-1B</h3>"))

# Visualize the Gemma-1B probabilities.
visualizations.plot_next_token(next_token_logits, prompt=prompt, tokenizer=tokenizer)

display(HTML("<h3>Trigram model</h3>"))

# Visualize the trigram probabilities.
context_ngram = tuple(prompt.split(" ")[-2:])
if context_ngram in trigram_model.probabilities:
    visualizations.plot_next_token(
        trigram_model.probabilities[context_ngram], prompt=prompt
    )
else:
    print(
        "The trigram model does not make any predictions for the prompt"
        f" \"{prompt}\ since the bigram \"{' '.join(context_ngram)}\""
        f" is not part of the dataset."
    )

#### What did you observe?

When running the transformer model with prompts like `"Jide was thirsty so she went looking for"`, you might notice certain patterns in the predicted next tokens. For instance, you may see drink-related words like "water" suggested more often. This is because the transformer model is **context-sensitive** and associates tokens related to hunger with tokens like "food", and tokens related to thirst with tokens like "water" based on the entire prompt.

The distribution over the next token as predicted by the trigram model, on the other hand, did not change at all. This is because this model's predictions are only based on the last bigram "looking for" which is the same across the two prompts.

This highlights a key shortcoming of n-gram models: They have very short context windows and are unable to take information into account that does not appear at the very end of the context. Transformer models, on the other hand, usually have a context window of hundreds or thousands of tokens and can therefore provide much more context-sensitive answers.

### Generate sequences

In the previous activities, you have focused on predicting only one token. However, usually, you will use language models to predict entire sequences. In this activity, you will compare the outputs of the Gemma-1B model to the outputs of the trigram model when predicting longer sequences.


Change `num_tokens_to_generate` to set the number of tokens to generate so that the generations are longer sequences. Generate continuations for several prompts. Then compare the generations of the two models along the four evaluation criteria mentioned previously:
1. Fluency
2. Coherence
3. Relevance
4. Bias

In [None]:
# @title Generate sequences
prompt = "Jide was hungry so she went looking for"  # @param {type: "string"}

num_tokens_to_generate = 50  # @param {type: "number"}

(output_text_transformer, next_token_logits, tokenizer) = (
    generation.prompt_transformer_model(
        prompt, max_new_tokens=num_tokens_to_generate, loaded_model=gemma_model
    )
)

clear_output()

print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")

output_text_ngram = trigram_model.generate(num_tokens_to_generate, prompt)
print(f"Generation by trigram model:\n{output_text_ngram}")

#### What did you observe?

You likely made some of the following observations:

1. **Fluency**: The Gemma-1B model tends to be much more fluent than the trigram model and generates texts that follow the rules of English. The trigram model tends to generate sequences where short phrases are fluent but globally there may be many mistakes. This is again caused by the small context window of the n-gram model.
2. **Coherence**: While not always perfect, the Gemma-1B model also produces more coherent generations than the trigram model. While you may sometimes observe generations by both models that do not make any sense or go off topic, this tends to be much less of a problem with transformer models such as Gemma.
3. **Relevance**: The trigram model is rarely able to produce relevant responses. This becomes particularly pronounced when you use a question as a prompt. The Gemma-1B model again performs much better against this criterion.
4. **Bias**: Language models tend to suffer from similar biases in the data that they were trained on. For example, if you compare several generations of the models for the prompts "The nurse went to university in Ethiopia." and "The doctor went to university in Ethiopia.", you will likely observe that the Gemma-1B model continues  more often with female pronouns such as "she" or "her" when talking about a nurse, and more often with male pronouns such as "he" and "him" when talking about a doctor. This is because there tend to be many more texts about male doctors and female nurses than the other way round. When models are trained from generally available texts, they likely also learn such undesirable patterns. The trigram model tends to suffer from similar biases but since it rarely generates coherent responses it cannot really be used in practice (as you may have noticed it generates continuations about coffee for these two prompts).








#### Diversity of generations

You likely noticed that the output of both models tend to change often when you run one of the cells above, even with the same prompt. As discussed in the previous labs, this is because the model uses a probability distribution to pick the next token, which introduces a level of stochasticity (randomness) into the prediction. As mentioned before, this variability helps the model generate more diverse and creative outputs. It allows users to regenerate a different response if they are not satisfied with the initial one.


### Controlling the model output

Sometimes it may be desirable to make the output **deterministic** and to always choose the token with the highest probability. This is known as **greedy sampling**.  

The following cell provides you with the option to switch between random sampling and greedy sampling. Depending on whether you set `sampling_mode` to `greedy` or `random`, you should get a deterministic or a random output respectively.

Run the following cell multiple times with `sampling_mode` set to `random` and then with `sampling_mode` set to `greedy`. When it is set to `random` you should get different outputs most of the time, when it is set to `greedy` re-running the cell should always lead to the same output.


In [None]:
# @title Random vs. deterministic generations
prompt = "Jide was thirsty so she went looking for"  # @param {type: "string"}

num_tokens_to_generate = 50  # @param {type: "number"}

sampling_mode = "random"  # @param {type: "string", values:["random", "greedy"]}


(output_text_transformer, next_token_logits, tokenizer) = (
    generation.prompt_transformer_model(
        prompt,
        max_new_tokens=num_tokens_to_generate,
        loaded_model=gemma_model,
        sampling_mode=sampling_mode,
    )
)
clear_output()

print(f"Generation by Gemma-1B:\n{output_text_transformer}\n\n")

output_text_ngram = trigram_model.generate(
    num_tokens_to_generate, prompt, sampling_mode=sampling_mode
)
print(f"Generation by trigram model:\n{output_text_ngram}")

#### Balancing creativity and consistency

Sampling from a probability distribution allows the model to explore a range of possible next tokens, fostering creativity and generating varied outputs. This approach contrasts with always picking the token with the highest probability, which focuses on the most likely next token, as you have experienced above.

Different applications require different settings for this balance. For creative tasks such as generating stories, sampling from the probability distribution is ideal. This is because it allows the model to explore various possibilities and produce more imaginative results.

If accuracy, consistency, and reliability are important for your use case, it is better to always choose the token with the highest probability. There are also methods that allow for a balance between these two approaches. You will learn more about these in later courses.

## Takeaways

You have now directly compared the generations of a trigram model and a transformer model and have observed many differences. These comparisons highlighted contrasts in terms of fluency, coherence and relevance between the two models. While the n-gram model often generated word salads or failed to generate a continuation at all, the transformer model generally generated quite reasonable responses (though sometimes they may have not been entirely perfect either).

Note that this comparison was stacked against the n-gram model. That is because the difference between the trigram model and the Gemma-1B model, which were both trained the Africa Galore dataset, is not only one of implementation. The Gemma-1B model has also been trained on a very large dataset. In comparison, the trigram model has only been trained on the paragraphs in the Africa Galore dataset. That being said, even if you had trained the n-gram model on as much data as the Gemma-1B model, the transformer model would have still performed much better.

There are two primary reasons for this:
- Transformers have much larger context windows and can therefore consider the information of tokens that are further away from the token to be generated. N-gram models, on the other hand, only have a context window of $n-1$. So in the case of the trigram model, the model only considered the last two tokens for making predictions.
- Transformers are based on neural networks that can learn **sophisticated** and **abstract** patterns. As you will learn more in later courses, neural networks can learn much more sophisticated patterns, and for example can learn that *food* and *snack* have related meanings. This allows the model to abstract away from specific words and learn more general patterns about language, which in return allows it to generate more diverse and more coherent responses.

## Summary

This is the end of the **Compare N-Gram Models and Transformer Language Models** lab.

In this lab, you:

- Experienced what generations of transformer models look like and how they compare to the generations of n-gram models.

- Tried different prompts and observed how the model predictions and their probabilities changed (or did not change) based on the context.

- Visualized the probability distributions over the next token to gain a deeper understanding of the model behavior when randomly sampling the next token.

- Compared the models' abilities in generating longer sequences of text and explored how you can make the generations deterministic.


## References

[1] Kamath et al. (Gemma Team). 2025. Gemma 3 Technical Report. Google DeepMind, London. arXiv:2503.19786. Retrieved from https://arxiv.org/pdf/2503.19786.