# Generating Text in Chatbots

In [6]:
import sys
import subprocess
import pkg_resources

# Find out which packages are missing.
installed_packages = {dist.key for dist in pkg_resources.working_set}
required_packages = {'torch', 'transformers'}
missing_packages = required_packages - installed_packages

# If there are missing packages install them.
if missing_packages:
    print('Installing the following packages: ' + str(missing_packages))
    python = sys.executable
    subprocess.check_call([python, '-m', 'pip', 'install', *missing_packages], stdout=subprocess.DEVNULL)

Installing the following packages: {'transformers'}


 <ins>Note</ins>: Windows users should enable their device for development, as described in the link https://learn.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development

## Perplexity

In the code that follows, we measure the perplexity of the _gpt2_ model using three datasets.

In [3]:
import torch 
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the models.
model_name = "gpt2"

model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

The perplexity calculation consists of various steps.

In [4]:
from tqdm import tqdm

max_len = model.config.n_positions
#  Use at least 512 tokens for context.
stride = 512

# Calculate the perplexity of the model.
def calc_perplexity(encodings):

    stack = []
    
    # Read the data using a sliding window for the context.
    for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
        start_pos = max(stride-max_len+i, 0)
        end_pos = min(i+stride, encodings.input_ids.size(1))
        trg_len = end_pos - i
        inp_ids = encodings.input_ids[:, start_pos:end_pos].to(device)
        trg_ids = inp_ids.clone()
        trg_ids[:, :-trg_len] = -100

        # Calculate the negative log likelihood.
        with torch.no_grad():
            out = model(inp_ids, labels=trg_ids)
            nll = out[0] * trg_len

        # Negative log-likelihood stack.
        stack.append(nll)
    
    return torch.exp(torch.stack(stack).sum()/end_pos).item()

In [7]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 kB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downlo

It's time to evaluate the model on the three diverse datasets.

<ins>Warning</ins>: This process will take several minutes to finish.

In [None]:
from datasets import load_dataset

# Load the dataset.
testset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

encodings = tokenizer("\n\n".join(testset["text"]), return_tensors="pt")
print("The perplexity of the wikitext model: %.2f" % calc_perplexity(encodings))

# Load the dataset.
testset = load_dataset("tiny_shakespeare", "default", split="test")

encodings = tokenizer("\n\n".join(testset["text"]), return_tensors="pt")
print("The perplexity of the tiny_shakespeare model: %.2f" % calc_perplexity(encodings))

# Load the dataset.
testset = load_dataset("iamholmes/tiny-imdb", "iamholmes--tiny-imdb", split="test")

encodings = tokenizer("\n\n".join(testset["text"]), return_tensors="pt")
print("The perplexity of the tiny-imdb model is: %.2f" % calc_perplexity(encodings))

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-raw-v1 to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


Token indices sequence length is longer than the specified maximum sequence length for this model (287644 > 1024). Running this sequence through the model will result in indexing errors



  0%|          | 0/562 [00:00<?, ?it/s][A[A[A


  0%|          | 1/562 [00:07<1:13:57,  7.91s/it][A[A[A


  0%|          | 2/562 [00:20<1:39:55, 10.71s/it][A[A[A


  1%|          | 3/562 [00:27<1:22:01,  8.80s/it][A[A[A


  1%|          | 4/562 [00:33<1:12:57,  7.85s/it][A[A[A


  1%|          | 5/562 [00:40<1:11:02,  7.65s/it][A[A[A


  1%|          | 6/562 [00:47<1:08:16,  7.37s/it][A[A[A


  1%|          | 7/562 [00:55<1:08:42,  7.43s/it][A[A[A


  1%|▏         | 8/562 [01:01<1:06:14,  7.17s/it][A[A[A


  2%|▏         | 9/562 [01:09<1:09:01,  7.49s/it][A[A[A


  2%|▏         | 10/562 [01:16<1:06:26,  7.22s/it][A[A[A


  2%|▏         | 11/562 [01:24<1:07:51,  7.39s/it][A[A[A


  2%|▏         | 12/562 [01:30<1:05:12,  7.11s/it][A[A[A


  2%|▏         | 13/562 [

## What we have learned …

| |
| --- |
| **Performance metrics**<ul><li>perplexity</li></ul> |
| |