In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=2,3

env: CUDA_VISIBLE_DEVICES=2,3


In [2]:
import logging
import os
import random
import re
import numpy as np
import torch

def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    random.seed(seed)
    
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
def set_global_logging_level(level=logging.ERROR, prefixes=("",)):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        level: desired level. Optional. Default is logging.ERROR
        prefixes: list of one or more str prefixes to match (e.g. ["transformers", "torch"]). Optional.
            Default is `[""]` to match all active loggers.
            The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{"|".join(prefixes)})')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
    logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)

set_seed(0)
# set_global_logging_level(logging.WARNING, ["elasticsearch"])

## Load examples

In [40]:
import json

split = 'train'
samples = []
with open(f"data/hotpot-{split}.tsv") as f:
    for line in f:
        q_id, question, answer, sp_facts = line.strip().split('\t')
        sp_facts = json.loads(sp_facts)
        samples.append((q_id, (question, answer, sp_facts)))
print(len(samples))
questions = [sample[1][0] for sample in samples]  # (N,)
questions = [q[:-1] if q.endswith('?') else q for q in questions]
print(len(questions))

90447
90447


## Load sparse retriever and query generator

In [5]:
from retriever import SparseRetriever

sparse_retriever = SparseRetriever('enwiki-20171001-paragraph-5', ['10.60.0.59:9200'], timeout=60)

PyTorch version 1.2.0 available.
Loading faiss with AVX2 support.
Loading faiss.


In [6]:
# %env CLASSPATH=corenlp/*
from drqa.reader import Predictor
import warnings

warnings.filterwarnings("ignore")

# import os
# os.environ['CLASSPATH'] = 'corenlp/*'

qg1 = Predictor(model='ckpts/golden-retriever/hop1.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt', num_workers=-1)
qg1.cuda()
# qg1.model.network.to(torch.device('cuda:0'))
qg2 = Predictor(model='ckpts/golden-retriever/hop2.mdl', tokenizer=None, embedding_file='data/glove.840B.300d.txt', num_workers=-1)
qg2.cuda()
# qg2.model.network.to(torch.device('cuda:0'))

Initializing model...
Loading model ckpts/golden-retriever/hop1.mdl
Expanding dictionary...
Adding 2115762 new words to dictionary...
New vocab size: 2195963
Loading pre-trained embeddings for 2115762 words from data/glove.840B.300d.txt
WARN: Duplicate embedding found for ����������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������������
WARN: Duplicate embedding found for Kṛṣṇa
WARN: Duplicate embedding found for Decisión
WARN: Duplicate embedding found for decisión
WARN: Duplicate embedding found for Viṣṇu
WARN: Duplicate embedding found for Justiça
WARN: Duplicate embedding found for Acórdão
WARN: Duplicate embedding found for ̲̅̅
WARN: Duplicate embedding found for Câmara
WARN: Duplicate embedding found for 

## Load dense retriever and query encoder

In [7]:
import faiss
faiss.omp_set_num_threads(16)

In [8]:
from argparse import Namespace

args = Namespace(**{
    "model_name": "roberta-base",
    "model_path": "ckpts/mdr/q_encoder.pt",
    "index_prefix_path": "data/index/mdr/hotpot-paragraph-q-strict.hnsw",
    "index_buffer_size": 50000,
    "max_q_len": 70,
    "max_q_sp_len": 350
})

In [9]:
from transformers import AutoConfig, AutoTokenizer
from mdr.retrieval.models.retriever import RobertaCtxEncoder
from utils.model_utils import load_state

bert_config = AutoConfig.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
query_encoder = RobertaCtxEncoder(bert_config, args)
query_encoder = load_state(query_encoder, args.model_path, exact=False)
device = torch.device('cuda:1')
query_encoder.to(device)
# if torch.cuda.device_count() > 1:
#     query_encoder = torch.nn.DataParallel(query_encoder)
query_encoder.eval()

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/yuxiaoming/.cache/torch/transformers/e1a2a406b5a05063c31f4dfdee7608986ba7c6393f7f79db5e69dcd197208534.117c81977c5979de8c088352e74ec6e70f5c66096c28b61d3c50101609b39690
Model config RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "type_vocab_size": 1,
  "vocab_size": 50265
}

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json from cache at /home/yuxiaoming/.cache/torch/transformers/e1a2

