# **Tutorial on Response Generation with SpeechBrain**


This tutorial will guide you through the process of **fine-tuning the pretrained GPT2 model** that is available in the HuggingFace Transformers library for response generation.


**What is a dialogue system?**

In previous labs, we have implemented machine translation, which is used to read the source language (input) and generate the desired language (output). Similarly, in a dialogue system, we will implement a model to generate a response given a context. This is also known as Natural Language Generation (NLG).

<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*yef6QgRpT1ktP6BpHelPmA.png" alt="drawing" width="700" align="center"/>


**Transformers for Language Modeling**

As we’ve seen in the previous labs, the original transformer model is made up of an encoder and decoder – each is a stack of what we can call transformer blocks. A lot of the subsequent research works try to focus only on either the encoder or decoder, and use just one stack of transformer blocks – stacking them up as high as practically possible and feeding them massive amounts of training text.

<img src="https://jalammar.github.io/images/gpt2/gpt-2-transformer-xl-bert-3.png" alt="drawing" width="700" align="center"/>

How high can we stack up these blocks? It turns out that’s one of the main distinguishing factors between the different GPT2 model sizes:

<img src="https://jalammar.github.io/images/gpt2/gpt2-sizes-hyperparameters-3.png" alt="drawing" width="700" align="center"/>





**Architectures of interest for this tutorial**

We will consider the smallest pre-trained GPT2 model : GPT-2.

GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a dataset of 8 million web pages. GPT-2 is trained with a simple objective: **predict the next word, given all of the previous words within some text**. The diversity of the dataset causes this simple goal to contain naturally occurring demonstrations of many tasks across diverse domains. GPT-2 is a direct scale-up of GPT, with more than 10X the parameters and trained on more than 10X the amount of data. Please refer to the official paper to obtain more details: [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf).


You could find some helpful resources here:

*   [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
*   [How to build a State-of-the-Art Conversational AI with Transfer Learning](https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313)
*   [Fun Article about Fine-Tuning for Superhero Descriptions](https://towardsdatascience.com/unleashing-the-power-of-gpt-how-to-fine-tune-your-model-da35c90766c4#:~:text=By%20fine-tuning%20GPT-3%2C%20creating%20a%20highly%20customized%20and,code%20and%20without%20assuming%20prior%20knowledge%20about%20GPT-3.)

*  [GPT vs Bert](https://medium.com/@10shubhamkedar10/gpt-vs-bert-12d108956260)




**With this tutorial, you will learn how to:**

1. Instantiate a pretrained GPT2.
2. Fine-tuning GPT2 on MultiWOZ with SpeechBrain for response generation task.



Let's first install all the needed packages:

In [1]:
%%capture
# !git clone https://github.com/speechbrain/speechbrain.git
!pip install speechbrain
!pip install sacrebleu

## **1. Generate Texts  with GPT-2**







In [2]:
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model='gpt2')
set_seed(42)
generator("Hello, I am omer,", max_length=30, num_return_sequences=5)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cuda:0
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'Hello, I am omer, for you be my true friends, and my wife and children, and my mother and father, and my sister-'},
 {'generated_text': 'Hello, I am omer, mister.\n\nHm? What did you say, mister?\n\nNervous. Is'},
 {'generated_text': 'Hello, I am omer, or I am not there, or I am not there. When I die I give up the world: the only'},
 {'generated_text': 'Hello, I am omer, you are me. Please tell me all about this! Here I am, with you. I will tell you all'},
 {'generated_text': 'Hello, I am omer, and am glad you are taking my question as you do, because I know it is something the English have a special'}]

Here, we can explore the model with:

In [3]:
print(generator.model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)



## **2. Pretrain GPT-2 and Fine-tune**
Until now, we only saw how to use pre-trained GPT-2 to continue our own sentences.
As we have learned in previous labs, the suggested way to use SpeechBrain is to directly plug your pre-trained model into your pipeline to fine-tune it while training our final model.


Remember in "Week 7: Pretrained and fine-tune" lab, Wav2vec2 is offered as a **lobe** in SpeechBrain. Its implementation can be found in `speechbrain.lobes.models.huggingface_transformers.wav2vec2.py`.  We need to have a similar interface for GPT. Then, GPT can simply be added as a block to your hyper-params file:

For GPT-2:
```yaml
GPT2: !new:speechbrain.lobes.models.huggingface_transformers.gpt.GPT
    source: !ref <gpt_hub>
    freeze: !ref <freeze_gptmodel>
    save_path: !ref <gpt_folder>
    max_new_tokens: !ref <max_new_tokens>
    num_beams: !ref <num_beams>
    top_k: !ref  <top_k>
    top_p: !ref <top_p>
```

- *freeze* enables you to fine-tune (False) or freeze (True) the neural parameters while training your final model.


The GPT model is just a neural network that can be applied to your input data and can be jointly trained with the downstream task of interest. GPT interface has been already implemented in Speechbrain. Its implementation can be found in `speechbrain.lobes.models.huggingface_transformers.gpt.py`.

```
"""This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.

Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Pooneh Mousavi 2023
 * Simone Alghisi 2023
"""

import logging
import torch

from speechbrain.lobes.models.huggingface_transformers.huggingface import (
    HFTransformersInterface,
)


logger = logging.getLogger(__name__)


class GPT(HFTransformersInterface):
    """This lobe enables the integration of HuggingFace pretrained GPT model.
     Source paper whisper:
        https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf
    Transformer from HuggingFace needs to be installed:
        https://huggingface.co/transformers/installation.html

    The model can be finetuned. It will download automatically the model from
    HuggingFace or use a local path.

    Arguments
    ---------
    source : str
        HuggingFace hub name: e.g "gpt2"
    save_path : str
        Path (dir) of the downloaded model.
    freeze : bool (default: False)
        If True, the model is frozen. If False, the model will be trained
        alongside with the rest of the pipeline.
    max_new_tokens : int
        Maximum count of new tokens allowed.
    min_length : int
        Minimum count of input tokens
    top_k : int
        Top results count to keep
    top_p : float
        Proportion of top results to keep
    num_beams : int
        Number of decoder beams
    eos_token_id : int
        Index of end-of-sentence token.
    early_stopping : int
        Whether to stop training early.

    Example
    -------
    >>> model_hub = "gpt2"
    >>> save_path = "savedir"
    >>> model = GPT(model_hub, save_path)
    >>> tokens = torch.tensor([[1, 1]])
    >>> tokens_type = torch.tensor([[1, 1]])
    >>> attention_mask = torch.tensor([[1, 1]])
    >>> outputs = model(tokens, tokens_type, attention_mask)
    """

    def __init__(
        self,
        source,
        save_path,
        freeze=False,
        max_new_tokens=200,
        min_length=1,
        top_k=45,
        top_p=0.9,
        num_beams=8,
        eos_token_id=50258,
        early_stopping=True,
    ) -> None:
        super().__init__(
            source=source, save_path=save_path, freeze=freeze, with_lm_head=True
        )
        self.max_new_tokens = max_new_tokens
        self.min_length = min_length
        self.top_k = top_k
        self.top_p = top_p
        self.num_beams = num_beams
        self.early_stopping = early_stopping
        self.eos_token_id = eos_token_id

        self.load_tokenizer(source=source, pad_token=None, use_fast=False)

        if self.freeze:
            logger.warning("huggingface_GPT - GPT  is frozen.")
            self.model.train()  # we keep it to train to have dropout and LN computed adequately
            for param in self.model.parameters():
                param.requires_grad = False

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ):
        """Takes an input a history of conversation and returns its corresponding reply.

        Arguments
        ---------
        input_ids : torch.Tensor
            A batch of input-id to transform to features.
        token_type_ids : torch.Tensor
            Token Type(Speaker) for each token in input_ids.
        attention_mask : torch.Tensor
            A batch of attention_mask.

        Returns
        -------
        output : torch.Tensor
            Reply to conversation
        """
        with torch.set_grad_enabled(not self.freeze):
            output = self.model.forward(
                input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
            )
        return output

    def generate(
        self,
        input_ids: torch.Tensor,
        token_type_ids,
        attention_mask: torch.Tensor,
        decoder_type="greedy",
    ):
        """Takes an input a history of conversation and returns its corresponding reply.

        Arguments
        ---------
        input_ids : torch.Tensor
            A batch of input-id which are dialogue context tokens
        token_type_ids : torch.Tensor
        attention_mask : torch.Tensor
            A batch of attention_mask.
        decoder_type : str
            It shows strategy for autoregressive decoding either beam search or greedy.

        Returns
        -------
        hyp : torch.Tensor
            Conversation reply.
        """

        with torch.no_grad():
            if decoder_type == "beam":
                # beam decoding based on the input_ids which are dialogue context tokens (here only history)
                hyp = self.model.generate(
                    input_ids=input_ids,
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask,
                    do_sample=True,
                    max_new_tokens=self.max_new_tokens,
                    min_length=self.min_length,
                    top_k=self.top_k,
                    top_p=self.top_p,
                    num_beams=self.num_beams,
                    num_return_sequences=1,
                    eos_token_id=self.eos_token_id,
                    early_stopping=self.early_stopping,
                )
            else:
                # greedy decoding based on the input_ids which are dialogue context tokens (here only history)
                hyp = self.model.generate(
                    input_ids,
                    token_type_ids=token_type_ids,
                    max_new_tokens=self.max_new_tokens,
                    eos_token_id=self.eos_token_id,
                    attention_mask=attention_mask,
                )
        return hyp
```








