<a href="https://colab.research.google.com/github/rickqiu/trax_chatbots/blob/main/reformer_chatbots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chatbot with Reformer Model
@author: Rick Qiu

This notebook demonstrates a chatbot with a Reformer model pre-trained on the MultiWoz dataset. The chatbot is built with Google Trax, which is a low-code and high-speed deep learning library.

Although the chatbot can not compete with Alexa, Siri, Cortana and Meena, it gives reasonable good conversations in the domains of attraction, hotel, taxi, train, hospital and police.

Through this notebook, not only one will learn how to create an AI chatbot with few lines of code, but one will also have the opportunities of seeing the algorithm and chatting with the bot.

### References

- [MultiWoz](https://arxiv.org/abs/1810.00278) dataset
- [Reformer](https://arxiv.org/abs/2001.04451) paper
- [Trax](https://github.com/google/trax) code repository









## 1. Setup

In [None]:
!pip install -q -U https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl
!pip install -q -U jax
!pip install -q -U trax

[K     |████████████████████████████████| 144.8MB 20kB/s 
[K     |████████████████████████████████| 491kB 8.7MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 419kB 8.4MB/s 
[K     |████████████████████████████████| 2.6MB 20.0MB/s 
[K     |████████████████████████████████| 174kB 63.0MB/s 
[K     |████████████████████████████████| 1.5MB 69.2MB/s 
[K     |████████████████████████████████| 1.3MB 62.3MB/s 
[K     |████████████████████████████████| 71kB 10.2MB/s 
[K     |████████████████████████████████| 348kB 62.5MB/s 
[K     |████████████████████████████████| 1.1MB 73.0MB/s 
[K     |████████████████████████████████| 3.6MB 69.5MB/s 
[K     |████████████████████████████████| 81kB 11.4MB/s 
[K     |████████████████████████████████| 358kB 61.2MB/s 
[K     |████████████████████████████████| 194kB 71.6MB/s 
[K     |████████████████████████████████| 983kB 61.0MB/s 
[K     |████████████████████████████████| 5.3MB 52.2MB/

In [None]:
!pip list | grep trax

trax                          1.3.5                


In [None]:
# clone to get vocabs and the pretrained model parts from repository
!git clone https://github.com/rickqiu/trax_chatbots.git

Cloning into 'trax_chatbots'...
remote: Enumerating objects: 75, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 92 (delta 45), reused 0 (delta 0), pack-reused 17[K
Unpacking objects: 100% (92/92), done.
Checking out files: 100% (11/11), done.


In [None]:
%cd trax_chatbots
!tar -xzvf vocabs.tar.gz
!cat model_splits/* > chatbot_model1.pkl.gz
%cd ..

/content/trax_chatbots
vocabs/
vocabs/en_32k.sentencepiece
vocabs/en_32k.sentencepiece.vocab
vocabs/en_32k.subword
vocabs/en_8k.subword
/content


In [None]:
# import libraries
import numpy as np
import trax   
from trax import layers as tl

## 2. Modelling

In [None]:
# define attention for fast inference
def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 120
    kwargs['predict_drop_len'] = 120
    return tl.SelfAttention(*args, **kwargs)

# define the model
model = trax.models.reformer.ReformerLM( 
        vocab_size=33000,
        n_layers=6,
        mode='predict', # default 'train' for model training
        attention_type=attention
    )

In [None]:
# display the Reformer model
print(str(model))

Serial[
  ShiftRight(1)
  Embedding_33000_512
  Dropout
  PositionalEncoding
  Dup_out2
  ReversibleSerial_in2_out2[
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        

## 3. Using pre-traind model

In [None]:
# define an input signature
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

# initialize the model from file
model.init_from_file('trax_chatbots/chatbot_model1.pkl.gz',
                     weights_only=True, input_signature=shape11)

# save the starting state
STARTING_STATE = model.state

## 4. Code for making prediction

In [None]:
# https://www.deeplearning.ai/natural-language-processing-specialization/
# vocabulary file directory
VOCAB_DIR = './trax_chatbots/vocabs'

# vocabulary filename
VOCAB_FILE = 'en_32k.subword'

def tokenize(sentence, vocab_file, vocab_dir):
    return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]


def detokenize(tokens, vocab_file, vocab_dir):
    return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)


def generate_output(model, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        model:  the Reformer language model you just trained
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """
    
    input_tokens =  tokenize(start_sentence, vocab_file, vocab_dir)
    
    # add batch dimension to array
    input_tokens_with_batch = np.expand_dims(input_tokens, axis=0)
    
    # call the autoregressive_sample_stream function from trax
    output_gen = trax.supervised.decoding.autoregressive_sample_stream( 
        model,
        inputs=input_tokens_with_batch,
        temperature=temperature
    )
    
    return output_gen


def generate_sentence(model, model_state, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        model:  the Reformer language model you just trained
        model_state (np.array): initial state of the model before decoding
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """  
    
    # define the delimiters we used during training
    delimiter_1 = 'Person 1: ' 
    delimiter_2 = 'Person 2: '
    
    # initialize detokenized output
    sentence = ''

    # output tokens
    result = []
    
    # reset the model state when starting a new dialogue
    model.state = model_state
    
    # calls the output generator implemented earlier
    output = generate_output(model, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
    
    # print the starting sentence
    #print(start_sentence.split(delimiter_2)[0].strip())
    
    # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
    for o in output:
        
        result.append(o)
        
        sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
        
        if sentence.endswith(delimiter_1):
            sentence = sentence.split(delimiter_1)[0]
            print(f'{delimiter_2}{sentence}')
            break
        elif sentence.endswith(delimiter_2):
            sentence = sentence.split(delimiter_2)[0]
            print(f'{delimiter_1}{sentence}')
            break

## 5. Chatting with the bot

Run the following cell to start a conversation input. Enter 'q' to exit.

In [None]:
sample_sentence = ''
in_msg = ''

while True:
  in_msg = input('Person 1: ')
  if in_msg == 'q':
    break
    
  sample_sentence += ' Person 1: ' + in_msg + ' Person 2: '
  generate_sentence(model, model_state=STARTING_STATE, 
  start_sentence=sample_sentence,
  vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, 
  temperature=0.2)

Person 1: Are there theatres in town?
Person 2: There are 4 theatres in town. Do you have a preference? 
Person 1: No, I don't care. Which one would you recommend?
Person 2: I would recommend the Mumford Theatre. It is located in the east at Anglia Ruskin Enterprise, east road. Would you like their phone number? 
Person 1: Yes, could I get the postcode and phone number?
Person 2: The phone number is 01223332360 and the postcode is cb58as. 
Person 1: Thank you very much.
Person 2: The phone number is 01223332320 and the postcode is cb58as. 
Person 1: Bye.
Person 2: You're welcome. Have a nice day.Bye.
Person 1: q
