# Getting Started

In this tutorial, you will know how to
- use the models in **tatk** to build a dialog agent.
- build a simulator to chat with the agent and evaluate the performance.
- try different module combinations.

Let's get started!

## Environment Setup
Run the command below to install tatk for once. Then restart the notebook and ignore this commend.

In [2]:
# first install tatk and restart the notebook
! rm -rf tatk && git clone https://github.com/thu-coai/tatk.git && cd tatk && pip install -e .

Cloning into 'tatk'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 2569 (delta 14), reused 12 (delta 5), pack-reused 2532[K
Receiving objects: 100% (2569/2569), 19.15 MiB | 31.67 MiB/s, done.
Resolving deltas: 100% (1485/1485), done.
Obtaining file:///content/tatk
Collecting nltk>=3.4 (from tatk==0.0.1.dev20190822)
[?25l  Downloading https://files.pythonhosted.org/packages/f6/1d/d925cfb4f324ede997f6d47bea4d9babba51b49e87a767c170b77005889d/nltk-3.4.5.zip (1.5MB)
[K     |████████████████████████████████| 1.5MB 1.8MB/s 
[?25hCollecting tqdm>=4.30 (from tatk==0.0.1.dev20190822)
[?25l  Downloading https://files.pythonhosted.org/packages/a5/83/06029af22fe06b8a7be013aeae5e104b3ed26867e5d4ca91408b30aa602e/tqdm-4.34.0-py2.py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 19.1MB/s 
[?25hCollecting checksumdir>=1.1 (from tatk==0.0.1.dev20190822)
  Downloadin

## Build an agent

We use the models adapted on [Multiwoz dataset](https://www.aclweb.org/anthology/D18-1547) to build our agent. This pipeline agent consists of NLU, DST, Policy and NLG modules.

First, import some models:

In [0]:
import sys
import os
# Agent
from tatk.dialog_agent import PipelineAgent, BiSession
# common import: tatk.$module.$model.$dataset
from tatk.nlu.svm.multiwoz import SVMNLU
from tatk.nlu.bert.multiwoz import BERTNLU
from tatk.dst.rule.multiwoz import RuleDST
from tatk.policy.rule.multiwoz import Rule
from tatk.nlg.template.multiwoz import TemplateNLG
from tatk.evaluator.multiwoz_eval import MultiWozEvaluator
import random
import numpy as np
from pprint import pprint

Then, create the models and build an agent:

In [2]:
# svm nlu trained on usr sentence of multiwoz
# go to README.md under `tatk/tatk/nlu/svm/multiwoz` for more information 
sys_nlu = SVMNLU('usr',model_file="https://tatk-data.s3-ap-northeast-1.amazonaws.com/svm_multiwoz_usr.zip")
# simple rule DST
sys_dst = RuleDST()
# rule policy
sys_policy = Rule(character='sys')
# template NLG
sys_nlg = TemplateNLG(is_user=False)
# assemble
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg)


[<tatk.nlu.svm.Features.nbest object at 0x7f1f07eaa550>]
Load from model_file param


100%|██████████| 15903199/15903199 [00:02<00:00, 6519028.66B/s]


loading saved Classifier
loaded.


That's all! Let's chat with the agent using its `response` function:

In [3]:
sys_agent.response("I want to find a moderate hotel")

'We have 18 such places . Yes , i would suggest a and b guest house. The reference number is 00000000 .'

In [4]:
sys_agent.response("Which type of hotel is it ?")

'It is a guesthouse .'

In [5]:
sys_agent.response("OK , where is its address ?")

'Pool way, whitehill road, off newmarket road is the address.'

In [6]:
sys_agent.response("Thank you !")

'Okay ! glad i could help . enjoy your stay .'

In [7]:
sys_agent.response("Try to find me a Chinese restaurant in south area .")

'Yes , there are 3 available restaurants . I would suggest the good luck chinese food takeaway . The reference number is 00000003 . It is located in the south .'

In [8]:
sys_agent.response("Which kind of food it provides ?")

'They serve chinese food .'

In [9]:
sys_agent.response("Book a table for 5 , this Sunday .")

'Reference number is : 00000003 .'

## Build a Simulator to Chat with the Agent and Evaluate

In many one-to-one task-oriented dialog system, a simulator is essential to train an RL agent. In our framework, we doesn't distinguish user or system, all speakers are **agents**. The simulator is also an agent, with specific policy inside for accomplishing the user goal.

We use Agenda policy for the simulator, this policy requires dialog act input, which means we should set DST argument of `PipelineAgent` to `None`. Then the `PipelineAgent` will pass dialog act to policy directly. Refer to `PipelineAgent` doc for more details.

In [10]:
# bert nlu trained on sys sentence of multiwoz
# go to README.md under `tatk/tatk/nlu/bert/multiwoz` for more information 
user_nlu = BERTNLU('sys',model_file="https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_sys.zip")
# not use dst
user_dst = None
# rule policy
user_policy = Rule(character='usr')
# template NLG
user_nlg = TemplateNLG(is_user=True)
# assemble
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg)

load train, size 8434
load val, size 999
load test, size 1000
loaded train, size 56750
loaded val, size 7365
loaded test, size 7372
dialog act num: 34
sentence label num: 64
tag num: 312


100%|██████████| 231508/231508 [00:00<00:00, 3913016.04B/s]


Load from model_file param


100%|██████████| 1076596/1076596 [00:00<00:00, 1125435.95B/s]


Load from /content/tatk/tatk/nlu/bert/multiwoz/output/sys/bestcheckpoint.tar
train step 29900


100%|██████████| 407873900/407873900 [00:06<00:00, 62204605.09B/s]


BERTNLU loaded
Loading goal model is done


Now we have a simulator and an agent. we will use an existed simple one-to-one conversation controller `BiSession`, you can also define your own Session class for your special need. 

We add `MultiWozEvaluator` to evaluate the performance. It uses the parsed dialog act input and policy output dialog act to calculate **inform f1**, **book rate**, and whether the task is **success**.

In [0]:
evaluator = MultiWozEvaluator()
sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)

Let's make this two agents chat! The key is `next_turn` method of Session class.

In [12]:
random.seed(20190827)
np.random.seed(20190827)
sys_response = ''
sess.init_session()
print('init goal:')
pprint(sess.evaluator.goal)
print('-'*50)
for i in range(40):
    sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
    print('user:', user_response)
    print('sys:', sys_response)
    print()
    if session_over is True:
        print('task success:', sess.evaluator.task_success())
        print('book rate:', sess.evaluator.book_rate())
        print('inform precision/recall/f1:', sess.evaluator.inform_F1())
        print('-'*50)
        print('final goal:')
        pprint(sess.evaluator.goal)
        print('='*100)
        break

init goal:
{'attraction': {'info': {'area': 'centre', 'type': 'college'},
                'reqt': {'entrance fee': '?', 'phone': '?'}},
 'train': {'book': {'people': '7'},
           'booked': '?',
           'info': {'day': 'wednesday',
                    'departure': 'cambridge',
                    'destination': 'london kings cross',
                    'leaveAt': '12:15'},
           'reqt': {'duration': '?'}}}
--------------------------------------------------
user: Do you have any college attractions. I am also looking for places to go in town . maybe something in the centre .
sys: There are 44 , anything in particular you are looking for ? I 'd recommend the fez club . would you like some information on it ?

user: Can you give me their phone number please ? Yes , what are the entrance fees ?
sys: Its entrance fee is ? . The phone number is 01223300085 .

user: I just need to know how much the entrance fee is .
sys: The park is ? .

user: Okay , are there any colleges in the c

`BiSession` allows two agents chat in dialog act level or natural language level, once the input and output are consistent. Example configurations:

| usr input        | usr NLU | usr DST | usr Policy | usr NLG  | sys input        | sys NLU | sys DST | sys Policy | sys NLG  |
| ---------------- | ------- | ------- | ---------- | -------- | ---------------- | ------- | ------- | ---------- | -------- |
| Dialog act       | None    | Rule    | Rule       | None     | Dialog act       | None    | None    | Rule       | None     |
| Natural language | Bert    | Rule    | Rule       | None     | Dialog act       | None    | None    | Rule       | Template |
| Dialog act       | None    | Rule    | Rule       | Template | Natural language | SVM     | None    | Rule       | None     |
| Natural language | Bert    | Rule    | Rule       | Template | Natural language | SVM     | None    | Rule       | Template |


We have tried the last configuration before. Let's try the second configuration.


In [0]:
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, None)
sys_agent = PipelineAgent(None, sys_dst, sys_policy, sys_nlg)
evaluator = MultiWozEvaluator()
sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)

In [14]:
random.seed(20190827)
np.random.seed(20190827)
sys_response = ''
sess.init_session()
print('init goal:')
pprint(sess.evaluator.goal)
print('-'*50)
for i in range(40):
    sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
    print('user:', user_response)
    print('sys:', sys_response)
    print()
    if session_over is True:
        print('task success:', sess.evaluator.task_success())
        print('book rate:', sess.evaluator.book_rate())
        print('inform precision/recall/f1:', sess.evaluator.inform_F1())
        print('-'*50)
        print('final goal:')
        pprint(sess.evaluator.goal)
        print('='*100)
        break

init goal:
{'attraction': {'info': {'area': 'centre', 'type': 'college'},
                'reqt': {'entrance fee': '?', 'phone': '?'}},
 'train': {'book': {'people': '7'},
           'booked': '?',
           'info': {'day': 'wednesday',
                    'departure': 'cambridge',
                    'destination': 'london kings cross',
                    'leaveAt': '12:15'},
           'reqt': {'duration': '?'}}}
--------------------------------------------------
user: {'Attraction-Inform': [['Type', 'college'], ['Area', 'centre']]}
sys: There are 13 , anything in particular you are looking for ? I 'd recommend emmanuel college . would you like some information on it ?

user: {'Attraction-Request': [['Phone', '?'], ['Fee', '?']]}
sys: The phone number is 01223334900 . Their entrance fee is free by our system currently .

user: {'Train-Inform': [['Depart', 'cambridge'], ['Day', 'wednesday'], ['Leave', '12:15'], ['Dest', 'london kings cross']]}
sys: Would you like me to book you on t

After removing user NLG and system NLU, the conversation is more efficient.

## Try Different Module Combinations

The combination modes of pipeline agent modules are flexible. We support joint model such as [MDBT](https://www.aclweb.org/anthology/P18-2069) (NLU+DST) and [MDRG](https://pdfs.semanticscholar.org/47d0/1eb59cd37d16201fcae964bd1d2b49cfb55e.pdf) (Policy+NLG), once the input and output are matched with previous and next module. We also support End2End model such as [Sequicity](https://www.comp.nus.edu.sg/~kanmy/papers/acl18-sequicity.pdf).

### MDBT
- NLU: None
- DST: MDBT
- Policy: Rule
- NLG: TemplateNLG

In [15]:
from tatk.dst.mdbt.multiwoz.mdbt import MultiWozMDBT
nlu = None
# simple rule DST
dst = MultiWozMDBT()
# rule policy
policy = Rule()
# template NLG
nlg = TemplateNLG(is_user=False)
# assemble
sys_agent = PipelineAgent(nlu, dst, policy, nlg)

W0822 07:47:49.780210 139773340759936 deprecation_wrapper.py:119] From /content/tatk/tatk/dst/mdbt/mdbt_util.py:62: The name tf.nn.rnn_cell.RNNCell is deprecated. Please use tf.compat.v1.nn.rnn_cell.RNNCell instead.

100%|██████████| 262980414/262980414 [00:20<00:00, 12691259.66B/s]


Configuring MDBT model...


W0822 07:48:57.885578 139773340759936 deprecation_wrapper.py:119] From /content/tatk/tatk/dst/mdbt/mdbt_util.py:223: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0822 07:48:57.894057 139773340759936 deprecation_wrapper.py:119] From /content/tatk/tatk/dst/mdbt/mdbt_util.py:155: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0822 07:48:57.903185 139773340759936 deprecation.py:323] From /content/tatk/tatk/dst/mdbt/mdbt_util.py:162: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
W0822 07:48:57.905097 139773340759936 deprecation.py:323] From /content/tatk/tatk/dst/mdbt/mdbt_util.py:173: bidirectional_dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions 