The inputs of GPT2 model are input_ids and token_type_ids. Input_ids are a concatenation of all tokenized histories with the <speaker_token> of each sentence added before it. The token_type_ids has the same length as the input_ids and indicates who is the speaker of each token.
For example:
```
history: 'Hi how are you', 'I am fine and you', 'I am good']
input_ids : <speaker_1> Hi how are you <speaker_2> I am fine and you? <speaker_1> I am good>
token_type_ids : [[<speaker_1>,<speaker_1>,<speaker_1>,<speaker_1>],
                  [<speaker_2>,<speaker_2>,<speaker_2>,<speaker_2>,<speaker_2>],
                  [<speaker_1> <speaker_1>,<speaker_1>]]

```

**Note:** It is just an illustrating example. The real input has the token_id instead of the words.

## 3 Fine-tuning GPT-2 on Resonse Generation (with MultiWOZ)
Now we will discuss how to fine-tune the GPT-2  model for response generation. To achieve this, we will be using a smaller version of the MultiWOZ 2.1 dataset. Multi-Domain Wizard-of-Oz dataset (MultiWOZ), a fully-labeled collection of human-human written conversations spanning over multiple domains and topics. Instead of using all the data, we will set a parameter that identifies the percentage of data to be sampled. For this experiment, we will just sample 1000, 100, 200 of training , valid and test entries.


### **Step 1: Prepare your data**
The goal of data preparation is to create the data manifest files.
These files tell SpeechBrain where to find the dialogue history and the system reply. They are text files written in the popular CSV and JSON formats.

#### **Data manifest files**
Let's take a look into how a data manifest file in JSON format looks like:


```json
{
    "SNG01919.json_1": {
        "history": [
            "i need a taxi from the missing sock and i need to get to my destination by 08:30 . can you help ?"
        ],
        "reply": "i can help you with that . where are you going ?",
        "length": 145
    },
    "SNG01919.json_3": {
        "history": [
            "i need a taxi from the missing sock and i need to get to my destination by 08:30 . can you help ?",
            "i can help you with that . where are you going ?",
            "i am going to el shaddai"
        ],
        "reply": "okay your booking is complete . be on the lookout for a white volkswagen",
        "length": 128.33333333333334
    },
    "SNG01919.json_5": {
        "history": [
            "i need a taxi from the missing sock and i need to get to my destination by 08:30 . can you help ?",
            "i can help you with that . where are you going ?",
            "i am going to el shaddai",
            "okay your booking is complete . be on the lookout for a white volkswagen",
            "i will also need the contact number please ."
        ],
        "reply": "their contact number is 07053289961 . do you need any further assistance ?",
        "length": 131
    }
}
```
As you can see, we have a hierarchical structure in which the first key is a **unique identifier** of the name_of the dialouge+the turn number.

You can specify here the entries with the name you prefer. However, there must be a matching between the name of these entries and what the experiment script (e.g, train.py) expects. We will elaborate more on this later.


#### **Preparation Script**
Every dataset is formatted in a different way. The script that parses your own dataset and creates the JSON or the CSV files is something that you are supposed to write. Most of the time, this is very straightforward.

For the MultiWOZ dataset, for instance, we wrote this data preparation script called multiwoz_prepare.py.
The function automatically downloads the data. We search for all the dialogues and split them based on the turns. Our goal is to train a model that could produce reasonable system-generated responses. Therefore, as our gold labels, we extract the system turns (replies uttered by the system). It will be our reply field in manifest files. Then, we will extract the history for each reply. The history is a list of sentences prior to that response. Each sentence is uttered either by a system or a user. You could see as we go further in the dialogue, the history become bigger since it contains all previous histories + new one. Even rows in history are uttered by the user and odd rows by the system. For the length, we take an average over the length of all sentences( number of words in each sentence). This field is only used by the data loader to sort data to avoid any unnecessary padded_tokens. It is not the actual length. The actual length of the inputs is varied depending on the tokenizer and token_type that we will use.

You can use this script as a good base for your custom preparation on your target dataset. As you can see, we create three separate data manifest files to manage the training, validation, and test phases.



It is a good practice to make your data as clean as possible before feeding it to the model. To prepare the text data for the model building, we perform text preprocessing. Some of the preprocessing steps are:
Removing punctuations like. , ! $( ) * % @
Removing URLs
Lower/upper casing
Mapping abbreviations and short forms to their full forms (e.g. "it's" to "it is")
We will apply these preprocessing steps using multiwoz_prepare.py and mapping.pair files.

