In [48]:
import os
import random

from parlai.core.worlds import create_task
from parlai.core.agents import create_agent
from parlai.utils.world_logging import WorldLogger
from parlai.utils.misc import TimeLogger, nice_report
from parlai.core.metrics import (
    aggregate_named_reports,
    aggregate_unnamed_reports,
    Metric,
)
import parlai.utils.logging as logging
from parlai.utils.distributed import (
    is_primary_worker,
    all_gather_list,
    is_distributed,
    get_rank,
)

from eval_model import prepare_tb_logger, get_n_parleys, _save_eval_stats

### Seq2Seq Model

In [60]:
# the index to access classifier agent's output in the world
CLASSIFIER_AGENT = 1

random.seed(42)
datapath = "/home/dsi/yufli/ParlAI/data"
ckpt = "/home/dsi/yufli/ParlAI/results/personachat/seq2seq" 
opt = {
    'model': 'seq2seq',
    'task': 'personachat', # dialogue dataset (blended_skill_talk, convai2, etc.)
    'datapath': datapath,
    'datatype': 'test', # dataset type (train, train:evalmode, train:ordered, train:stream, train:stream:ordered, valid, test, etc.) 
    'model_file': ckpt, # checkpoint path (opt, model, dict)
    'batchsize': 32,
    'metrics': 'ppl,rouge,bleu', # evaluation metrics
    'tensorboard_log': False,
    'world_logs': False,
    'num_examples': 1000, # number of examples to evaluate
    'display_examples': True, # display examples
}
agent = create_agent(opt)
# agent.opt.log()
print(agent.model)

22:09:52 | Using CUDA
22:09:52 | loading dictionary from /home/dsi/yufli/ParlAI/results/personachat/seq2seq.dict
22:09:52 | num words = 18745
22:09:52 | Total parameters: 7,745,209 (7,745,209 trainable)
22:09:52 | Loading existing model params from /home/dsi/yufli/ParlAI/results/personachat/seq2seq
Seq2seq(
  (decoder): RNNDecoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (lt): Embedding(18745, 128, padding_idx=0)
    (rnn): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.1)
    (attention): AttentionLayer()
  )
  (encoder): RNNEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (input_dropout): UnknownDropout()
    (lt): Embedding(18745, 128, padding_idx=0)
    (rnn): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.1)
  )
  (output): OutputLayer(
    (dropout): Dropout(p=0.1, inplace=False)
    (o2e): Identity()
  )
)


In [36]:
# Prepare tensorboard logger
tb_logger, setting = prepare_tb_logger(opt)

if tb_logger:
    n_parleys = get_n_parleys(opt)

In [66]:
def get_task_world_logs(task, world_logs, is_multitask=False):
    if not is_multitask:
        return world_logs
    else:
        base_outfile, extension = os.path.splitext(world_logs)
        return f'{base_outfile}_{task}{extension}'


# Handle task loading (if necessary)
task = opt['task']
logging.info(f'Evaluating task {task} using datatype {opt.get("datatype")}.')

# Set up world logger
task_opt = opt.copy()  # copy opt since we're editing the task
task_opt['task'] = task

# Add task suffix in case of multi-tasking
if opt['world_logs']:
    task_opt['world_logs'] = get_task_world_logs(
        task, task_opt['world_logs'], is_multitask=len(opt['task'].split(',')) > 1
    )

world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None
world = create_task(task_opt, agent)  # create worlds for tasks

# Set up logging
log_every_n_secs = opt.get('log_every_n_secs', -1)
if log_every_n_secs <= 0:
    log_every_n_secs = float('inf')
log_time = TimeLogger()

# Max number of examples to evaluate
max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
logging.info(f'Total number of examples to evaluate: {max_cnt}.')
cnt = 0
total_cnt = world.num_examples()
logging.info(f"Sum of each subworld's number of examples: {total_cnt}.")