RobertaCtxEncoder(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [10]:
from dense_indexers import DenseHNSWFlatIndexer, DenseFlatIndexer

vector_size = bert_config.hidden_size
dense_indexer = DenseHNSWFlatIndexer(vector_size, args.index_buffer_size)
dense_indexer.deserialize_from(args.index_prefix_path)

Loading index from data/index/mdr/hotpot-paragraph-q-strict.hnsw
Loaded index of type <class 'faiss.swigfaiss.IndexHNSWFlat'> and size 5232080


In [11]:
from retriever import DenseRetriever

dense_retriever = DenseRetriever(dense_indexer, query_encoder, tokenizer)

## Load corpus

In [12]:
from utils.data_utils import load_corpus

corpus, title2id = load_corpus('data/corpus/hotpot-paragraph-5.tsv', for_hotpot=True, require_hyperlinks=True)
print(len(corpus))
print(len(title2id))

Loaded 5232080 passages from data/corpus/hotpot-paragraph-5.tsv


5232080
5232080


## Generate greedy step transition data

In [13]:
set_global_logging_level(logging.WARNING, ["elasticsearch"])
logger.setLevel(logging.WARNING)

In [14]:
import redis

query_redis = redis.Redis(host='10.60.1.79', port=6379, db=2, password='redis4zyc', decode_responses=True)
bm25_redis = redis.Redis(host='10.60.1.79', port=6379, db=3, password='redis4zyc', decode_responses=True)  # 0
mdr_redis = redis.Redis(host='10.60.1.79', port=6379, db=4, password='redis4zyc', decode_responses=True)  # 1

In [41]:
from html import unescape
from tqdm.auto import tqdm

from utils.data_utils import get_valid_links

faiss.omp_set_num_threads(1)

is_strict = True
step_data_file = open(f"data/hotpot-step-{split}{'.strict' if is_strict else ''}.jsonl", 'w')

OBS_SIZE = 1
RET_SIZE = 1000
# HOP_MAX_STEPS = (RET_SIZE + OBS_SIZE - 1) // OBS_SIZE
hotpot_filter = {"term": {"for_hotpot": True}}
examples = []
for q_idx, (q_id, qas) in enumerate(tqdm(samples)):
    question, answer, sp_facts = qas
    if len(sp_facts) < 2:
        print(f"less than 2 supporting facts: {q_id}")
    norm_sp_titles = set(unescape(t) for t in sp_facts.keys())
    sp_ids = list(title2id[t] for t in norm_sp_titles)

    hard_negs = set()
    state2action = dict()

    # ==================== initial ====================
    sp_ranks = {strategy: {sp_id: 2 * RET_SIZE for sp_id in sp_ids}
                for strategy in ["BM25", "BM25+Link", "MDR", "MDR+Link"]}

    # BM25
    if not query_redis.exists(question):
        query_redis.set(question, qg1.predict(question, question)[0][0].strip())
    q1 = query_redis.get(question)
    if not bm25_redis.exists(q1) or (bm25_redis.llen(q1) < RET_SIZE and bm25_redis.lindex(q1, -1) != 'EOL'):
        hits = [hit['_id'] for hit in sparse_retriever.search(q1, RET_SIZE, filter_dic=hotpot_filter, n_retrieval=RET_SIZE * 2)]
        bm25_redis.delete(q1)
        if len(hits) < RET_SIZE:
            bm25_redis.rpush(q1, *(hits + ['EOL']))
        else:
            bm25_redis.rpush(q1, *hits)
    bm25_hits = bm25_redis.lrange(q1, 0, -1)
    if bm25_hits[-1] == 'EOL':
        bm25_hits = bm25_hits[:-1]
    for p_idx, p_id in enumerate(bm25_hits[:RET_SIZE]):
        hyperlinks = get_valid_links(corpus[p_id], is_strict, title2id)
        if p_id in sp_ids:
            sp_ranks['BM25'][p_id] = min(p_idx, sp_ranks['BM25'][p_id])
            for hard_id in bm25_hits[max(0, p_idx - 5):p_idx + 6]:
                hard_negs.add(hard_id)
        elif len(norm_sp_titles & set(hyperlinks.keys())) > 0:  # and p_idx < min(sp_ranks['BM25'].values()):  # xxx
            hard_negs.add(p_id)
        for sp_title in norm_sp_titles & set(hyperlinks.keys()):
            sp_id = title2id[sp_title]
            sp_ranks['BM25+Link'][sp_id] = min(p_idx, sp_ranks['BM25+Link'][sp_id])
        if max(list(sp_ranks['BM25'].values()) + list(sp_ranks['BM25+Link'].values())) <= p_idx:
            break
    for hard_id in bm25_hits[:10]:
        hard_negs.add(hard_id)

    # MDR
    qk = questions[q_idx]
#     if not mdr_redis.exists(qk) or (mdr_redis.llen(qk) < RET_SIZE and mdr_redis.lindex(qk, -1) != 'EOL'):
#         hits = dense_retriever.search(questions[q_idx], max(RET_SIZE, 1000), args.max_q_len)[0]
#         mdr_redis.delete(qk)
#         if len(hits) < RET_SIZE:
#             mdr_redis.rpush(qk, *(hits + ['EOL']))
#         else:
#             mdr_redis.rpush(qk, *hits)
    mdr_hits = mdr_redis.lrange(qk, 0, -1)
    if mdr_hits[-1] == 'EOL':
        mdr_hits = mdr_hits[:-1]
    assert len(mdr_hits) > 0
    for p_idx, p_id in enumerate(mdr_hits[:RET_SIZE]):
        hyperlinks = get_valid_links(corpus[p_id], is_strict, title2id)
        if p_id in sp_ids:
            sp_ranks['MDR'][p_id] = min(p_idx, sp_ranks['MDR'][p_id])
            for hard_id in mdr_hits[max(0, p_idx - 5):p_idx + 6]:
                hard_negs.add(hard_id)
        elif len(norm_sp_titles & set(hyperlinks.keys())) > 0:  # and p_idx < min(sp_ranks['MDR'].values()):  # xxx
            hard_negs.add(p_id)
        for sp_title in norm_sp_titles & set(hyperlinks.keys()):
            sp_id = title2id[sp_title]
            sp_ranks['MDR+Link'][sp_id] = min(p_idx, sp_ranks['MDR+Link'][sp_id])
        if max(list(sp_ranks['MDR'].values()) + list(sp_ranks['MDR+Link'].values())) <= p_idx:
            break
    for hard_id in mdr_hits[:10]:
        hard_negs.add(hard_id)

    # get the greedy initial action
    if min(min(_sp_ranks.values()) for _sp_ranks in sp_ranks.values()) >= RET_SIZE:
        print(f"Unable recall SP1 through the first {RET_SIZE} retrieval results: {q_id}")
        state2action['initial'] = {"query": q1, "action": "ANSWER", "sp_ranks": sp_ranks}
    else:
        # calculate the number of step to get SPs
        easy_steps, hard_steps = {}, {}
        for strategy, _sp_ranks in sp_ranks.items():
            (easy_sp_id, easy_sp_rank), (hard_sp_id, hard_sp_rank) = sorted(_sp_ranks.items(), key=lambda x: x[1])
            easy_steps[strategy] = (easy_sp_rank + OBS_SIZE) // OBS_SIZE
            hard_steps[strategy] = (hard_sp_rank + OBS_SIZE) // OBS_SIZE
            if strategy.endswith('+Link'):
                easy_steps[strategy] += 1
                hard_steps[strategy] += 1
        # find the fastest strategy
        #best_strategy = min(easy_steps.keys(), key=lambda k: (easy_steps[k], hard_steps[k]))
        #state2action['initial'] = {"query": q1, "action": "BM25" if best_strategy.startswith('BM25') else "MDR", "sp_ranks": sp_ranks}
        best_ret_func = min(['BM25', 'MDR'], key=lambda k: sorted(list(sp_ranks[k].values()) + [x + 1 for x in sp_ranks[f'{k}+Link'].values()]))
        state2action['initial'] = {"query": q1, "action": best_ret_func, "sp_ranks": sp_ranks}

    # ==================== partial ====================
    for (sp1_id, sp2_id) in [sp_ids, sp_ids[::-1]]:
        sp1 = corpus[sp1_id]
        norm_sp1_title = unescape(sp1['title'])
        sp2 = corpus[sp2_id]
        norm_sp2_title = unescape(sp2['title'])
        sp2_ranks = {strategy: 2 * RET_SIZE for strategy in ["BM25", "BM25+Link", "MDR", "MDR+Link"]}

        # BM25
        obs = ' '.join([question, f"<t> {sp1['title']} </t> {sp1['text'][sp1['sentence_spans'][0][0]:sp1['sentence_spans'][-1][1]]}"])
        if not query_redis.exists(obs):
            query_redis.set(obs, qg2.predict(obs, question)[0][0].strip())
        q2 = query_redis.get(obs)
        if not bm25_redis.exists(q2) or (bm25_redis.llen(q2) < RET_SIZE and bm25_redis.lindex(q2, -1) != 'EOL'):
            hits = [hit['_id'] for hit in sparse_retriever.search(q2, RET_SIZE, filter_dic=hotpot_filter, n_retrieval=RET_SIZE * 2)]
            bm25_redis.delete(q2)
            if len(hits) < RET_SIZE:
                bm25_redis.rpush(q2, *(hits + ['EOL']))
            else:
                bm25_redis.rpush(q2, *hits)
        bm25_hits = bm25_redis.lrange(q2, 0, -1)
        if bm25_hits[-1] == 'EOL':
            bm25_hits = bm25_hits[:-1]
        for p_idx, p_id in enumerate(bm25_hits[:RET_SIZE]):
            hyperlinks = get_valid_links(corpus[p_id], is_strict, title2id)
            if p_id == sp2_id:
                sp2_ranks['BM25'] = min(p_idx, sp2_ranks['BM25'])
                for hard_id in bm25_hits[max(0, p_idx - 5):p_idx + 6]:
                    hard_negs.add(hard_id)
            elif norm_sp2_title in hyperlinks.keys():  # and p_idx < sp2_ranks['BM25']:
                hard_negs.add(p_id)
                sp2_ranks['BM25+Link'] = min(p_idx, sp2_ranks['BM25+Link'])
            if max(sp2_ranks['BM25'], sp2_ranks['BM25+Link']) <= p_idx:
                break
        for hard_id in bm25_hits[:10]:
            hard_negs.add(hard_id)

        # MDR
        qk = f"{questions[q_idx]}\t+++\t{unescape(sp1['title'])}"
        if not mdr_redis.exists(qk) or (mdr_redis.llen(qk) < RET_SIZE and mdr_redis.lindex(qk, -1) != 'EOL'):
            expansion = sp1['text']
            expansion = expansion[sp1['sentence_spans'][0][0]:sp1['sentence_spans'][-1][1]]  # if strict
            expanded_query = (questions[q_idx], expansion if expansion else sp1['title'])
            hits = dense_retriever.search(expanded_query, max(RET_SIZE, 1000), args.max_q_sp_len)[0]
            mdr_redis.delete(qk)
            if len(hits) < RET_SIZE:
                mdr_redis.rpush(qk, *(hits + ['EOL']))
            else:
                mdr_redis.rpush(qk, *hits)
        mdr_hits = mdr_redis.lrange(qk, 0, -1)
        if mdr_hits[-1] == 'EOL':
            mdr_hits = mdr_hits[:-1]
        assert len(mdr_hits) > 0
        for p_idx, p_id in enumerate(mdr_hits[:RET_SIZE]):
            hyperlinks = get_valid_links(corpus[p_id], is_strict, title2id)
            if p_id == sp2_id:
                sp2_ranks['MDR'] = min(p_idx, sp2_ranks['MDR'])
                for hard_id in mdr_hits[max(0, p_idx - 5):p_idx + 6]:
                    hard_negs.add(hard_id)
            elif norm_sp2_title in hyperlinks.keys():  # p_idx < sp2_ranks['MDR'] and 
                hard_negs.add(p_id)
                sp2_ranks['MDR+Link'] = min(p_idx, sp2_ranks['MDR+Link'])
            if max(sp2_ranks['MDR'], sp2_ranks['MDR+Link']) <= p_idx:
                break
        for hard_id in mdr_hits[:10]:
            hard_negs.add(hard_id)

        # get the greedy subsequent action
        if norm_sp2_title in set(unescape(t) for t in get_valid_links(sp1, is_strict, title2id).keys()):
            state2action[norm_sp1_title] = {"query": q2, "action": "LINK", "sp2_ranks": sp2_ranks}
        else:
            if min(sp2_ranks.values()) >= RET_SIZE:
                print(f"{q_id}: Unable recall SP2({repr(norm_sp2_title)}) through the first {RET_SIZE} retrieval(Q + {repr(norm_sp1_title)}) results")
                state2action[norm_sp1_title] = {"query": q2, "action": "ANSWER", "sp2_ranks": sp2_ranks}
            else:
                # calculate the number of step to get SP2
                sp2_steps = dict()
                for strategy, sp2_rank in sp2_ranks.items():
                    sp2_steps[strategy] = (sp2_rank + OBS_SIZE) // OBS_SIZE
                    if strategy.endswith('+Link'):
                        sp2_steps[strategy] += 1
                #best_strategy = min(sp2_steps.keys(), key=lambda k: sp2_steps[k])
                #state2action[norm_sp1_title] = {"query": q2, "action": "BM25" if best_strategy.startswith('BM25') else "MDR", "sp2_ranks": sp2_ranks}
                best_ret_func = min(['BM25', 'MDR'], key=lambda k: sorted([sp2_steps[k], sp2_steps[f'{k}+Link']]))
                state2action[norm_sp1_title] = {"query": q2, "action": best_ret_func, "sp2_ranks": sp2_ranks}
    
    # add SPs' out-neighbors to hard negatives
    for sp_id in sp_ids:
        for out_title in get_valid_links(corpus[sp_id], is_strict, title2id):
            hard_negs.add(title2id[out_title])
    # remove SPs from hard negatives
    hard_negs = hard_negs - set(sp_ids)
    
    example = {
        "_id": q_id,
        "question": question,
        "answer": answer,
        "sp_facts": {unescape(t): sents for t, sents in sp_facts.items()},
        "hard_negs": list(hard_negs),  # in- and out- neighbors of SPs, top ranked passages
        "state2action": state2action
    }
    examples.append(example)
    step_data_file.write(json.dumps(example, ensure_ascii=False) + '\n')

step_data_file.close()

HBox(children=(FloatProgress(value=0.0, max=90447.0), HTML(value='')))

5adf8d5b5542995534e8c7e1: Unable recall SP2('2016 Marrakesh ePrix') through the first 1000 retrieval(Q + 'Sébastien Buemi') results
5abbbd0f55429931dba144d5: Unable recall SP2('Hot Feet') through the first 1000 retrieval(Q + 'Maurice Hines') results
5ab1ce395542997061209579: Unable recall SP2('The Adventures of Brer Rabbit') through the first 1000 retrieval(Q + "Br'er Rabbit") results
5ac2c958554299657fa2905b: Unable recall SP2('Be Right There') through the first 1000 retrieval(Q + 'The Warriors Gate') results
5ae1130a5542997b2ef7d0d3: Unable recall SP2('John Wallace Crawford') through the first 1000 retrieval(Q + 'American Horse (elder)') results
5a80271a5542992e7d278dfa: Unable recall SP2("Minnesota's 10th congressional district") through the first 1000 retrieval(Q + "Minnesota's 3rd congressional district") results
5a77732b55429967ab105183: Unable recall SP2('Le piccole storie') through the first 1000 retrieval(Q + 'Julius Caesar (play)') results
5a879c8e5542994846c1cdb3: Unable rec