In [4]:
%%file mapping.pair
it's	it is
don't	do not
doesn't	does not
didn't	did not
you'd	you would
you're	you are
you'll	you will
i'm	i am
they're	they are
that's	that is
what's	what is
couldn't	could not
i've	i have
we've	we have
can't	cannot
i'd	i would
i'd	i would
aren't	are not
isn't	is not
wasn't	was not
weren't	were not
won't	will not
there's	there is
there're	there are
. .	.
restaurants	restaurant -s
hotels	hotel -s
laptops	laptop -s
cheaper	cheap -er
dinners	dinner -s
lunches	lunch -s
breakfasts	breakfast -s
expensively	expensive -ly
moderately	moderate -ly
cheaply	cheap -ly
prices	price -s
places	place -s
venues	venue -s
ranges	range -s
meals	meal -s
locations	location -s
areas	area -s
policies	policy -s
children	child -s
kids	kid -s
kidfriendly	kid friendly
cards	card -s
upmarket	expensive
inpricey	cheap
inches	inch -s
uses	use -s
dimensions	dimension -s
driverange	drive range
includes	include -s
computers	computer -s
machines	machine -s
families	family -s
ratings	rating -s
constraints	constraint -s
pricerange	price range
batteryrating	battery rating
requirements	requirement -s
drives	drive -s
specifications	specification -s
weightrange	weight range
harddrive	hard drive
batterylife	battery life
businesses	business -s
hours	hour -s
one	1
two	2
three	3
four	4
five	5
six	6
seven	7
eight	8
nine	9
ten	10
eleven	11
twelve	12
anywhere	any where
good bye	goodbye


Writing mapping.pair


In [5]:
from itertools import product
from pathlib import Path
from statistics import mean
from typing import Any, Dict, List, Optional, Set, Tuple
import json
import logging
import os
import random
import re
import shutil


from tqdm import tqdm

from speechbrain.utils.data_utils import download_file

logger = logging.getLogger(__name__)
MULTIWOZ_21_DATASET_URL = (
    "https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip"
)

"""
Trade script used for tokenization porposes.

The original one can be found at:
https://github.com/jasonwu0731/trade-dst/blob/master/create_data.py
"""

def prepare_mwoz_21(
    output_folder: str,
    data_folder: str,
    override: bool,
    replacements_path: str,
    tr_random_dialogues: Optional[int] = None,
    dev_random_dialogues: Optional[int] = None,
    te_random_dialogues: Optional[int] = None,
    seed: int = 42,
) -> None:
    # set seed
    random.seed(seed)

    dataset_folder = os.path.join(data_folder, "MultiWOZ_21")
    if not os.path.isdir(dataset_folder):
        download_mwoz_21(data_folder)
    else:
        logger.info(f"{dataset_folder} exists, skipping.")

    tr_split, dev_split, te_split = get_splits(dataset_folder)

    data_path = os.path.join(dataset_folder, "data.json")
    build_dialogue_dataset(
        data_path,
        logger,
        tr_split,
        "train.json",
        output_folder,
        override,
        replacements_path,
        tr_random_dialogues,
    )

    build_dialogue_dataset(
        data_path,
        logger,
        dev_split,
        "valid.json",
        output_folder,
        override,
        replacements_path,
        dev_random_dialogues,
    )

    build_dialogue_dataset(
        data_path,
        logger,
        te_split,
        "test.json",
        output_folder,
        override,
        replacements_path,
        te_random_dialogues,
    )


def insertSpace(token, text):
    sidx = 0
    while True:
        sidx = text.find(token, sidx)
        if sidx == -1:
            break
        if (
            sidx + 1 < len(text)
            and re.match("[0-9]", text[sidx - 1])
            and re.match("[0-9]", text[sidx + 1])
        ):
            sidx += 1
            continue
        if text[sidx - 1] != " ":
            text = text[:sidx] + " " + text[sidx:]
            sidx += 1
        if sidx + len(token) < len(text) and text[sidx + len(token)] != " ":
            text = text[: sidx + 1] + " " + text[sidx + 1 :]
        sidx += 1
    return text


def normalize(text, replacements):
    # lower case every word
    text = text.lower()

    # replace white spaces in front and end
    text = re.sub(r"^\s*|\s*$", "", text)

    # hotel domain pfb30
    text = re.sub(r"b&b", "bed and breakfast", text)
    text = re.sub(r"b and b", "bed and breakfast", text)

    # weird unicode bug
    text = re.sub("(\u2018|\u2019)", "'", text)

    # replace st.
    text = text.replace(";", ",")
    text = re.sub("$\/", "", text)
    text = text.replace("/", " and ")

    # replace other special characters
    text = text.replace("-", " ")
    text = re.sub('["\<>@\(\)]', "", text)  # remove

    # insert white space before and after tokens:
    for token in ["?", ".", ",", "!"]:
        text = insertSpace(token, text)

    # insert white space for 's
    text = insertSpace("'s", text)

    # replace it's, does't, you'd ... etc
    text = re.sub("^'", "", text)
    text = re.sub("'$", "", text)
    text = re.sub("'\s", " ", text)
    text = re.sub("\s'", " ", text)
    for fromx, tox in replacements:
        text = " " + text + " "
        text = text.replace(fromx, tox)[1:-1]

    # remove multiple spaces
    text = re.sub(" +", " ", text)

    # concatenate numbers
    tokens = text.split()
    i = 1
    while i < len(tokens):
        if re.match("^\d+$", tokens[i]) and re.match("\d+$", tokens[i - 1]):
            tokens[i - 1] += tokens[i]
            del tokens[i]
        else:
            i += 1
    text = " ".join(tokens)
    return text


def get_replacements(
    replacements_path: str = "trade/utils/mapping.pair",
) -> List[Tuple[str, str]]:
    """
    Get the replacements from a given file. Used by trade preprocessing.

    Arguments
    ---------
    replacements_path: str
        File containing from, to pairs, one per line.

    Returns
    -------
    replacements: List of replacements, i.e. pairs of str
        Pairs of elements used to substitute the first element with the second.
    """
    replacements = []
    with open(replacements_path, "r") as fin:
        for line in fin.readlines():
            tok_from, tok_to = line.replace("\n", "").split("\t")
            replacements.append((" " + tok_from + " ", " " + tok_to + " "))
    return replacements


TOKEN_EXCEPTIONS = {"childs": "children", "businesss": "businesses", "inchs": "inches"}

PATTERN_EXCEPTIONS = {"breakfasts": "b&bs"}