Loading trained MDBT model from  /content/tatk/tatk/dst/mdbt/multiwoz/configs/models/model-1


### MDRG

- NLU: SVM
- DST: Rule
- Policy: MDRG
- NLG: None

In [16]:
from tatk.policy.mdrg.multiwoz.policy import MDRGWordPolicy
# svm nlu trained on usr sentence of multiwoz
# go to README.md under `tatk/tatk/nlu/svm/multiwoz` for more information 
nlu = SVMNLU('usr',model_file="https://tatk-data.s3-ap-northeast-1.amazonaws.com/svm_multiwoz_usr.zip")
# simple rule DST
dst = RuleDST()
# rule policy
policy = MDRGWordPolicy()
# template NLG
nlg = None
# assemble
sys_agent = PipelineAgent(nlu, dst, policy, nlg)

Downloading from:  https://tatk-data.s3-ap-northeast-1.amazonaws.com/mdrg_model.zip


100%|██████████| 21577107/21577107 [00:02<00:00, 8956227.09B/s]


Extracting...
Downloading from:  https://tatk-data.s3-ap-northeast-1.amazonaws.com/mdrg_data.zip


100%|██████████| 47104409/47104409 [00:04<00:00, 9933537.88B/s] 


Extracting...
Downloading from:  https://tatk-data.s3-ap-northeast-1.amazonaws.com/mdrg_db.zip


