<a href="https://colab.research.google.com/github/rickqiu/trax_chatbots/blob/main/chatbots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chatbots with Pretrained Reformer Model

The notebook demonstrates conversations between two chatbots in single or multiple domains of attraction, hospital, hotel, police, taxi and train using a pretrained Reformer model.

Credits should be given to Lukasz Kaiser and his Trax open source team. Here is a list of useful resources.

- [MultiWoz](https://arxiv.org/abs/1810.00278) dataset
- [Reformer](https://arxiv.org/abs/2001.04451) paper
- [Trax](https://github.com/google/trax) github repository
- [Natural Language Processing Specialization](https://www.deeplearning.ai/natural-language-processing-specialization/)



In [1]:
!pip install -U https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl
!pip install -q -U jax
!pip install -q -U trax

Collecting jaxlib==0.1.55
[?25l  Downloading https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl (144.8MB)
[K     |████████████████████████████████| 144.8MB 27kB/s 
Installing collected packages: jaxlib
  Found existing installation: jaxlib 0.1.55
    Uninstalling jaxlib-0.1.55:
      Successfully uninstalled jaxlib-0.1.55
Successfully installed jaxlib-0.1.55
[K     |████████████████████████████████| 481kB 2.8MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 419kB 2.7MB/s 
[K     |████████████████████████████████| 2.6MB 11.9MB/s 
[K     |████████████████████████████████| 174kB 28.9MB/s 
[K     |████████████████████████████████| 1.5MB 24.3MB/s 
[K     |████████████████████████████████| 348kB 19.5MB/s 
[K     |████████████████████████████████| 71kB 10.8MB/s 
[K     |████████████████████████████████| 1.1MB 41.2MB/s 
[K     |████████████████████████████████| 1.1MB 35.7MB

In [2]:
# import packages
import numpy as np
import trax   
from trax import layers as tl
!pip list | grep trax

trax                          1.3.5                


In [3]:
# clone for vocabs and the pretrained model parts from repository
!git clone https://github.com/rickqiu/trax_chatbots.git

Cloning into 'trax_chatbots'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 37 (delta 10), reused 0 (delta 0), pack-reused 17[K
Unpacking objects: 100% (37/37), done.
Checking out files: 100% (11/11), done.


In [4]:
%cd trax_chatbots
!tar -xzvf vocabs.tar.gz
!cat model_splits/* > chatbot_model1.pkl.gz
%cd ..

/content/trax_chatbots
vocabs/
vocabs/en_32k.sentencepiece
vocabs/en_32k.sentencepiece.vocab
vocabs/en_32k.subword
vocabs/en_8k.subword
/content


In [5]:
# define attention for fast inference
def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 120
    kwargs['predict_drop_len'] = 120
    return tl.SelfAttention(*args, **kwargs)

# define the model
model = trax.models.reformer.ReformerLM( 
        # set vocab size
        vocab_size=33000,
        # set number of layers
        n_layers=6,
        # set mode
        mode='predict',
        # set attention type
        attention_type=attention
    )

# define an input signature
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

# initialize from file
model.init_from_file('trax_chatbots/chatbot_model1.pkl.gz',
                     weights_only=True, input_signature=shape11)

# save the starting state
STARTING_STATE = model.state

In [6]:
# Reference: Course 4 of https://www.deeplearning.ai/natural-language-processing-specialization/
# vocabulary file directory
VOCAB_DIR = './trax_chatbots/vocabs'

# vocabulary filename
VOCAB_FILE = 'en_32k.subword'

def tokenize(sentence, vocab_file, vocab_dir):
    return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]


def detokenize(tokens, vocab_file, vocab_dir):
    return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)


def generate_output(ReformerLM, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """
    
    # Create input tokens using the the tokenize function
    input_tokens =  tokenize(start_sentence, vocab_file, vocab_dir)
    
    # Add batch dimension to array. Convert from (n,) to (x, n) where 
    # x is the batch size. Default is 1. (hint: you can use np.expand_dims() with axis=0)
    input_tokens_with_batch = np.expand_dims(input_tokens, axis=0)
    
    # call the autoregressive_sample_stream function from trax
    output_gen = trax.supervised.decoding.autoregressive_sample_stream( 
        # model
        ReformerLM,
        # inputs will be the tokens with batch dimension
        inputs=input_tokens_with_batch,
        # temperature
        temperature=temperature
    )
    
    return output_gen


def generate_dialogue(ReformerLM, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):
    """
    Args:
        ReformerLM:  the Reformer language model you just trained
        model_state (np.array): initial state of the model before decoding
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        max_len (int): maximum number of tokens to generate 
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """  
    
    # define the delimiters we used during training
    delimiter_1 = 'Person 1: ' 
    delimiter_2 = 'Person 2: '
    
    # initialize detokenized output
    sentence = ''
    
    # token counter
    counter = 0
    
    # output tokens. we insert a ': ' for formatting
    result = [tokenize(': ', vocab_file=vocab_file, vocab_dir=vocab_dir)]
    
    # reset the model state when starting a new dialogue
    ReformerLM.state = model_state
    
    # calls the output generator implemented earlier
    output = generate_output(ReformerLM, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
    
    # print the starting sentence
    print(start_sentence.split(delimiter_2)[0].strip())
    
    # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
    for o in output:
        
        result.append(o)
        
        sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
        
        if sentence.endswith(delimiter_1):
            sentence = sentence.split(delimiter_1)[0]
            print(f'{delimiter_2}{sentence}')
            sentence = ''
            result.clear()
        
        elif sentence.endswith(delimiter_2):
            sentence = sentence.split(delimiter_2)[0]
            print(f'{delimiter_1}{sentence}')
            sentence = ''
            result.clear()

        counter += 1
        
        if counter > max_len:
            break    

In [7]:
sample_sentence = ' Person 1: Are there theatres in town? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Are there theatres in town?
Person 2: : There are 4 theatres in town. Do you have a preference? 
Person 1: Not really, but I would like the most affordable and the cambridge corn exchange. 
Person 2: I would recommend the Mumford Theatre. Would you like more information on it? 
Person 1: Yes, could I get the postcode and phone number please? 
Person 2: The phone number is 08451962320 and the postcode is cb11pt. 
Person 2: Thank bybybybyby


In [8]:
sample_sentence = ' Person 1: Is there a hospital nearby? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Is there a hospital nearby?
Person 2: : Addensbrookes Hospital is located at Hills Rd, Cambridge, Postcode CB20QQ. Do you need a particular department? 
Person 1: No, but I do need the phone number, please. 
Person 2: The phone number is 01223245151. Do you need the main phone number? 
Person 1: No, that's all I need. Thank you! 
Person 2: You're welcome. Have a nice day.
Person 1: Thank you find me one called the gandhi. 


In [9]:
sample_sentence = ' Person 1: Can you book a taxi? Person 2: '
generate_dialogue(ReformerLM=model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Can you book a taxi?
Person 2: : I sure can. Where would you like to be picked up? 
Person 1: I'm going to be from the hamilton lodge restaurant. 
Person 2: What time would you like to be picked up? 
Person 1: I need to be picked up by 23:45. 
Person 2: Booking completed!
Booked car type	:	grey volkswagen
Contact number	:	07180084574
 
Person 2: Thank bybybybybyby