def invert_trade_subtokenization(
    original_seq: str,
    trade_seq: str,
    token_exceptions: Dict[str, str] = TOKEN_EXCEPTIONS,
    pattern_exceptions: Dict[str, str] = PATTERN_EXCEPTIONS,
    subtoken_special_chrs: List[str] = [" -", " _"],
) -> str:
    """
    Invert all trade subtokenizations in a string given the original sequence.

    Arguments
    ---------
    original_seq: str
        The original sequence.
    trade_seq: str
        The sequence that has been pre-processed by trade.
    token_exceptions: dict, keys are str, values are str
        A dictionary to map merged token to their correct counterpart. E.g.
        child -s is merged into childs, but the correct token is children.
    pattern_exceptions: dict, keys are str, values are str
        A dictionary to map patterns to their correct counterpart. E.g.
        after the pre-processing "b&bs" is mapped to "bed and breakfast -s",
        making the search of breakfasts impossible if not handled by such
        exceptions.
    subtoken_special_chrs: list of str
        List containing the special characters that are used for subtokens.

    Returns
    -------
    corrected_seq: str
        The sequence corrected, i.e. subtokens replaced by tokens.
    """
    regex = "|".join(subtoken_special_chrs)
    subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
    search_after: int = 0
    while len(subtoken_pieces) > 1:
        # example: 'the wind is moderate -ly strong'
        # split: ['the wind is moderate ', 'ly strong']
        # split[0]: 'the wind is moderate' --> split on whitespace ['the', 'wind', 'is', 'moderate']
        left_side = subtoken_pieces[0].split()
        subtoken_left = left_side[-1]
        # split[1]: 'ly strong' --> split on whitespace ['ly', 'strong']
        right_side = subtoken_pieces[1].split()
        subtoken_right = right_side[0]
        # try merging the subtoken parts to form a token, i.e. moderate + ly
        token = "".join([subtoken_left, subtoken_right])

        if token in token_exceptions:
            # if you match an exception, replace the token with the exception
            token = token_exceptions[token]

        # assume there are no tokens on left and right side of the subtokens' pieces
        left_token = None  # if token is at the beginnig
        right_token = None  # if token is at the end
        # try looking for them
        if len(left_side) > 1:
            left_token = left_side[-2]
        if len(right_side) > 1:
            right_token = right_side[1]

        # start from a complete match, and progressively remove left and right
        # tokens to counter TRADE preprocessing of some tokens
        # The order is
        # 1. True, True
        # 2. True, False
        # 3. False, True
        # 4. False, False
        # basically, at the end you try looking only for the merged token
        pattern: str = ""
        idx: int = -1
        for use_left, use_right in product((True, False), (True, False)):
            pattern = token
            if (left_token is not None) and use_left:
                pattern = " ".join([left_token, pattern])
            if right_token is not None and use_right:
                pattern = " ".join([pattern, right_token])

            # check if the pattern is in the exceptions
            if pattern in pattern_exceptions:
                pattern = pattern_exceptions[pattern]
            # Search the pattern
            idx = original_seq[search_after:].lower().find(pattern)
            if idx > -1:
                break

        error: str = f"""
            Pattern search failed in the following case:
            PATTERN =  \t{pattern}
            LEFT SIDE = \t{left_side}
            RIGHT SIDE = \t{right_side}
            ORIG SEQ = \t{original_seq[search_after:]}

            This may be due to further TRADE pre-processing, or not correct merging operation.
            To solve this, add a special rule for the token that breaks the code either as a
            token_exception or a pattern_exception.
        """

        assert idx > -1, error
        # move the index to avoid perfect matches with the same token
        # TODO is probably better to move it of len(left_token + token) or
        # len(token) depending on the match
        search_after += idx + 1
        # reconstruct the sentence with the matched pattern
        trade_seq = " ".join([*left_side[:-1], token, *right_side[1:]])

        # try splitting the sentence again and repeat the process
        subtoken_pieces = re.split(regex, trade_seq, maxsplit=1)
    # Good, no subtokens found: return trade seq
    return trade_seq


def get_json_object(data_path: str) -> dict:
    """
    A function to read a json object and return the python
    dictionary associated to it.

    Arguments
    ---------
    data_path: str
        Path to a json file.

    Returns
    -------
    loaded_json: dict
        A loaded json object.
    """
    with open(data_path, "r") as data_file:
        data = json.load(data_file)

    return data


def load_dialogues(
    data_path: str,
    data_split: List[str],
    replacements: List[Tuple[str, str]],
) -> List[List[Dict[str, Any]]]:
    """
    Load dialogues from data_path, apply trade pre-processing, revert the
    subtokenization, and create a dictionary containing the dialogue id,
    the turn id, and the corrected sequence.

    Arguments
    ---------
    data_path: str
        Path to the json file containing the data.
    data_split: list of str
        List of string containing MultiWOZ 2.1 keys of the dialogues
        associated to a certain split (train, dev, test).
    replacements_path: str
        File containing (from, to) pairs, one per line.

    Returns
    -------
    dialogues: list of list of dict, keys are str, values could be anything
        List of dialogues. Each dialogue is a list of turns. Each turn is a
        dict containing dialogue_idx, turn_idx, and the corrected sequence.
    """

    def get_preprocessed_seq(
        original_seq: str, replacements: List[Tuple[str, str]]
    ) -> str:
        # apply trade normalization
        trade_seq = normalize(original_seq, replacements)
        # merge back subtokens
        sequence = invert_trade_subtokenization(original_seq, trade_seq)
        return sequence

    dialogues: List[List[Dict[str, Any]]] = []

    data = get_json_object(data_path)

    for dialogue_idx in tqdm(data_split, desc="Load Dialogues"):
        dial: List[Dict[str, Any]] = []
        original_dialogue: dict = data[dialogue_idx]
        turns: dict = original_dialogue["log"]
        for i, turn in enumerate(turns):
            sequence = get_preprocessed_seq(turn["text"], replacements)
            to_save = {
                "sequence": sequence,
                "turn_idx": i,
                "dialogue_idx": dialogue_idx,
            }
            dial.append(to_save)
        dialogues.append(dial)
    return dialogues


def create_entry_key(turn: Dict[str, Any]) -> str:
    """
    Creates the entry key for a given entry by considering dialogue id
    and turn id for the given turn.

    Arguments
    ---------
    turn: dict, keys are str, values could be anything
        A dict containing, the dialogue id, the turn id, the sequence,
        and the mean length.
    kwargs: any
        Additional arguments for the current function.

    Returns
    -------
    key: str
        The key for the given turn.
    """
    dialogue_idx = turn["dialogue_idx"]
    turn_idx = turn["turn_idx"]
    return f"{dialogue_idx}_{turn_idx}"


def create_dialogue_dataset(
    dialogues: List[List[Dict[str, Any]]]
) -> Dict[str, Dict[str, Any]]:
    """
    Creates a dialogue dataset starting from a set of dialogues. Each
    entry of the dataset contains the dialogue history and the system
    reply in response to that.

    Arguments
    ---------
    dialogues: list of list of dict, keys are str, values could be anything
        List of dialogues. Each dialogue is a list of turns. Each turn is a
        dict containing dialogue_idx, turn_idx, and the corrected sequence.
    kwargs: any
        Additional arguments for the current function.

    Returns
    -------
    dataset: Dict[str, Dict[str, Any]]
        Dataset, keys are str, values are dictionaries containing the
        dialogue history and the system reply.
    """

    def create_dialogue_dataset_entry(
        turn: Dict[str, Any], history: List[str]
    ) -> Optional[Dict[str, Any]]:
        """
        Creates an entry if the current turn id is odd. An entry is
        composed of the history, which contains the previous turns
        of the current dialogue, and the reply of the system.

        Arguments
        ---------
        turn: dict, keys are str, values could be anything
            A dict containing, the dialogue id, the turn id, the sequence,
            and the mean length.
        replacements_path: str
            Path to TRADE file containing (from, to) pairs, one per line.
        kwargs: any
            Additional arguments for the current function.

        Returns
        -------
        entry: optional dict, keys are str, values could be anything
            Entry of the dialogue dataset. It is a dict containing the history
            of the dialogue, i.e. a list of turns, the reply of the system,
            i.e. a turn, and the mean length.
        """

        turn_idx = turn["turn_idx"]
        entry: Optional[Dict[str, Any]] = None
        if turn_idx % 2 == 0:
            # user turn, simply append it to the history
            user_seq: str = turn["sequence"]
            history.append(user_seq)
        elif turn_idx % 2 == 1:
            # system turn, create the dataset entry, and the append it to the history
            system_seq: str = turn["sequence"]
            history_mean_length = mean([len(turn) for turn in history])
            entry = {
                "history": history.copy(),
                "reply": system_seq,
                "length": history_mean_length + len(system_seq),
            }
            history.append(system_seq)
        return entry

    dataset: Dict[str, Dict[str, Any]] = {}
    for dialogue in tqdm(dialogues, desc="Creating dataset"):
        history: List[str] = []
        for turn in dialogue:
            # custom function to create a dataset entry
            dataset_entry = create_dialogue_dataset_entry(turn, history)
            # custom function to create a dataset key
            key = create_entry_key(turn)
            if dataset_entry is not None:
                dataset[key] = dataset_entry
    return dataset