100%|██████████| 183081/183081 [00:00<00:00, 379999.73B/s]


Extracting...
[<tatk.nlu.svm.Features.nbest object at 0x7f1ebccd5048>]
loading saved Classifier
loaded.
Model has 383900  parameters.
Loading parameters of iter 1 


### Sequicity

Sequicity inherits from interface `Agent` directly.

In [17]:
from tatk.e2e.sequicity.multiwoz import Sequicity
sequicity = Sequicity()
sys_agent = sequicity

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
I0822 07:50:05.561878 139773340759936 connectionpool.py:813] Starting new HTTPS connection (1): tatk-data.s3-ap-northeast-1.amazonaws.com:443


down load data from https://tatk-data.s3-ap-northeast-1.amazonaws.com/sequicity_multiwoz_data.zip


I0822 07:50:06.239478 139773340759936 connectionpool.py:393] https://tatk-data.s3-ap-northeast-1.amazonaws.com:443 "HEAD /sequicity_multiwoz_data.zip HTTP/1.1" 200 0
I0822 07:50:06.246556 139773340759936 allennlp_file_utils.py:284] https://tatk-data.s3-ap-northeast-1.amazonaws.com/sequicity_multiwoz_data.zip not found in cache, downloading to /tmp/tmpy9ivjabr
I0822 07:50:06.249362 139773340759936 connectionpool.py:813] Starting new HTTPS connection (1): tatk-data.s3-ap-northeast-1.amazonaws.com:443
I0822 07:50:06.981320 139773340759936 connectionpool.py:393] https://tatk-data.s3-ap-northeast-1.amazonaws.com:443 "GET /sequicity_multiwoz_data.zip HTTP/1.1" 200 74337337
100%|██████████| 74337337/74337337 [00:06<00:00, 11831252.18B/s]
I0822 07:50:13.274135 139773340759936 allennlp_file_utils.py:297] copying /tmp/tmpy9ivjabr to cache at /root/.tatk/cache/d0ee72a6516ccdffeac0b9d62526d5131c1a94c692c081268d4e4768204cf330.70856ee662e89468b1e185825920d6525c1c98cc1ce3e8a9d1da2ffca852b12c
I0822 07

