# Set seed and logger

In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import logging
import os
import random
import re
import numpy as np
import torch

def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    random.seed(seed)
    
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def set_global_logging_level(level=logging.ERROR, prefixes=("",)):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        level: desired level. Optional. Default is logging.ERROR
        prefixes: list of one or more str prefixes to match (e.g. ["transformers", "torch"]). Optional.
            Default is `[""]` to match all active loggers.
            The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{"|".join(prefixes)})')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)

In [3]:
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)
# if logger.hasHandlers():
#     logger.handlers.clear()
# console = logging.StreamHandler()
# logger.addHandler(console)

logging.basicConfig(
    format='[%(asctime)s %(levelname)s %(name)s] %(message)s', datefmt='%m/%d %H:%M:%S',
    level=logging.INFO
)
set_global_logging_level(logging.WARNING, ["elasticsearch"])

set_seed(0)

# Load examples

In [4]:
import json

split = 'dev'
samples = []
with open(f"data/hotpot-{split}.tsv") as f:
    for line in f:
        if split == 'test':
            q_id, question = line.strip().split('\t')
            samples.append((q_id, (question,)))
        else:
            q_id, question, answer, sp_facts = line.strip().split('\t')
            sp_facts = json.loads(sp_facts)
            samples.append((q_id, (question, answer, sp_facts)))
print(len(samples))

7405


In [5]:
if split == 'test':
    gold_qas_map = None
    gold_samples = None
    gold_state2action = None
else:
    gold_qas_map = dict(samples)
    with open('data/HotpotQA/hotpot_dev_fullwiki_v1.json') as f:
        gold_samples = json.load(f)
    print(len(gold_samples))
    gold_state2action = dict()
    with open('data/hotpot-step-dev.strict.refined.jsonl') as f:
        for line in f:
            example = json.loads(line)
            gold_state2action[example['_id']] = example['state2action']
    print(len(gold_state2action))

7405
7405


# Load sparse query generator

In [6]:
from retriever import SparseRetriever

sparse_retriever = SparseRetriever('enwiki-20171001-paragraph-5', ['10.60.0.59:9200'], max_retries=3, timeout=30)

[06/11 00:45:32 INFO transformers.file_utils] PyTorch version 1.4.0 available.
[06/11 00:45:33 INFO faiss.loader] Loading faiss with AVX2 support.
[06/11 00:45:33 INFO faiss.loader] Loading faiss.


In [7]:
# %env CLASSPATH=corenlp/*
from drqa.reader import Predictor
import warnings

warnings.filterwarnings("ignore")

# import os
# os.environ['CLASSPATH'] = 'corenlp/*'

qg1 = Predictor(model='ckpts/golden-retriever/hop1.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt', num_workers=-1)
qg1.cuda()
qg1.model.network.to(torch.device('cuda:0'))
qg2 = Predictor(model='ckpts/golden-retriever/hop2.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt', num_workers=-1)
qg2.cuda()
qg2.model.network.to(torch.device('cuda:0'))

[06/11 00:45:34 INFO drqa.reader.predictor] Initializing model...
[06/11 00:45:34 INFO drqa.reader.model] Loading model ckpts/golden-retriever/hop1.mdl
[06/11 00:45:34 INFO drqa.reader.predictor] Expanding dictionary...
[06/11 00:46:22 INFO drqa.reader.model] Adding 2115762 new words to dictionary...
[06/11 00:46:25 INFO drqa.reader.model] New vocab size: 2195963
[06/11 00:46:35 INFO drqa.reader.model] Loading pre-trained embeddings for 2115762 words from data/glove.840B.300d.txt
[06/11 00:50:36 INFO drqa.reader.model] Loaded 2115762 embeddings (100.00%)
[06/11 00:50:37 INFO drqa.reader.predictor] Initializing tokenizer...
[06/11 00:50:50 INFO drqa.reader.predictor] Initializing model...
[06/11 00:50:50 INFO drqa.reader.model] Loading model ckpts/golden-retriever/hop2.mdl
[06/11 00:50:52 INFO drqa.reader.predictor] Expanding dictionary...
[06/11 00:51:39 INFO drqa.reader.model] Adding 1892506 new words to dictionary...
[06/11 00:51:41 INFO drqa.reader.model] New vocab size: 2195963
[06

RnnDocReader(
  (embedding): Embedding(2195963, 300, padding_idx=0)
  (qemb_match): SeqAttnMatch(
    (linear): Linear(in_features=300, out_features=300, bias=True)
  )
  (doc_rnn): StackedBRNN(
    (rnns): ModuleList(
      (0): LSTM(662, 128, bidirectional=True)
      (1): LSTM(256, 128, bidirectional=True)
      (2): LSTM(256, 128, bidirectional=True)
    )
  )
  (question_rnn): StackedBRNN(
    (rnns): ModuleList(
      (0): LSTM(300, 128, bidirectional=True)
      (1): LSTM(256, 128, bidirectional=True)
      (2): LSTM(256, 128, bidirectional=True)
    )
  )
  (self_attn): LinearSeqAttn(
    (linear): Linear(in_features=768, out_features=1, bias=True)
  )
  (start_attn): BilinearSeqAttn(
    (linear): Linear(in_features=768, out_features=768, bias=True)
  )
  (end_attn): BilinearSeqAttn(
    (linear): Linear(in_features=768, out_features=768, bias=True)
  )
)

# Load dense indexer and dense encoder

In [8]:
import faiss
faiss.omp_set_num_threads(16)

In [9]:
from argparse import Namespace

args = Namespace(**{
    "model_name": "roberta-base",
    "model_path": "ckpts/mdr/q_encoder.pt",
    "index_prefix_path": "data/index/mdr/hotpot-paragraph-q-strict",  # .hnsw
    "index_buffer_size": 50000,
    "max_q_len": 70,
    "max_q_sp_len": 350
})

In [10]:
from transformers import AutoConfig, AutoTokenizer
from mdr.retrieval.models.retriever import RobertaCtxEncoder
from utils.model_utils import load_state

bert_config = AutoConfig.from_pretrained(args.model_name)
dense_tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dense_encoder = RobertaCtxEncoder(bert_config, args)
dense_encoder = load_state(dense_encoder, args.model_path, exact=False)
dense_encoder.to(torch.device('cuda:0'))
dense_encoder.eval()

[06/11 00:55:46 INFO transformers.configuration_utils] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/shenhuawei/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690
[06/11 00:55:46 INFO transformers.configuration_utils] Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "type_vocab_size": 1,
  "vocab_size": 50265
}