def save_dialogue_dataset(
    dataset: Dict[str, Dict[str, Any]], file_name: str, dst_folder: str = "."
) -> None:
    """
    Saves the dialogue dataset at dst_folder/file_name as a json file.

    Arguments
    ---------
    dataset: Dict[str, Dict[str, Any]]
        Dataset, keys are str, values are dictionaries containing the
        dialogue history, the system reply, and the mean length.
    file_name: str
        Name of the file where the dataset will be saved.
    dst_folder: str
        Path to the directory where the dataset will be saved. If it
        does not exists, it creates it.
    """
    os.makedirs(dst_folder, exist_ok=True)
    dataset_path = os.path.join(dst_folder, file_name)
    with open(dataset_path, "w") as f:
        json.dump(dataset, f, indent=4)


def encode_dialogue_dataset(
    file_name: str,
    dst_folder: str,
    data_path: str,
    data_split: List[str],
    override: bool,
    logger: logging.Logger,
    replacements_path: str = "utils/mapping.pair",
    random_dialogues: Optional[int] = None,
) -> None:
    """
    Wrapper function that loads processed data stored at
    dst_folder/file_name. If they are not available, it processes the
    original data and then saves them at dst_folder/file_name.

    Arguments
    ---------
    file_name: str
        Name of the file where the dataset will be saved.
    dst_folder: str
        Path to the directory where the dataset will be saved. If it
        does not exists, it creates it.
    data_path: str
        Path to the data pre-processed by trade.
    data_split: list of str
        List of string containing MultiWOZ 2.1 keys of the dialogues
        associated to a certain split (train, dev, test).
    override: bool
        Whether or not override the data stored at dst_folder/file_name.
    logger: logging.Logger instance
        Logger to report the processing steps carried out in the current
        execution.
    replacements_path: str
        Path to TRADE file containing (from, to) pairs, one per line.
    random_dialogues: int
        Number of dialogues to randomly sample from the current data.
    """
    dataset_path = os.path.join(dst_folder, file_name)
    if os.path.isfile(dataset_path) and (not override):
        logger.info(f"Dataset already created at {dataset_path}")
    else:
        replacements = get_replacements(replacements_path)
        logger.info(f"Extract dialogues from {data_path}")
        # custom loading function to return the important elements of a dialogue
        dialogues = load_dialogues(data_path, data_split, replacements)
        if random_dialogues:
            dialogues = random.sample(dialogues, min(random_dialogues, len(dialogues)))
        logger.info("Create dataset")
        dataset = create_dialogue_dataset(dialogues)
        logger.info(f"Save dataset in {dataset_path}")
        save_dialogue_dataset(dataset, file_name, dst_folder)


def build_dialogue_dataset(
    data_path: str,
    logger: logging.Logger,
    data_split: List[str],
    file_name: str = "train.json",
    dst_folder: str = ".",
    override: bool = False,
    replacements_path: str = "utils/mapping.pair",
    random_dialogues: Optional[int] = None,
) -> None:
    """
    Returns the dialogue dataset for the corresponding data_path.

    Arguments
    ---------
    data_path: str
        Path to the data pre-processed by trade.
    logger: logging.Logger instance
        Logger to report the processing steps carried out in the current
        execution.
    data_split: list of str
        List of string containing MultiWOZ 2.1 keys of the dialogues
        associated to a certain split (train, dev, test).
    file_name: str
        Name of the file where the dataset will be saved.
    dst_folder: str
        Path to the directory where the dataset will be saved. If it
        does not exists, it creates it.
    override: bool
        Whether or not override the data stored at dst_folder/file_name.
    replacements_path: str
        Path to TRADE file containing (from, to) pairs, one per line.
    random_dialogues: int
        Number of dialogues to randomly sample from the current data.

    Returns
    -------
    dataset:
        dataset, keys are str, values are dictionaries containing the
        dialogue history, the system reply, and the mean length.
    """
    logger.info(f"Prepare {file_name}")
    encode_dialogue_dataset(
        file_name,
        dst_folder,
        data_path,
        data_split,
        override,
        logger,
        replacements_path,
        random_dialogues=random_dialogues,
    )


def download_mwoz_21(destination):
    """Download dataset repo, unpack it, and remove unnecessary elements.
    Arguments
    ---------
    destination : str
        Place to put dataset.
    """
    mwoz_21_archive = os.path.join(destination, "MultiWOZ_21.zip")
    download_file(MULTIWOZ_21_DATASET_URL, mwoz_21_archive)
    shutil.unpack_archive(mwoz_21_archive, destination)
    shutil.rmtree(os.path.join(destination, "__MACOSX"))

    mwoz_21 = os.path.join(destination, "MultiWOZ_21")
    os.makedirs(mwoz_21, exist_ok=True)

    mwoz_21_repo = os.path.join(destination, "MultiWOZ_2.1")
    for relevant_file in ["data.json", "valListFile.txt", "testListFile.txt"]:
        shutil.move(
            os.path.join(mwoz_21_repo, relevant_file),
            os.path.join(mwoz_21, relevant_file),
        )

    shutil.rmtree(mwoz_21_repo)


def get_splits(dataset_folder) -> Tuple[List[str], List[str], List[str]]:
    mwoz_21_dialouges = get_json_object(os.path.join(dataset_folder, "data.json"))
    dialougues_keys: Set[str] = set(mwoz_21_dialouges.keys())
    tr_split: List[str] = []
    with open(os.path.join(dataset_folder, "valListFile.txt")) as f:
        dev_split: List[str] = [key.strip() for key in f]
    with open(os.path.join(dataset_folder, "testListFile.txt")) as f:
        te_split: List[str] = [key.strip() for key in f]

    for key in dialougues_keys:
        if key not in dev_split and key not in te_split:
            tr_split.append(key)

    return tr_split, dev_split, te_split




For this tutorial, we only use 1000, 100 and 200 sentences for training, valid and test respectively.

In [6]:
prepare_mwoz_21("data_dir", "data", False, "mapping.pair", 1000, 100, 200)


Downloading https://github.com/budzianowski/multiwoz/raw/master/data/MultiWOZ_2.1.zip to data/MultiWOZ_21.zip


MultiWOZ_2.1.zip: 20.2MB [00:00, 21.4MB/s]                            
Load Dialogues: 100%|██████████| 8438/8438 [00:07<00:00, 1136.97it/s]
Creating dataset: 100%|██████████| 1000/1000 [00:00<00:00, 15219.75it/s]
Load Dialogues: 100%|██████████| 1000/1000 [00:00<00:00, 1004.09it/s]
Creating dataset: 100%|██████████| 100/100 [00:00<00:00, 13125.25it/s]
Load Dialogues: 100%|██████████| 1000/1000 [00:00<00:00, 1030.78it/s]
Creating dataset: 100%|██████████| 200/200 [00:00<00:00, 13622.96it/s]