unzip to /content/tatk/tatk/e2e/sequicity/multiwoz/


I0822 07:50:15.328531 139773340759936 connectionpool.py:813] Starting new HTTPS connection (1): tatk-data.s3-ap-northeast-1.amazonaws.com:443


Load from model_file param
down load data from https://tatk-data.s3-ap-northeast-1.amazonaws.com/sequicity_multiwoz.zip


I0822 07:50:16.006468 139773340759936 connectionpool.py:393] https://tatk-data.s3-ap-northeast-1.amazonaws.com:443 "HEAD /sequicity_multiwoz.zip HTTP/1.1" 200 0
I0822 07:50:16.014307 139773340759936 allennlp_file_utils.py:284] https://tatk-data.s3-ap-northeast-1.amazonaws.com/sequicity_multiwoz.zip not found in cache, downloading to /tmp/tmpt_sqtc0j
I0822 07:50:16.017261 139773340759936 connectionpool.py:813] Starting new HTTPS connection (1): tatk-data.s3-ap-northeast-1.amazonaws.com:443
I0822 07:50:16.718674 139773340759936 connectionpool.py:393] https://tatk-data.s3-ap-northeast-1.amazonaws.com:443 "GET /sequicity_multiwoz.zip HTTP/1.1" 200 4864058
100%|██████████| 4864058/4864058 [00:01<00:00, 3359392.93B/s]
I0822 07:50:18.178032 139773340759936 allennlp_file_utils.py:297] copying /tmp/tmpt_sqtc0j to cache at /root/.tatk/cache/b3c77908764e923c0c55e5c271c33fc92a3f8c2987a226c72cca8d2760c364b0.7f13568232dd1506e94bd76727cac23d303abeb923a2cc7e8a0a3bb1776fb086
I0822 07:50:18.191380 13977

unzip to /content/tatk/tatk/e2e/sequicity/multiwoz/output
total trainable params: 1717950


  "num_layers={}".format(dropout, num_layers))
  torch.nn.init.orthogonal(hh[i:i+gru.hidden_size],gain=1)