if is_distributed():
    logging.warning('Progress bar is approximate in distributed mode.')

22:35:24 | Evaluating task personachat using datatype test.
22:35:24 | creating task(s): personachat
22:35:24 | loading fbdialog data: /home/dsi/yufli/ParlAI/data/Persona-Chat/personachat/test_self_original.txt
22:35:24 | Total number of examples to evaluate: 1000.
22:35:24 | Sum of each subworld's number of examples: 7512.


In [61]:
# Run evaluation
# Show some example dialogs

while not world.epoch_done() and cnt < max_cnt:
    cnt += opt.get('batchsize', 1)
    world.parley()
    if world_logger is not None:
        world_logger.log(world)
    if opt['display_examples']:
        # Display examples
        print(world.display() + '\n~~')
    if log_time.time() > log_every_n_secs:
        report = world.report()
        text, report = log_time.log(
            report.get('exs', 0), min(max_cnt, total_cnt), report
        )
        logging.info(text)

[--batchsize 32--]
[batch world 0:]
[0;34m[personachat]:[0;0m [1mi am great . i just got back from the club .[0;0m
[0;34m[eval_labels]:[0;0m [1;94mthis is my favorite time of the year season wise[0;0m
   [0;34m[Seq2Seq]:[0;0m [1;94mi am a good at the beach . i am a good at the night .[0;0m
[batch world 1:]
[0;34m[personachat]:[0;0m [1mthat is a great thing honor your dad with your presence[0;0m
[0;34m[eval_labels]:[0;0m [1;94msure , i pick him up for church every sunday with my ford pickup[0;0m
   [0;34m[Seq2Seq]:[0;0m [1;94mi love to go to the beach . i am a good at the beach .[0;0m
[batch world 2:]
[0;34m[personachat]:[0;0m [1mnice , i just got a advertising job myself . do you like your job ?[0;0m
[0;34m[eval_labels]:[0;0m [1;94mnice . yes i do . caring for people is the joy of my life .[0;0m
   [0;34m[Seq2Seq]:[0;0m [1;94mi do not have any pets . i am a good at the night .[0;0m
[batch world 3:]
[0;34m[personachat]:[0;0m [1mi am going for a hor

In [64]:
# Report metrics

if world_logger is not None:
    # Dump world acts to file
    world_logger.reset()  # add final acts to logs
    if is_distributed():
        rank = get_rank()
        base_outfile, extension = os.path.splitext(task_opt['world_logs'])
        outfile = base_outfile + f'_{rank}' + extension
    else:
        outfile = task_opt['world_logs']
    world_logger.write(outfile, world, file_format=opt['save_format'])

report = aggregate_unnamed_reports(all_gather_list(world.report()))

if isinstance(world.agents, list) and len(world.agents) > 1:
    classifier_agent = world.agents[CLASSIFIER_AGENT]
    if hasattr(classifier_agent, 'calc_auc') and classifier_agent.calc_auc:
        for class_indices, curr_auc in zip(
            classifier_agent.auc_class_indices, classifier_agent.aucs
        ):
            report[f'AUC_{classifier_agent.class_list[class_indices]}'] = curr_auc
        classifier_agent.reset_auc()
        # For safety measures
        agent.reset_auc()
world.reset()

logging.report(f"Report for {opt['task']}:\n{nice_report(report)}")

22:31:30 | [1mReport for personachat:
    accuracy  bleu-1  bleu-2  bleu-3  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gen_n_toks  gpu_mem  llen  \
           0   .1293  .05192  .01252 .002221 147.1  4706  5083       0          0 34.56 1024 .1540       15.85   .00508 12.97   
    loss  lr  ltpb  ltps  ltrunc  ltrunclen   ppl  precision  recall  token_acc  token_em  total_train_updates  tpb  tps  
   4.358   1 414.9 448.1       0          0 78.14      .1694   .1494      .3172         0                41075 5121 5531[0m