### **Step 2: Tokenizer**
GPT-2 comes equipped with its own tokenizer known as GPT2Tokenizer, which becomes accessible upon instantiation of the Speechbrain GPT interface.

```
  tokenizer = hparams['GPT2'].tokenizer
```
We need to add special tokens to the tokenizer to identify which speaker is talking (system or user).


```
def add_special_tokens_(model, tokenizer, attr_to_special_token) -> None:
    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(
        attr_to_special_token  # type: ignore
    )  # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(
            new_num_tokens=orig_num_tokens + num_added_tokens
        )
```








### **Step 3: Train a Model**

Since we are performing a language modeling task, for our model, we will be finetuning GPT2LMHeadModel. For fine-tuning, it is enough to load the pre-trained version of the model and use the default forward method implementation. It is important to specify the correct hugging face path, which for this model is gpt_hub: gpt2.

The hyperparameter file for our model is the following:

In [7]:
%%file hparams_gpt2.yaml
# ########################################
# Model: GPT2LMHeadModel +  NLL
# Authors:
    # Pooneh Mousavi 2023
    # Simone Alghisi 2023
# ########################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1995
__set_seed: !apply:torch.manual_seed [!ref <seed>]
# Dataset will be downloaded to the `data_original`
data_folder: !PLACEHOLDER
output_folder: !ref /content/results/train_with_gpt2/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
bleu_4_test_file: !ref <output_folder>/bleu_4_test.txt
bleu_4_valid_file: !ref <output_folder>/bleu_4_valid.txt

# URL for the gpt2 model
gpt_hub: gpt2
gpt_folder: !ref <save_folder>/gpt_checkpoint

# Path where data manifest files will be stored
train_annotation: !ref <data_folder>/train.json
valid_annotation: !ref <data_folder>/valid.json
test_annotation: !ref <data_folder>/test.json

skip_prep: False

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

# Special tokens
bos_token: "BOS"
eos_token: "EOS"

system_token: "SPK_1"
user_token: "SPK_2"

special_tokens: [
    !ref <bos_token>,
    !ref <eos_token>,
    !ref <system_token>,
    !ref <user_token>
]

attr_to_special_tokens:
    "bos_token": !ref <bos_token>
    "eos_token": !ref <eos_token>
    "additional_special_tokens": [!ref <system_token>, !ref <user_token>]

# history_window, i.e. how many user-system exchanges consider as context.
max_history: 5

ignore_index: -100
label_smoothing: 0

####################### Training Parameters ####################################
number_of_epochs: 4
batch_size: 8
test_batch_size: 4
lr: 1.97125e-4

#freeze GPT model
freeze_gptmodel: False
num_beams: 3
max_new_tokens: 50
top_k: 45
top_p: 0.9


train_dataloader_options:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 2
    drop_last: False

test_dataloader_options:
    batch_size: !ref <test_batch_size>
    shuffle: True
    num_workers: 2
    drop_last: True

# Masks
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask

# gpt model
gpt_model: !new:speechbrain.lobes.models.huggingface_transformers.gpt.GPT
    source: !ref <gpt_hub>
    freeze: !ref <freeze_gptmodel>
    save_path: !ref <gpt_folder>
    max_new_tokens: !ref <max_new_tokens>
    num_beams: !ref <num_beams>
    top_k: !ref  <top_k>
    top_p: !ref <top_p>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    gpt_model: !ref <gpt_model>

model: !new:torch.nn.ModuleList
    - [!ref <gpt_model>]


ce_loss: !new:torch.nn.CrossEntropyLoss
    ignore_index: !ref <ignore_index>
    label_smoothing: !ref <label_smoothing>

opt_class: !name:torch.optim.AdamW
    lr: !ref <lr>


lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        gpt_model: !ref <gpt_model>
        lr_annealing_output: !ref <lr_annealing>
        counter: !ref <epoch_counter>


bleu_4_computer: !name:speechbrain.utils.bleu.BLEUStats
    max_ngram_order: 4

bleu_2_computer: !name:speechbrain.utils.bleu.BLEUStats
    max_ngram_order: 2

Writing hparams_gpt2.yaml


The training script follows a standard approach, and you should be able to identify the common operations that are necessary to implement a neural classifier:


In [8]:
%%file train.py
"""
Recipe for training a gpt_based response generation model with MultiWOZ.
The system employs GPT2 (https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf).
This recipe takes the GPT2LMHeadModel to fine-tune for the response generation task on the NLL.

To run this recipe, do the following:
> python train_with_gpt.py hparams/train_gpt.yaml

Authors
 * Pooneh Mousavi 2023
 * Simone Alghisi 2023
"""


import sys
import speechbrain as sb
import torch
from itertools import chain
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main
import math
from speechbrain.dataio.batch import PaddedBatch


