<a href="https://colab.research.google.com/github/rickqiu/AMLR/blob/master/Chatbots.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chatbots with Pretrained Reformer Model
@author: Rick Qiu

The notebook demonstrates conversations between two chatbots in single or multiple domains of attraction, hospital, hotel, police, taxi and train using a pretrained Reformer model.

Credits should be given to Lukasz Kaiser and his Trax open source team. Here is a list of useful resources.

- [MultiWoz](https://arxiv.org/abs/1810.00278) dataset
- [Reformer](https://arxiv.org/abs/2001.04451) paper
- [Trax](https://github.com/google/trax) github repository
- [Natural Language Processing Specialization](https://www.deeplearning.ai/natural-language-processing-specialization/)



## 1. Setup

In [None]:
!pip install -q -U https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl
!pip install -q -U jax
!pip install -q -U trax

[K     |████████████████████████████████| 144.8MB 28kB/s 
[K     |████████████████████████████████| 481kB 8.8MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 419kB 5.5MB/s 
[K     |████████████████████████████████| 1.5MB 8.3MB/s 
[K     |████████████████████████████████| 174kB 30.3MB/s 
[K     |████████████████████████████████| 2.6MB 42.4MB/s 
[K     |████████████████████████████████| 655kB 60.9MB/s 
[K     |████████████████████████████████| 81kB 10.7MB/s 
[K     |████████████████████████████████| 348kB 58.6MB/s 
[K     |████████████████████████████████| 983kB 59.0MB/s 
[K     |████████████████████████████████| 194kB 51.3MB/s 
[K     |████████████████████████████████| 5.3MB 65.8MB/s 
[K     |████████████████████████████████| 368kB 53.3MB/s 
[K     |████████████████████████████████| 358kB 55.5MB/s 
[K     |████████████████████████████████| 1.1MB 58.2MB/s 
[K     |████████████████████████████████| 3.6MB 45.8MB/

In [None]:
# import packages
import numpy as np
import trax   
from trax import layers as tl
!pip list | grep trax

trax                          1.3.5                


In [None]:
# clone to get vocabs and the pretrained model parts from repository
!git clone https://github.com/rickqiu/trax_chatbots.git

Cloning into 'trax_chatbots'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 49 (delta 18), reused 0 (delta 0), pack-reused 17[K
Unpacking objects: 100% (49/49), done.
Checking out files: 100% (11/11), done.


In [None]:
%cd trax_chatbots
!tar -xzvf vocabs.tar.gz
!cat model_splits/* > chatbot_model1.pkl.gz
%cd ..

/content/trax_chatbots
vocabs/
vocabs/en_32k.sentencepiece
vocabs/en_32k.sentencepiece.vocab
vocabs/en_32k.subword
vocabs/en_8k.subword
/content


## 2. Modelling

In [None]:
# define attention for fast inference
def attention(*args, **kwargs):
    kwargs['predict_mem_len'] = 150
    kwargs['predict_drop_len'] = 120
    return tl.SelfAttention(*args, **kwargs)

# define the model
model = trax.models.reformer.ReformerLM( 
        vocab_size=33000,
        n_layers=6,
        mode='predict',
        attention_type=attention
    )

In [None]:
# desplay the model
print(str(model))

Serial[
  ShiftRight(1)
  Embedding_33000_512
  Dropout
  PositionalEncoding
  Dup_out2
  ReversibleSerial_in2_out2[
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidual_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        

## 3.  Prediction

In [None]:
# define an input signature
shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)

# initialize from file
model.init_from_file('trax_chatbots/chatbot_model1.pkl.gz',
                     weights_only=True, input_signature=shape11)

# save the starting state
STARTING_STATE = model.state

In [None]:
# Reference: Course 4 of https://www.deeplearning.ai/natural-language-processing-specialization/
# vocabulary file directory
VOCAB_DIR = './trax_chatbots/vocabs'

# vocabulary filename
VOCAB_FILE = 'en_32k.subword'

def tokenize(sentence, vocab_file, vocab_dir):
    return list(trax.data.tokenize(iter([sentence]), vocab_file=vocab_file, vocab_dir=vocab_dir))[0]


def detokenize(tokens, vocab_file, vocab_dir):
    return trax.data.detokenize(tokens, vocab_file=vocab_file, vocab_dir=vocab_dir)


def generate_output(model, start_sentence, vocab_file, vocab_dir, temperature):
    """
    Args:
        model:  the Reformer language model you just trained
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """
    
    input_tokens =  tokenize(start_sentence, vocab_file, vocab_dir)
    
    # add batch dimension to array
    input_tokens_with_batch = np.expand_dims(input_tokens, axis=0)
    
    # call the autoregressive_sample_stream function from trax
    output_gen = trax.supervised.decoding.autoregressive_sample_stream( 
        model,
        inputs=input_tokens_with_batch,
        temperature=temperature
    )
    
    return output_gen


def generate_dialogue(model, model_state, start_sentence, vocab_file, vocab_dir, max_len, temperature):
    """
    Args:
        model:  the Reformer language model you just trained
        model_state (np.array): initial state of the model before decoding
        start_sentence (string): starting sentence of the conversation
        vocab_file (string): vocabulary filename
        vocab_dir (string): directory of the vocabulary file
        max_len (int): maximum number of tokens to generate 
        temperature (float): parameter for sampling ranging from 0.0 to 1.0.
            0.0: same as argmax, always pick the most probable token
            1.0: sampling from the distribution (can sometimes say random things)

    Returns:
        generator: yields the next symbol generated by the model
    """  
    
    # define the delimiters we used during training
    delimiter_1 = 'Person 1: ' 
    delimiter_2 = 'Person 2: '
    
    # initialize detokenized output
    sentence = ''
    
    # token counter
    counter = 0

    # turns
    turns = 0
    
    # output tokens
    result = []
    
    # reset the model state when starting a new dialogue
    model.state = model_state
    
    # calls the output generator implemented earlier
    output = generate_output(model, start_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, temperature=temperature)
    
    # print the starting sentence
    print(start_sentence.split(delimiter_2)[0].strip())

    # turns
    turns = 1
    
    # loop below yields the next tokens until max_len is reached. the if-elif is just for prettifying the output.
    for o in output:
        
        result.append(o)
        
        sentence = detokenize(np.concatenate(result, axis=0), vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR)
        
        if sentence.endswith(delimiter_1):
            sentence = sentence.split(delimiter_1)[0]
            print(f'{delimiter_2}{sentence}')
            sentence = ''
            result.clear()
            turns += 1
        
        elif sentence.endswith(delimiter_2):
            sentence = sentence.split(delimiter_2)[0]
            print(f'{delimiter_1}{sentence}')
            sentence = ''
            result.clear()
            turns +=1

        counter += 1
        
        if counter > max_len and turns%2 == 0:
            break    

## 4. Outputing results

In [None]:
sample_sentence = ' Person 1: Are there theatres in town? Person 2: '
generate_dialogue(model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Are there theatres in town?
Person 2: There are 4 theatres in town. Two are in the centre of town and one in the south. Which would you prefer? 
Person 1: I would prefer the south please. 
Person 2: There are 4 theatres in the south. The Junction, wandlebury country park, and wandlebury country park. 
Person 1: That sounds good. Can I get the phone number, postcode, and address? 
Person 2: The address is wandlebury ring, gog magog hills, babraham, their phone number is 01223243830. 


In [None]:
sample_sentence = ' Person 1: Is there a hospital nearby? Person 2: '
generate_dialogue(model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Is there a hospital nearby?
Person 2: Addensbrookes Hospital is located at Hills Rd, Cambridge, postcode CB20QQ. Do you need a particular department? 
Person 1: No, but I do need the phone number, please. 
Person 2: Their main phone number is 01223245151. Is there anything else I can help you with? 
Person 1: No, that's all. Thanks! 
Person 2: Thank you for using our services.Goodbye.
Person 1: Thank you. Goodbye. 
Person 2: Thank you for using our services.Goodbye.


In [None]:
sample_sentence = ' Person 1: Can you book a taxi? Person 2: '
generate_dialogue(model, model_state=STARTING_STATE, start_sentence=sample_sentence, vocab_file=VOCAB_FILE, vocab_dir=VOCAB_DIR, max_len=120, temperature=0.2)

Person 1: Can you book a taxi?
Person 2: I sure can. When would you like to leave? 
Person 1: I need to leave after 16:45. 
Person 2: I'm sorry, I have no listings for that time. Would you like to try a different time? 
Person 1: No, I need to leave after 17:00. 
Person 2: I have booked you a grey bmw with contact number 07394368786. 
Person 1: Thank you for your help. 
Person 2: You're welcome. Have a great day.Bye.