[06/11 00:55:48 INFO transformers.configuration_utils] 

RobertaCtxEncoder(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [11]:
from dense_indexers import DenseHNSWFlatIndexer, DenseFlatIndexer

vector_size = bert_config.hidden_size
if args.index_prefix_path.endswith('hnsw'):
    dense_indexer = DenseHNSWFlatIndexer(vector_size, args.index_buffer_size)
else:
    dense_indexer = DenseFlatIndexer(vector_size, args.index_buffer_size)
dense_indexer.deserialize_from(args.index_prefix_path)

[06/11 00:56:06 INFO dense_indexers] Loading index from data/index/mdr/hotpot-paragraph-q-strict
[06/11 00:56:40 INFO dense_indexers] Loaded index of type <class 'faiss.swigfaiss.IndexFlat'> and size 5232080


In [12]:
from retriever import DenseRetriever

dense_retriever = DenseRetriever(dense_indexer, dense_encoder, dense_tokenizer)

# Load corpus

In [13]:
from utils.data_utils import load_corpus

corpus, title2id = load_corpus('data/corpus/hotpot-paragraph-5.tsv', for_hotpot=True, require_hyperlinks=True)
print(len(corpus))
print(len(title2id))

[06/11 01:01:42 INFO utils.data_utils] Loaded 5232080 passages from data/corpus/hotpot-paragraph-5.tsv


5232080
5232080


# Load union model

In [14]:
from models.union_model import UnionModel

tokenizer = AutoTokenizer.from_pretrained('google/electra-base-discriminator', use_fast=True,
                                           additional_special_tokens=['[unused0]', '[unused1]',
                                                                      '[unused2]', '[unused3]'])
union_model = UnionModel('google/electra-base-discriminator', max_ans_len=64)

[06/11 01:01:43 INFO transformers.configuration_utils] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/config.json from cache at /home/shenhuawei/.cache/torch/transformers/9236d197566a7f1be2b2151f5afcc5a8e17f31e1e23c52f3cdf2340019986e78.88ba6e8e7d5a7936e86d6f2551fe19c236dc57c24da163907cd0544e9933f6ee
[06/11 01:01:43 INFO transformers.configuration_utils] Model config ElectraConfig {
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "embedding_size": 768,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "electra",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "summary_activation": "gelu",
  "summary_last_dropout": 0.1,
  "summary_type": "first",
  "summary_use_proj": true,
  "type_vocab_

In [15]:
init_checkpoint = 'ckpts/td5-exp1-ila.4_electra-base-discriminator_DP0.5_HN2_M2_D2_adamW_SP0.5_B32_LR2.0e-05_WU0.1_E30_S42_04202303_pld-sb0-wo*-cmd10/checkpoint_68000.pt'
device = torch.device("cuda:0")

union_model = load_state(union_model, init_checkpoint)
union_model.to(device)
union_model.eval()

UnionModel(
  (encoder): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affin

# Env: WikiWorld

In [16]:
import redis
from tqdm.auto import tqdm
from env.core import Environment

bm25_redis = redis.Redis(host='10.60.1.79', port=6379, db=3, password='redis4zyc', decode_responses=True)
mdr_redis = redis.Redis(host='10.60.1.79', port=6379, db=4, password='redis4zyc', decode_responses=True)
env = Environment(corpus, title2id, sparse_retriever, dense_retriever, bm25_redis, mdr_redis,
                  for_hotpot=True, strict=True, max_ret_size=1000)

# QA agent

In [25]:
import copy
from agent import Agent
from hotpot_evaluate_plus import evaluate, pretty_metrics

faiss.omp_set_num_threads(1)
query_redis = redis.Redis(host='10.60.1.79', port=6379, db=2, password='redis4zyc', decode_responses=True)
agent = Agent(tokenizer, union_model, qg1, qg2, device, env, query_redis,
              actions=("ANSWER", "BM25", "MDR", "LINK"), action_mask=(1, 1, 1, 1),
              memory_size=2, max_seq_len=512, max_q_len=96, max_obs_len=256, strict=True,
              gold_qas_map=gold_qas_map, oracle_belief=False, oracle_state2action=None)  # gold_state2action

In [26]:
predictions = {"answer": dict(), "norm_answer": dict(), "sp": dict(), "_sp": dict(), "spp": dict()}
n_observed = -1
env.reset()
agent.reset()
agent.eval()
q_ids = [sample[0] for sample in samples]
questions = [sample[1][0] for sample in samples]
observations = [None] * len(q_ids)

In [27]:
def force_answer():
    if split != 'test':
        answers, _answers, sp_passages, sp_facts, _sp_facts = agent.force_answer(q_ids, questions)
        tmp_predictions = copy.deepcopy(predictions)
        tmp_predictions['answer'].update(answers)
        tmp_predictions['norm_answer'].update(_answers)
        tmp_predictions['sp'].update(sp_facts)
        tmp_predictions['_sp'].update(_sp_facts)
        tmp_predictions['spp'].update(sp_passages)
        print(pretty_metrics(evaluate(tmp_predictions, gold_samples)))
        del tmp_predictions

In [28]:
MAX_OBS = 10
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

while len(observations) > 0 and n_observed < MAX_OBS:
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in tqdm(zip(q_ids, questions), desc='exe', total=len(q_ids)):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx observed memory
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    print(f"{n_observed}: {len(observations)} remained")
    if n_observed > 0 and split != 'test':
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()

gen_sparse_query:   0%|          | 0/7405 [00:00<?, ?it/s]

act:   0%|          | 0/232 [00:00<?, ?it/s]

exe:   0%|          | 0/7405 [00:00<?, ?it/s]

0: 7405 remained


gen_sparse_query:   0%|          | 0/7405 [00:00<?, ?it/s]

act:   0%|          | 0/232 [00:00<?, ?it/s]

[06/11 01:05:59 INFO agent] 5a77152355429966f1a36c2e: miss SP0 2995506_0 (What Are Little Boys Made Of?) for 0.3236 <= 0.4820 when 232 tokens
[06/11 01:06:00 INFO agent] 5ae305df5542992decbdcdc3: miss SP0 34639380_0 (2012 Louisville Cardinals football team) for 0.3529 <= 0.3849 when 345 tokens
[06/11 01:06:01 INFO agent] 5ab28019554299722f9b4d51: miss SP0 30863419_0 (Son of al Qaeda) for 0.0113 <= 0.8037 when 250 tokens
[06/11 01:06:01 INFO agent] 5a7b46e55542995eb53be8e5: miss SP0 44920366_0 (Spaceware Sky Vision II) for 0.0470 <= 0.4297 when 296 tokens
[06/11 01:06:02 INFO agent] 5ae628f55542995703ce8b32: miss SP0 27362380_0 (Six Violin Sonatas, Op. 5 (Vivaldi)) for 0.1364 <= 0.4814 when 226 tokens
[06/11 01:06:04 INFO agent] 5a7b374c5542995eb53be8cc: miss SP0 4491426_0 (Buttered cat paradox) for 0.0377 <= 0.6577 when 362 tokens
[06/11 01:06:05 INFO agent] 5a8242c155429940e5e1a81d: miss SP0 6513540_0 (Thoughts on Machiavelli) for 0.2711 <= 0.3119 when 233 tokens
[06/11 01:06:07 INFO 

exe:   0%|          | 0/7405 [00:00<?, ?it/s]

1: 7405 remained
|                  |   EM |   F1 | Prec | Recall |
|------------------|------|------|------|--------|
|           Answer | 0.00 | 0.00 | 0.00 |   0.00 |
|      Norm answer | 0.00 | 0.00 | 0.00 |   0.00 |
| Support sentence | 0.00 | 0.00 | 0.00 |   0.00 |
| Support sentence | 0.00 | 0.00 | 0.00 |   0.00 |
|  Support passage | 0.00 | 0.00 | 0.00 |   0.00 |
|            Joint | 0.00 | 0.00 | 0.00 |   0.00 |
0.999594868332208


gen_sparse_query:   0%|          | 0/7405 [00:00<?, ?it/s]

act:   0%|          | 0/232 [00:00<?, ?it/s]

[06/11 01:07:46 INFO agent] 5a75e05c55429976ec32bc5f: augment SP ['3434750_0'] into memory
[06/11 01:07:47 INFO agent] 5a713ea95542994082a3e6e4: miss SP1 23602935_0 (Indigenous peoples of Florida) for 0.0059 <= 0.5659 when 379 tokens
[06/11 01:07:47 INFO agent] 5a80b3a9554299485f5986cc: miss SP1 137582_0 (McLean, Virginia) for 0.1781 <= 0.6269 when 379 tokens
[06/11 01:07:47 INFO agent] 5a80b3a9554299485f5986cc: augment SP ['137582_0'] into memory
[06/11 01:07:48 INFO agent] 5a8739a05542994775f607ab: miss SP0 32700345_0 (New York's 1st State Senate district) for 0.3087 <= 0.3827 when 439 tokens
[06/11 01:07:49 INFO agent] 5ab28019554299722f9b4d51: miss SP0 30863419_0 (Son of al Qaeda) for 0.0113 <= 0.8037 when 419 tokens
[06/11 01:07:49 INFO agent] 5a7ed2c655429930675135e5: miss SP1 39367266_0 (Brooklyn Nine-Nine) for 0.0264 <= 0.5958 when 419 tokens
[06/11 01:07:49 INFO agent] 5a7a3a945542996a35c17147: miss SP1 4568_0 (Bank of China Tower (Hong Kong)) for 0.0439 <= 0.6429 when 381 tok

exe:   0%|          | 0/7405 [00:00<?, ?it/s]

2: 2261 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 46.33 | 55.62 | 57.47 |  56.60 |
|      Norm answer | 46.12 | 55.39 | 57.28 |  56.32 |
| Support sentence | 43.32 | 60.34 | 62.01 |  60.04 |
| Support sentence | 42.28 | 60.01 | 62.60 |  59.29 |
|  Support passage | 64.32 | 66.58 | 66.58 |  66.58 |
|            Joint | 31.55 | 49.95 | 52.92 |  50.49 |
1.999594868332208


gen_sparse_query:   0%|          | 0/2261 [00:00<?, ?it/s]

act:   0%|          | 0/71 [00:00<?, ?it/s]

[06/11 01:09:54 INFO agent] 5a83880e554299123d8c214e: remove SP1 3888242_0 (Suicide (1977 album)) from memory for 0.9925 -> 0.1992 <= 0.6130 when 334 tokens
[06/11 01:09:54 INFO agent] 5a83880e554299123d8c214e: augment SP ['3888242_0'] into memory
[06/11 01:09:54 INFO agent] 5a7a230e5542996a35c170ee: augment SP ['9056492_0'] into memory
[06/11 01:09:55 INFO agent] 5a7d19d85542995ed0d165e8: remove SP0 79371_0 (Southeastern Conference) from memory for 0.5753 -> 0.4733 <= 0.5049 when 384 tokens
[06/11 01:09:55 INFO agent] 5ae7edee554299540e5a56ad: remove SP2 269223_0 (Tenerife) from memory for 0.9413 -> 0.5695 <= 0.7304 when 372 tokens
[06/11 01:09:55 INFO agent] 5ae7edee554299540e5a56ad: augment SP ['269223_0'] into memory
[06/11 01:09:55 INFO agent] 5ab5b6c1554299637185c5e5: miss SP1 3139987_0 (Hellogoodbye) for 0.0624 <= 0.6184 when 372 tokens
[06/11 01:09:56 INFO agent] 5a8b78775542997f31a41d3d: miss SP1 5004696_0 (Chimaji Appa) for 0.0085 <= 0.6596 when 383 tokens
[06/11 01:09:56 INF

exe:   0%|          | 0/2261 [00:00<?, ?it/s]

3: 1462 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 53.30 | 64.03 | 66.19 |  65.19 |
|      Norm answer | 52.91 | 63.60 | 65.81 |  64.70 |
| Support sentence | 48.70 | 69.04 | 71.11 |  68.62 |
| Support sentence | 47.32 | 68.62 | 71.81 |  67.72 |
|  Support passage | 73.61 | 76.49 | 76.49 |  76.49 |
|            Joint | 35.41 | 57.10 | 60.67 |  57.70 |
2.3049291019581366


gen_sparse_query:   0%|          | 0/1462 [00:00<?, ?it/s]

act:   0%|          | 0/46 [00:00<?, ?it/s]

[06/11 01:10:38 INFO agent] 5a85fb085542994775f606de: augment SP ['142457_0'] into memory
[06/11 01:10:38 INFO agent] 5a713ea95542994082a3e6e4: augment SP ['23602935_0'] into memory
[06/11 01:10:39 INFO agent] 5ab31864554299233954ff06: miss SP1 17438452_0 (Apratim Majumdar) for 0.5462 <= 0.5699 when 512 tokens
[06/11 01:10:40 INFO agent] 5abbc58455429931dba14502: miss SP1 3226774_0 (The Grudge 2) for 0.1018 <= 0.6753 when 512 tokens
[06/11 01:10:40 INFO agent] 5abbc58455429931dba14502: augment SP ['3226774_0'] into memory
[06/11 01:10:40 INFO agent] 5ab438395542990594ba9bb9: augment SP ['25857222_0'] into memory
[06/11 01:10:41 INFO agent] 5a901f735542990a98493591: miss SP1 32569_0 (Vitellius) for 0.1505 <= 0.5339 when 452 tokens
[06/11 01:10:41 INFO agent] 5adfcb5455429906c02daa48: remove SP1 33529086_0 (Death in Paradise (TV series)) from memory for 0.8310 -> 0.0630 <= 0.5293 when 452 tokens
[06/11 01:10:41 INFO agent] 5a8347d65542996488c2e3f6: miss SP0 464883_0 (Miracle (2004 film))

exe:   0%|          | 0/1462 [00:00<?, ?it/s]

4: 1158 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 55.96 | 67.14 | 69.41 |  68.32 |
|      Norm answer | 55.53 | 66.66 | 68.98 |  67.78 |
| Support sentence | 50.48 | 72.28 | 74.46 |  71.86 |
| Support sentence | 48.97 | 71.84 | 75.21 |  70.92 |
|  Support passage | 77.03 | 80.19 | 80.19 |  80.19 |
|            Joint | 36.77 | 59.73 | 63.46 |  60.36 |
2.5022282241728564


gen_sparse_query:   0%|          | 0/1158 [00:00<?, ?it/s]

act:   0%|          | 0/37 [00:00<?, ?it/s]

[06/11 01:11:10 INFO agent] 5a7ed2c655429930675135e5: miss SP2 39367266_0 (Brooklyn Nine-Nine) for 0.1678 <= 0.7170 when 512 tokens
[06/11 01:11:10 INFO agent] 5ab31864554299233954ff06: miss SP1 577443_0 (Sarod) for 0.1339 <= 0.5065 when 512 tokens
[06/11 01:11:11 INFO agent] 5a7319e755429901807daf86: augment SP ['1447560_0', '25871295_0'] into memory
[06/11 01:11:11 INFO agent] 5a901f735542990a98493591: miss SP1 32569_0 (Vitellius) for 0.1505 <= 0.5339 when 360 tokens
[06/11 01:11:12 INFO agent] 5ae5d7b755429929b08079cd: miss SP1 21709_0 (Nation of Islam) for 0.2072 <= 0.5283 when 359 tokens
[06/11 01:11:12 INFO agent] 5ac156d05542994ab5c67ce9: remove SP2 13072534_0 (Lin-Manuel Miranda) from memory for 0.8224 -> 0.1082 <= 0.6409 when 431 tokens
[06/11 01:11:12 INFO agent] 5a84807d554299123d8c2278: remove SP0 593041_0 (Hugh Dancy) from memory for 0.8494 -> 0.4384 <= 0.5426 when 431 tokens
[06/11 01:11:12 INFO agent] 5a7f85505542995d8a8dde95: miss SP1 11940266_0 (Heading Out to the High

exe:   0%|          | 0/1158 [00:00<?, ?it/s]

5: 1011 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 57.29 | 68.67 | 70.99 |  69.86 |
|      Norm answer | 56.81 | 68.16 | 70.53 |  69.28 |
| Support sentence | 51.24 | 73.73 | 76.01 |  73.27 |
| Support sentence | 49.70 | 73.30 | 76.77 |  72.34 |
|  Support passage | 78.56 | 81.89 | 81.89 |  81.89 |
|            Joint | 37.37 | 60.97 | 64.83 |  61.58 |
2.6584740040513166


gen_sparse_query:   0%|          | 0/1011 [00:00<?, ?it/s]

act:   0%|          | 0/32 [00:00<?, ?it/s]

[06/11 01:11:34 INFO agent] 5a85eed75542996432c5713b: augment SP ['31753716_0'] into memory
[06/11 01:11:36 INFO agent] 5ab30d1155429976abd1bc3d: miss SP1 144415_0 (The Santa Clause) for 0.0927 <= 0.6480 when 512 tokens
[06/11 01:11:36 INFO agent] 5ab30d1155429976abd1bc3d: augment SP ['144415_0'] into memory
[06/11 01:11:36 INFO agent] 5ab1d983554299340b52540a: augment SP ['28348349_0'] into memory
[06/11 01:11:36 INFO agent] 5a88bd1d554299206df2b357: miss SP1 27652115_0 (Jackson Avery) for 0.3634 <= 0.5829 when 512 tokens
[06/11 01:11:36 INFO agent] 5a88bd1d554299206df2b357: augment SP ['27652115_0'] into memory
[06/11 01:11:37 INFO agent] 5ab8903555429916710eb08e: augment SP ['17739236_0'] into memory
[06/11 01:11:38 INFO agent] 5a7da6cb5542990b8f5039ff: miss SP1 7867644_0 (Kevin Alejandro) for 0.1530 <= 0.5368 when 369 tokens
[06/11 01:11:38 INFO agent] 5a7da6cb5542990b8f5039ff: augment SP ['7867644_0'] into memory
[06/11 01:11:38 INFO agent] 5a8d1f015542994ba4e3dc08: remove SP0 197

exe:   0%|          | 0/1011 [00:00<?, ?it/s]

6: 906 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 58.11 | 69.67 | 72.03 |  70.87 |
|      Norm answer | 57.60 | 69.12 | 71.52 |  70.25 |
| Support sentence | 51.84 | 74.79 | 77.15 |  74.30 |
| Support sentence | 50.26 | 74.34 | 77.93 |  73.34 |
|  Support passage | 79.66 | 83.11 | 83.11 |  83.11 |
|            Joint | 37.84 | 61.81 | 65.77 |  62.40 |
2.7950033760972315


gen_sparse_query:   0%|          | 0/906 [00:00<?, ?it/s]

act:   0%|          | 0/29 [00:00<?, ?it/s]

[06/11 01:11:59 INFO agent] 5abb11c15542992ccd8e7ef8: augment SP ['69795_0'] into memory
[06/11 01:11:59 INFO agent] 5ab42b24554299753aec5a43: miss SP0 33397730_0 (Operation Lighthouse (1937)) for 0.0165 <= 0.5622 when 384 tokens
[06/11 01:11:59 INFO agent] 5ae75a5a5542991bbc9761f3: remove SP1 38871163_0 (Kriti Sanon) from memory for 0.8602 -> 0.6024 <= 0.6694 when 384 tokens
[06/11 01:12:01 INFO agent] 5ae22f4e5542996483e6492f: augment SP ['3381857_0'] into memory
[06/11 01:12:01 INFO agent] 5ab8c3c65542991b5579effd: miss SP0 1394553_0 (Tift Merritt) for 0.2972 <= 0.3349 when 443 tokens
[06/11 01:12:01 INFO agent] 5ac559bd5542993e66e82328: miss SP1 3488963_0 (Horse the Band) for 0.0013 <= 0.5418 when 443 tokens
[06/11 01:12:01 INFO agent] 5a712c4b5542994082a3e622: miss SP1 48768921_0 (Cook's Landing Place, Town of Seventeen Seventy) for 0.0818 <= 0.6295 when 424 tokens
[06/11 01:12:01 INFO agent] 5ab8338155429919ba4e2260: miss SP0 4540384_0 (Ian Brayshaw) for 0.0104 <= 0.6276 when 424

exe:   0%|          | 0/906 [00:00<?, ?it/s]

7: 832 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 58.62 | 70.35 | 72.75 |  71.54 |
|      Norm answer | 58.11 | 69.80 | 72.25 |  70.93 |
| Support sentence | 52.17 | 75.45 | 77.86 |  74.93 |
| Support sentence | 50.55 | 74.99 | 78.65 |  73.96 |
|  Support passage | 80.31 | 83.89 | 83.89 |  83.89 |
|            Joint | 38.07 | 62.33 | 66.38 |  62.89 |
2.9173531397704253


gen_sparse_query:   0%|          | 0/832 [00:00<?, ?it/s]

act:   0%|          | 0/26 [00:00<?, ?it/s]

[06/11 01:12:18 INFO agent] 5a84f7255542991dd0999e33: augment SP ['27423_0'] into memory
[06/11 01:12:18 INFO agent] 5adfcb5455429906c02daa48: miss SP1 33529086_0 (Death in Paradise (TV series)) for 0.5296 <= 0.7263 when 348 tokens
[06/11 01:12:19 INFO agent] 5ac5275755429924173fb617: miss SP0 46340990_0 (Marc Schiller) for 0.0165 <= 0.7458 when 349 tokens
[06/11 01:12:19 INFO agent] 5abb73425542996cc5e49ff5: miss SP1 294538_0 (Detroit Metropolitan Airport) for 0.0107 <= 0.5931 when 345 tokens
[06/11 01:12:19 INFO agent] 5ae20cd25542997283cd2376: miss SP1 57539_0 (Ulysses (spacecraft)) for 0.0114 <= 0.5522 when 304 tokens
[06/11 01:12:19 INFO agent] 5ae20cd25542997283cd2376: augment SP ['57539_0'] into memory
[06/11 01:12:21 INFO agent] 5ab56f4b554299637185c599: miss SP1 37644013_0 (Republican Party presidential primaries, 2016) for 0.0143 <= 0.5704 when 391 tokens
[06/11 01:12:21 INFO agent] 5ac508205542994611c8b32e: remove SP0 30863956_0 (M1919 Browning machine gun) from memory for 0

exe:   0%|          | 0/832 [00:00<?, ?it/s]

8: 788 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 58.89 | 70.71 | 73.13 |  71.91 |
|      Norm answer | 58.37 | 70.15 | 72.62 |  71.28 |
| Support sentence | 52.40 | 75.89 | 78.31 |  75.37 |
| Support sentence | 50.74 | 75.42 | 79.12 |  74.38 |
|  Support passage | 80.72 | 84.38 | 84.38 |  84.38 |
|            Joint | 38.22 | 62.62 | 66.69 |  63.20 |
3.0297096556380825


gen_sparse_query:   0%|          | 0/788 [00:00<?, ?it/s]

act:   0%|          | 0/25 [00:00<?, ?it/s]

[06/11 01:12:36 INFO agent] 5ae1f596554299234fd04372: augment SP ['24942471_0'] into memory
[06/11 01:12:36 INFO agent] 5ae6b6065542991bbc976168: remove SP0 23636141_0 (Universal Soldier (franchise)) from memory for 0.9061 -> 0.2042 <= 0.3807 when 295 tokens
[06/11 01:12:36 INFO agent] 5ac00f795542996f0d89cb16: miss SP0 230839_0 (Mickey's Christmas Carol) for 0.3208 <= 0.4180 when 295 tokens
[06/11 01:12:36 INFO agent] 5a8a9bc15542996c9b8d5f36: miss SP1 529355_0 (A Chorus Line) for 0.0063 <= 0.6348 when 295 tokens
[06/11 01:12:37 INFO agent] 5ae75a5a5542991bbc9761f3: remove SP1 38871163_0 (Kriti Sanon) from memory for 0.7051 -> 0.1247 <= 0.6918 when 333 tokens
[06/11 01:12:38 INFO agent] 5adfa451554299025d62a319: miss SP1 39793552_0 (Alleyne v. United States) for 0.0060 <= 0.5705 when 391 tokens
[06/11 01:12:38 INFO agent] 5ab4147a5542996a3a969f1e: miss SP0 5608488_0 (Stephen Curry) for 0.1386 <= 0.5064 when 391 tokens
[06/11 01:12:38 INFO agent] 5ab8338155429919ba4e2260: miss SP0 4540

exe:   0%|          | 0/788 [00:00<?, ?it/s]

9: 748 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 59.20 | 71.07 | 73.49 |  72.27 |
|      Norm answer | 58.68 | 70.51 | 72.99 |  71.65 |
| Support sentence | 52.59 | 76.29 | 78.73 |  75.77 |
| Support sentence | 50.90 | 75.82 | 79.54 |  74.77 |
|  Support passage | 81.11 | 84.83 | 84.83 |  84.83 |
|            Joint | 38.34 | 62.91 | 66.99 |  63.48 |
3.135989196488859


gen_sparse_query:   0%|          | 0/748 [00:00<?, ?it/s]

act:   0%|          | 0/24 [00:00<?, ?it/s]

[06/11 01:12:52 INFO agent] 5a7d19d85542995ed0d165e8: augment SP ['35001974_0'] into memory
[06/11 01:12:53 INFO agent] 5a7363ec5542991f29ee2dd7: miss SP1 109495_0 (Key West, Florida) for 0.0404 <= 0.5634 when 381 tokens
[06/11 01:12:54 INFO agent] 5ae75a5a5542991bbc9761f3: miss SP1 35121149_0 (Raabta (song)) for 0.2161 <= 0.6567 when 366 tokens
[06/11 01:12:55 INFO agent] 5ae5e762554299546bf82faf: miss SP1 2110323_0 (Rihanna) for 0.4111 <= 0.6750 when 309 tokens
[06/11 01:12:55 INFO agent] 5ae5e762554299546bf82faf: augment SP ['2110323_0'] into memory
[06/11 01:12:56 INFO agent] 5ac17b485542994ab5c67d65: miss SP0 23786794_0 (St Wilfrid's Church, Alford) for 0.0555 <= 0.6298 when 302 tokens
[06/11 01:12:57 INFO agent] 5a80b9ae5542992bc0c4a7eb: augment SP ['14900848_0'] into memory
[06/11 01:12:58 INFO agent] 5a8760e75542996e4f3087af: miss SP0 442343_0 (Whatever Happened to... Robot Jones?) for 0.2115 <= 0.3878 when 384 tokens
[06/11 01:12:59 INFO agent] 5a78d9c455429970f5fffda7: miss S

exe:   0%|          | 0/748 [00:00<?, ?it/s]

10: 721 remained
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 59.43 | 71.34 | 73.77 |  72.54 |
|      Norm answer | 58.87 | 70.75 | 73.24 |  71.89 |
| Support sentence | 52.69 | 76.55 | 79.01 |  76.03 |
| Support sentence | 50.99 | 76.08 | 79.81 |  75.03 |
|  Support passage | 81.36 | 85.13 | 85.13 |  85.13 |
|            Joint | 38.43 | 63.13 | 67.24 |  63.70 |
3.236866981769075
|                  |    EM |    F1 |  Prec | Recall |
|------------------|-------|-------|-------|--------|
|           Answer | 60.90 | 73.44 | 75.94 |  74.74 |
|      Norm answer | 60.34 | 72.87 | 75.45 |  74.09 |
| Support sentence | 53.10 | 79.06 | 81.74 |  78.43 |
| Support sentence | 51.21 | 78.80 | 83.72 |  77.25 |
|  Support passage | 82.16 | 88.05 | 88.05 |  88.05 |
|            Joint | 38.65 | 64.17 | 68.38 |  64.74 |


In [None]:
assert False

In [None]:
MAX_OBS = 20
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

while len(observations) > 0 and n_observed < MAX_OBS:
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in tqdm(zip(q_ids, questions), desc='exe', total=len(q_ids)):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    print(f"{n_observed}: {len(observations)} remained")
    if n_observed % 5 == 0 and split != 'test':
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()

In [None]:
MAX_OBS = 30
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

while len(observations) > 0 and n_observed < MAX_OBS:
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in tqdm(zip(q_ids, questions), desc='exe', total=len(q_ids)):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    print(f"{n_observed}: {len(observations)} remained")
    if n_observed % 5 == 0  and split != 'test':
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()

In [None]:
MAX_OBS = 50
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

while len(observations) > 0 and n_observed < MAX_OBS:
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in tqdm(zip(q_ids, questions), desc='exe', total=len(q_ids)):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    print(f"{n_observed}: {len(observations)} remained")
    if n_observed % 10 == 0  and split != 'test':
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()

In [None]:
MAX_OBS = 100
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

while len(observations) > 0 and n_observed < MAX_OBS:
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in tqdm(zip(q_ids, questions), desc='exe', total=len(q_ids)):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    print(f"{n_observed}: {len(observations)} remained")
    if n_observed % 10 == 0 and split != 'test':
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()

In [None]:
MAX_OBS = 300
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0, disable_tqdm=True)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 50 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

In [None]:
MAX_OBS = 500
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0, disable_tqdm=True)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 50 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

In [None]:
MAX_OBS = 700
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0, disable_tqdm=True)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 50 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

In [None]:
MAX_OBS = 1000
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0, disable_tqdm=True)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 50 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

In [None]:
assert False

In [None]:
MAX_OBS = 1500
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 100 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

In [None]:
MAX_OBS = 2000
set_global_logging_level(logging.WARNING, ["elasticsearch", "dpr"])

pbar = tqdm(total=MAX_OBS - n_observed)
while len(observations) > 0 and n_observed < MAX_OBS:
    pbar.set_description(f"act - {len(observations)}")
    commands = agent.act(q_ids, questions, observations, review=(n_observed + 1) % 100 == 0)
    pbar.set_description(f"exe - {len(observations)}")
    _q_ids, _questions, _observations = [], [], []
    for q_id, question in zip(q_ids, questions):
        cmd = commands[q_id]
        if cmd[0] == 'ANSWER':
            predictions['answer'][q_id] = agent.answer[q_id]
            predictions['norm_answer'][q_id] = max(agent.all_answer[q_id].keys(), key=lambda k: agent.all_answer[q_id][k]) if len(agent.all_answer[q_id]) > 0 else 'noanswer'
            predictions['sp'][q_id] = list(agent.sp_facts[q_id].keys())
            predictions['_sp'][q_id] = list(agent.all_sp_facts[q_id].keys())
            predictions['spp'][q_id] = agent.sp_passages[q_id]
            continue
        else:
            obs = env.step(cmd, q_id, exclusion=agent.memory[q_id])  # xxx agent.observed[q_id]
            _q_ids.append(q_id)
            _questions.append(question)
            _observations.append(obs)
    q_ids, questions, observations = _q_ids, _questions, _observations
    n_observed += 1
    pbar.update(1)
    if n_observed % 100 == 0 and split != 'test':
        print(f"{n_observed}: {len(observations)} remained")
        print(pretty_metrics(evaluate(predictions, gold_samples)))
        print(sum([len(x) for x in agent.observed.values()]) / len(agent.observed))
force_answer()
pbar.close()

# Analyze

## Error types

In [None]:
print({k: len(v) for k, v in agent.cases.items()})

### **ans none and early ans**

In [None]:
print(len(agent.cases['ans_none']), len(agent.cases['ans_none'] & agent.cases['early_ans']), len(agent.cases['early_ans']))

In [None]:
print(len(agent.cases['ans_none'] - agent.cases['good_spp']), len(agent.cases['ans_none'] & agent.cases['good_spp']), len(agent.cases['ans_none']))
print(len(agent.cases['ans_none'] - agent.cases['good_ans']), len(agent.cases['ans_none'] & agent.cases['good_ans']), len(agent.cases['ans_none']))

In [None]:
print(len(agent.cases['early_ans'] - agent.cases['good_spp']), len(agent.cases['early_ans'] & agent.cases['good_spp']), len(agent.cases['early_ans']))
print(len(agent.cases['early_ans'] - agent.cases['good_ans']), len(agent.cases['early_ans'] & agent.cases['good_ans']), len(agent.cases['early_ans']))

### false final sp sentences

In [None]:
print(len(agent.cases['bad_sp'] - agent.cases['good_spp']), len(agent.cases['bad_sp'] & agent.cases['good_spp']), len(agent.cases['bad_sp']))
print(len(agent.cases['bad_sp'] - agent.cases['good_ans']), len(agent.cases['bad_sp'] & agent.cases['good_ans']), len(agent.cases['bad_sp']))

### **false final sp passages**

In [None]:
bad_spp_cases = set(agent.observed.keys()) - agent.cases['good_spp']

In [None]:
print(len(bad_spp_cases - agent.cases['good_ans']), len(bad_spp_cases & agent.cases['good_ans']), len(bad_spp_cases))

In [None]:
from collections import Counter

rec_spp_nums = [len(set(agent.pretty_memory(q_id).keys()) & set(gold_qas_map[q_id][2].keys())) for q_id in bad_spp_cases]

print(Counter(rec_spp_nums))

### **add distractor into memory**

In [None]:
print(len(agent.cases['add_dis'] - agent.cases['good_spp']), len(agent.cases['add_dis'] & agent.cases['good_spp']), len(agent.cases['add_dis']))
print(len(agent.cases['add_dis'] - agent.cases['good_ans']), len(agent.cases['add_dis'] & agent.cases['good_ans']), len(agent.cases['add_dis']))

### **miss sp passage**

In [None]:
print(len(agent.cases['miss_sp'] - agent.cases['good_spp']), len(agent.cases['miss_sp'] & agent.cases['good_spp']), len(agent.cases['miss_sp']))
print(len(agent.cases['miss_sp'] - agent.cases['good_ans']), len(agent.cases['miss_sp'] & agent.cases['good_ans']), len(agent.cases['miss_sp']))

### miss 2nd sp passage

In [None]:
print(len(agent.cases['miss_sp2'] - agent.cases['good_spp']), len(agent.cases['miss_sp2'] & agent.cases['good_spp']), len(agent.cases['miss_sp2']))
print(len(agent.cases['miss_sp2'] - agent.cases['good_ans']), len(agent.cases['miss_sp2'] & agent.cases['good_ans']), len(agent.cases['miss_sp2']))

### remove sp

In [None]:
print(len(agent.cases['rm_sp'] - agent.cases['good_spp']), len(agent.cases['rm_sp'] & agent.cases['good_spp']), len(agent.cases['rm_sp']))
print(len(agent.cases['rm_sp'] - agent.cases['good_ans']), len(agent.cases['rm_sp'] & agent.cases['good_ans']), len(agent.cases['rm_sp']))

In [None]:
print(len(agent.cases['rm_sp'] & agent.cases['add_dis']), len(agent.cases['rm_sp'] - agent.cases['add_dis']), len(agent.cases['rm_sp']))
print(len(agent.cases['rm_sp'] & agent.cases['add_dis'] & agent.cases['good_spp']), len(agent.cases['rm_sp'] & agent.cases['add_dis'] - agent.cases['good_spp']))
print(len(agent.cases['rm_sp'] - agent.cases['add_dis'] & agent.cases['good_spp']), len(agent.cases['rm_sp'] - agent.cases['add_dis'] - agent.cases['good_spp']))

### **false dense query expansion**

In [None]:
print(len(agent.cases['false_expansion'] - agent.cases['good_spp']), len(agent.cases['false_expansion'] & agent.cases['good_spp']), len(agent.cases['false_expansion']))
print(len(agent.cases['false_expansion'] - agent.cases['good_ans']), len(agent.cases['false_expansion'] & agent.cases['good_ans']), len(agent.cases['false_expansion']))

In [None]:
print(len(agent.cases['false_expansion'] - agent.cases['add_dis']), len(agent.cases['false_expansion']), len(agent.cases['false_expansion'] & agent.cases['add_dis']), len(agent.cases['add_dis']), len(agent.cases['add_dis'] - agent.cases['false_expansion']))

## final evidence set size

In [None]:
from collections import defaultdict

msz2cases = defaultdict(set)
for q_id, mem in agent.memory.items():
    msz2cases[len(mem)].add(q_id)

for msz in sorted(msz2cases.keys()):
    cases = msz2cases[msz]
    print(f"{msz}: {len(cases)}")
    print(len(cases - agent.cases['good_spp']), len(cases & agent.cases['good_spp']), len(cases - agent.cases['good_ans']), len(cases & agent.cases['good_ans']))
    print()

## case studty

In [None]:
for x in agent.exhausted_cmds.values():
    if len(x) > 0:
        print(x)

In [None]:
cnt = 0
for k, v in agent.memory.items():
    if len(v) >= 3 and k in predictions['answer']:
        if cnt == 0:
            print(k, v)
        cnt += 1
print(cnt)

In [None]:
q_id = '5a7149595542994082a3e76f'  # 5ab3abe9554299753aec597e 5a7363ec5542991f29ee2dd7 5ae54bbd5542990ba0bbb262 5adf29f05542993344016c09 5a764c0b55429976ec32bd89
if q_id in predictions['answer']:
    print(predictions['answer'][q_id])
print(agent.all_answer[q_id])
if q_id in predictions['sp']:
    print(predictions['sp'][q_id])
print(agent.all_sp_facts[q_id])
if q_id in predictions['spp']:
    print(predictions['spp'][q_id])
print(agent.pretty_memory(q_id))

In [None]:
gold_qas_map[q_id]

In [None]:
print(agent.pretty_behavior(q_id))

In [None]:
print(f"*{0:<3d} Q\t{agent.pretty_cmd(agent.commands[q_id][0])}")
for i, (obs, obs_score, cmd, proposals) in enumerate(zip(agent.ids2titles(agent.observed[q_id]), agent.obs_scores[q_id], agent.commands[q_id][1:], agent.proposals[q_id][1:])):
    print(f"*{i + 1:<3d} {obs_score:.3f} {obs}\t{agent.pretty_cmd(cmd)}")

In [None]:
print(f"*{0:<3d} Q    {agent.pretty_cmd(agent.commands[q_id][0])}")
for cmd, conf, prob in agent.proposals[q_id][0]:
    print(f"{prob:.3f}  {conf:.2f}  {agent.pretty_cmd(cmd)}")
for i, (obs, obs_score, cmd, proposals) in enumerate(zip(agent.ids2titles(agent.observed[q_id]), agent.obs_scores[q_id], agent.commands[q_id][1:], agent.proposals[q_id][1:])):
    print(f"*{i + 1:<3d} {obs_score:.3f} -- {obs}    {agent.pretty_cmd(cmd)}")
    for cmd, conf, prob in proposals:
        print(f"{prob:.3f}  {conf:.2f} {agent.pretty_cmd(cmd)}")

In [None]:
agent.clicked[q_id]

In [None]:
agent.observed[q_id]

In [None]:
agent.obs_scores[q_id]