class ResGenBrain(sb.Brain):
    def compute_forward(self, batch, stage):
        """Computation pipeline based on a gpt decoder."""
        # Get required data from batch
        batch = batch.to(self.device)
        input_ids, _ = batch.input_ids
        token_type_ids, _ = batch.token_type_ids

        # Forward Pass
        padding_mask = ~self.hparams.padding_mask(
            input_ids, pad_idx=tokenizer.unk_token_id
        )
        outputs = self.modules.gpt_model(
            input_ids, token_type_ids, padding_mask
        ).logits

        return outputs

    def compute_objectives(self, predictions, batch, stage):
        """Computes the NLL-loss using reply as label."""
        # Get required data from batch
        batch = batch.to(self.device)
        ids = batch.id
        lm_labels, labels_lens = batch.lm_labels
        history_bos, history_lens = batch.history_bos
        reply_eos, reply_lens = batch.reply_eos
        history_token_type, _ = batch.history_token_type

        loss = self.hparams.ce_loss(
            predictions.flatten(end_dim=-2), lm_labels.flatten()
        )

        if stage == sb.Stage.VALID:
            # hyps = None
            # current_epoch = self.hparams.epoch_counter.current
            # if current_epoch % self.hparams.valid_search_interval == 0:
            # history_bos = torch.LongTensor([hparams["bos_index"]] + (history_bos))
            padding_mask = ~self.hparams.padding_mask(
                history_bos, pad_idx=tokenizer.unk_token_id
            )
            hyps = self.modules.gpt_model.generate(
                history_bos.detach(),
                history_token_type.detach(),
                padding_mask.detach(),
            )
        elif stage == sb.Stage.TEST:
            padding_mask = ~self.hparams.padding_mask(
                history_bos, pad_idx=tokenizer.unk_token_id
            )
            hyps = self.modules.gpt_model.generate(
                history_bos.detach(),
                history_token_type.detach(),
                padding_mask.detach(),
                "beam",
            )

        if stage != sb.Stage.TRAIN:
            reply_truncated = [
                reply_eos[i][
                    : int(reply_lens[i].item() * reply_eos.shape[1] - 1)
                ].detach()
                for i in range(reply_eos.shape[0])
            ]
            predicted_words = tokenizer.batch_decode(
                hyps[:, history_bos.shape[1] :],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            target_words = tokenizer.batch_decode(
                reply_truncated,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
            self.bleu_4_metric.append(ids, predicted_words, target_words)
            self.bleu_2_metric.append(ids, predicted_words, target_words)
            if stage != sb.Stage.TRAIN:
                self.hyps.extend(predicted_words)
                self.references.extend(target_words)

        return loss

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.bleu_4_metric = self.hparams.bleu_4_computer()
            self.bleu_2_metric = self.hparams.bleu_2_computer()
            self.hyps = []
            self.references = []

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch.

        Arguments
        ---------
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST
        stage_loss : float
            The average loss for all of the data processed in this stage.
        epoch : int
            The currently-starting epoch. This is passed
            `None` during the test stage.
        """

        # Store the train loss until the validation stage.
        stage_stats = {"loss": stage_loss}
        stage_stats["PPL"] = math.exp(stage_loss)
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["BLEU_4"] = self.bleu_4_metric.summarize("BLEU")
            stage_stats["BLEU_2"] = self.bleu_2_metric.summarize("BLEU")
        # Perform end-of-iteration things, like annealing, logging, etc.
        if stage == sb.Stage.VALID:
            # Update learning rate
            old_lr, new_lr = self.hparams.lr_annealing(epoch)
            sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)

            # The train_logger writes a summary to stdout and to the logfile.

            self.hparams.train_logger.log_stats(
                stats_meta={"epoch": epoch, "lr": old_lr},
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            # Save the current checkpoint and delete previous checkpoints.
            self.checkpointer.save_and_keep_only(
                meta={"PPL": stage_stats["PPL"]},
                min_keys=["PPL"],
            )
            if epoch == hparams["number_of_epochs"] - 1:
                with open(self.hparams.bleu_4_valid_file, "w") as w:
                    self.bleu_4_metric.write_stats(w)
                    for i in range(len(self.hyps)):
                        w.write("target: " + str(self.references[i]) + "\n")
                        w.write("predicted:" + str(self.hyps[i]) + "\n")
                        w.write(
                            "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
                        )

        # We also write statistics about test data to stdout and to the logfile.
        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            with open(self.hparams.bleu_4_test_file, "w") as w:
                self.bleu_4_metric.write_stats(w)
                for i in range(len(self.hyps)):
                    w.write("target: " + str(self.references[i]) + "\n")
                    w.write("predicted:" + str(self.hyps[i]) + "\n")
                    w.write(
                        "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
                    )

    def init_optimizers(self):
        "Initializes the model optimizer"
        self.optimizer = self.hparams.opt_class(self.hparams.model.parameters())

        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("optimizer", self.optimizer)

        self.optimizers_dict = {
            "optimizer": self.optimizer,
        }


def add_special_tokens_(model, tokenizer, attr_to_special_token) -> None:
    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(
        attr_to_special_token  # type: ignore
    )  # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(
            new_num_tokens=orig_num_tokens + num_added_tokens
        )


def dataio_prep(hparams, tokenizer):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined
    functions. We expect `prepare_multiwoz` to have been called before
    this, so that the `train.json`, `dev.json`,  and `test.json` manifest
    files are available.

    Arguments
    ---------
    hparams : dict
        This dictionary is loaded from the `train.yaml` file, and it includes
        all the hyperparameters needed for dataset construction and loading.
    tokenizer : tokenizer
        Object for converting text to tokens.

    Returns
    -------
    datasets : dict
        Contains two keys, "train" and "valid" that correspond
        to the appropriate DynamicItemDataset object.
    """

    # convert special tokens to their ids
    bos, eos, system, user = tokenizer.convert_tokens_to_ids(
        hparams["special_tokens"]
    )
    # history_window, i.e. how many user-system exchanges consider as context (+1 to consider at least the last user turn)
    history_window = 2 * hparams["max_history"] + 1

    #  Define history pipeline:
    @sb.utils.data_pipeline.takes("history")
    @sb.utils.data_pipeline.provides(
        "history",
        "history_tokens_lists",
        "history_ids",
        "history_bos",
        "history_token_type",
    )
    def history_pipeline(history):
        yield history

        # encode each turn of the history
        history_tokens_lists = [tokenizer.encode(turn) for turn in history]
        yield history_tokens_lists

        # add speaker tokens to the history turns (user is even, system is odd)
        # BEFORE:  [Hi how are you?], [I'm fine, thanks]
        # AFTER:   [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks]
        history_input_lists = [
            [user if i % 2 == 0 else system] + encoded_turn
            for i, encoded_turn in enumerate(history_tokens_lists)
        ]

        history_ids = history_input_lists[-history_window:]
        # concatenate every token into a single list
        # list(chain(*[[1, 2], [3, 4], [5]]))
        # >>> [1, 2, 3, 4, 5]
        history_ids = torch.LongTensor(list(chain(*history_ids)))
        # without bos for lm_labels
        yield history_ids

        # create bos version for the input
        history_bos = torch.cat((torch.tensor([bos]), history_ids))
        yield history_bos

        # create a mapping that associates each token in the input to a speaker
        # INPUT: [SPK_1 Hi    how   are   you? ], [SPK_2 I'm   fine, thanks]
        # TYPE:  [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ]
        history_token_type_lists = [
            [user if i % 2 == 0 else system] * len(encoded_turn)
            for i, encoded_turn in enumerate(history_input_lists)
        ]
        history_token_type = torch.LongTensor(
            list(
                chain(
                    *([[system]] + history_token_type_lists[-history_window:])
                )
            )
        )

        yield history_token_type

    #  Define reply pipeline:
    @sb.utils.data_pipeline.takes("reply")
    @sb.utils.data_pipeline.provides(
        "reply",
        "reply_tokens_list",
        "reply_ids",
        "reply_eos",
        "reply_token_type",
    )
    def reply_pipeline(reply):
        yield reply

        reply_tokens_list = tokenizer.encode(reply)
        yield reply_tokens_list

        # specify that the system will say the reply
        reply_input_list = [system] + reply_tokens_list
        reply_ids = torch.LongTensor(reply_input_list)
        yield reply_ids

        # create eos version of the reply for lm_labels
        reply_eos = torch.cat((reply_ids, torch.tensor([eos])))
        yield reply_eos

        # specify the speaker for each token in the reply
        reply_token_type = torch.LongTensor([system] * len(reply_input_list))
        yield reply_token_type

    # Define input_and_token_type_pipeline
    @sb.utils.data_pipeline.takes(
        "history_ids",
        "history_bos",
        "history_token_type",
        "reply_ids",
        "reply_eos",
        "reply_token_type",
    )
    @sb.utils.data_pipeline.provides("input_ids", "token_type_ids", "lm_labels")
    def input_and_token_type_pipeline(
        history_ids,
        history_bos,
        history_token_type,
        reply_ids,
        reply_eos,
        reply_token_type,
    ):
        # put history and reply together
        # N.B. input_sequence = history_bos + reply_ids, we don't have eos in the input
        input_ids = torch.cat((history_bos, reply_ids), -1)
        yield input_ids

        token_type_ids = torch.cat((history_token_type, reply_token_type), -1)
        yield token_type_ids

        # create the language model label (ground truth) for the current input
        # -100 is a special tokens that is ignored during the loss computation
        # the idea is to mask everything except the reply (without the speaker token)
        # N.B. we don't have bos in the input
        lm_labels = (
            [hparams["ignore_index"]] * history_ids.shape[0]
            + [hparams["ignore_index"]]
            + reply_eos[1:].tolist()
        )
        lm_labels = torch.LongTensor(lm_labels)

        yield lm_labels

    # Define datasets. We also connect the dataset with the data processing
    # functions defined above.
    datasets = {}
    data_info = {
        "train": hparams["train_annotation"],
        "valid": hparams["valid_annotation"],
        "test": hparams["test_annotation"],
    }
    for dataset in data_info:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=data_info[dataset],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[
                reply_pipeline,
                history_pipeline,
                input_and_token_type_pipeline,
            ],
            output_keys=[
                "id",
                "input_ids",
                "token_type_ids",
                "history_bos",
                "reply_eos",
                "history_token_type",
                "reply_token_type",
                "lm_labels",
            ],
        )

    return datasets


# RECIPE BEGINS!
if __name__ == "__main__":
    # Reading command line arguments.
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])

    # Initialize ddp (useful only for multi-GPU DDP training).
    sb.utils.distributed.ddp_init_group(run_opts)

    # Load hyperparameters file with command-line overrides.
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )


    # Load tokenizer and add special tokens
    tokenizer = hparams["gpt_model"].tokenizer

    #  Load pretrained GPT
    hparams["gpt_model"] = hparams["gpt_model"].to(device=run_opts["device"])

    # Add special tokens to the tokenizer and resize model embedding
    add_special_tokens_(
        hparams["gpt_model"].model, tokenizer, hparams["attr_to_special_tokens"]
    )

    class CustomPaddedBatch(PaddedBatch):
        """PaddedBatch with custom padding values.

        See the documentation of `speechbrain.dataio.batch.PaddedBatch`.

        """

        def __init__(self, examples, *args, **kwargs):
            _, _, system, _ = tokenizer.convert_tokens_to_ids(
                hparams["special_tokens"]
            )
            for k in [
                "input_ids",
                "history_bos",
                "lm_labels",
                "token_type_ids",
                "history_token_type",
            ]:
                max_len = max([len(x[k]) for x in examples])
                pad_value = 0
                if k in [
                    "input_ids",
                    "history_bos",
                    "token_type_ids",
                    "history_token_type",
                ]:
                    pad_value = tokenizer.unk_token_id
                elif k == "lm_labels":
                    pad_value = hparams["ignore_index"]
                for example in examples:
                    x = example[k]
                    if k in ["history_bos", "history_token_type"]:
                        x = torch.cat(
                            (example[k], torch.LongTensor([system])), -1
                        )
                        example[k] = torch.nn.functional.pad(
                            x, [max_len - len(x), 0], value=pad_value
                        )
                    else:
                        example[k] = torch.nn.functional.pad(
                            x, [0, max_len - len(x)], value=pad_value
                        )
            super().__init__(examples, *args, **kwargs)

    hparams["train_dataloader_options"]["collate_fn"] = CustomPaddedBatch
    hparams["test_dataloader_options"]["collate_fn"] = CustomPaddedBatch

    # Create dataset objects "train", "valid", and "test".
    datasets = dataio_prep(hparams, tokenizer)

    # Initialize the Brain object to prepare for mask training.
    res_gen_brain = ResGenBrain(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    # We load the pretrained whisper model
    if "pretrainer" in hparams.keys():
        run_on_main(hparams["pretrainer"].collect_files)
        hparams["pretrainer"].load_collected(res_gen_brain.device)

    # The `fit()` method iterates the training loop, calling the methods
    # necessary to update the parameters of the model. Since all objects
    # with changing state are managed by the Checkpointer, training can be
    # stopped at any point, and will be resumed on next call.
    res_gen_brain.fit(
        epoch_counter=res_gen_brain.hparams.epoch_counter,
        train_set=datasets["train"],
        valid_set=datasets["valid"],
        train_loader_kwargs=hparams["train_dataloader_options"],
        valid_loader_kwargs=hparams["test_dataloader_options"],
    )

    # Load the best checkpoint for evaluation
    test_stats = res_gen_brain.evaluate(
        test_set=datasets["test"],
        min_key="PPL",
        test_loader_kwargs=hparams["test_dataloader_options"],
    )

Writing train.py


You should pay attention to how we generate training data suitable for teacher forcing by concatenating histories and replies together. We also need to have a custom padding (CustomPaddedBatch) to handle padded value based on the type of input (history_bos, input_ids, lm_labels, token_type_ids or history_token_type).

Now, we train a model for  3 epochs. It takes around 20-30 minutes to train.

In [9]:
!rm -rf results
!python train.py hparams_gpt2.yaml --data_folder='/content/data_dir' --device='cuda:0' --number_of_epochs=3 --batch_size=8

INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [disable_jit_profiling, allow_tf32]
INFO:speechbrain.utils.quirks:Excluded quirks specified by the `SB_DISABLE_QUIRKS` environment (comma-separated list): []
INFO:numexpr.utils:NumExpr defaulting to 12 threads.
2025-04-04 00:54:20.185199: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743728060.208032    1719 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743728060.214666    1719 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
config.json: 100% 665/665 [00:00<00:00, 4.42MB/s]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to 


After training the model for 3 epochs on just 1000 data points, we achieved a test loss of 1.73, PL: 5.62,  BLEU_4: 2.85e-03, and  BLEU_2: 4.43e-02. The full recipe in Speechbrain, including other LLMs for fine-tuning  such as LLama2, can be found . [here](https://github.com/speechbrain/speechbrain/tree/develop/recipes/MultiWOZ/response_generation).


### **Step 4: Inference**

At this point, we can use the trained response generator. For this type of model, Speechbrain made available some classes (take a look [here](https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/inference/text.py)) such as the https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/inference/text.py one that can make inference easier. For instance, we can have a chatbot with a pre-trained GPT model hosted in our HuggingFace repository in only a few lines of code.

In [11]:
from speechbrain.inference.text import GPTResponseGenerator
res_gen_model = GPTResponseGenerator.from_hparams(source="speechbrain/MultiWOZ-GPT-Response_Generation", savedir="pretrained_models/MultiWOZ-GPT-Response_Generation", pymodule_file="custom.py")
print("Hi,How could I help you today?", end="\n")
i=0
while i<10:
  turn = input()
  response = res_gen_model.generate_response(turn)
  print(response, end="\n")
  i+=1


hyperparams.yaml:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

custom.py:   0%|          | 0.00/3.16k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


model.ckpt:   0%|          | 0.00/498M [00:00<?, ?B/s]

Hi,How could I help you today?
good


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


is there a certain price range you are looking for?
100


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


there are 33 restaurants in the city that serve cheap food. do you have a preference for the area of town?
no


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


there are 3 cheap restaurants in the city that serve cheap food: the dojo noodle bar, the j restaurant, and the golden house.
okk


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


the dojo noodle bar is located at 40210 millers yard city centre. would you like to book a table?
no


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


ok, is there anything else i can help you with?
cool


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


goodbye.
ok


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


goodbye.
ok


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


goodbye.


KeyboardInterrupt: Interrupted by user