In [1]:
import random
import numpy as np
from tqdm import tqdm
import spacy
nlp = spacy.load('en_core_web_sm')
import warnings
warnings.filterwarnings('ignore')

# Data Loading

In [2]:
from utils import Data
data_processor = Data(root_path="../")
data_processor.read_in_memory()

正在处理文件collection.sampled.tsv 读取文件的格式为('pid', 'passage')
正在处理文件train_sample_queries.tsv 读取文件的格式为('qid', 'query')
正在处理文件train_sample_passv2_qrels.tsv 读取文件的格式为('qid', 'mark', 'pid', 'rating')
正在处理文件val_2021_53_queries.tsv 读取文件的格式为('qid', 'query')
正在处理文件val_2021_passage_top100.txt 读取文件的格式为('qid', 'mark', 'pid', 'rank', 'score', 'sys_id')
正在处理文件val_2021.qrels.pass.final.txt 读取文件的格式为('qid', 'mark', 'pid', 'rating')
正在处理文件test_2022_76_queries.tsv 读取文件的格式为('qid', 'query')
正在处理文件test_2022_passage_top100.txt 读取文件的格式为('qid', 'mark', 'pid', 'rank', 'score', 'sys_id')
正在处理文件test_2022.qrels.pass.withDupes.txt 读取文件的格式为('qid', 'mark', 'pid', 'rating')


In [3]:
for f in data_processor.dataset.keys():
    print(f)

collection.sampled
train_sample_queries
train_sample_passv2_qrels
val_2021_53_queries
val_2021_passage_top100
val_2021.qrels.pass.final
test_2022_76_queries
test_2022_passage_top100
test_2022.qrels.pass.withDupes


In [4]:
set_passage_id = set(data_processor.dataset['collection.sampled'].keys())
print("Total number of passages: ", len(set_passage_id))

Total number of passages:  126799


# DocT5query

In [5]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

t5_path = "../model/docT5query/"

tokenizer = T5Tokenizer.from_pretrained(t5_path)
dt5q = T5ForConditionalGeneration.from_pretrained(t5_path)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
def generate_queries(passage: str, prefix="text2query", queries_num=2):
    input = prefix + ": " + passage
    input_ids = tokenizer.encode(
        input,
        return_tensors="pt",
        add_special_tokens=True,
        max_length=512,
        truncation=True,
    )
    outputs = dt5q.generate(
        input_ids=input_ids,
        max_length=512,
        do_sample=True,
        top_p=0.95,
        num_return_sequences=queries_num,
    )
    generated_queries = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return generated_queries

# Data Generation

In [7]:
generation_dict = {}

In [8]:
for qid, v in tqdm(data_processor.dataset['train_sample_passv2_qrels'].items()):
    
    query = data_processor.dataset['train_sample_queries'][qid]['query']
    # print(f"The question {qid} is: {query}")

    positive_pid = v.keys().__iter__().__next__()
    positive_passage = data_processor.dataset['collection.sampled'][positive_pid]['passage']
    # print(f"The positive passage {positive_pid} is: {positive_passage}")
    sample_passage_id = random.sample(list(set_passage_id - set([positive_pid])), 10)
    # print(sample_passage_id)
    sampled_passages = [
        nlp(data_processor.dataset['collection.sampled'][pid]['passage'])
        for pid in sample_passage_id
    ]
    doc = nlp(positive_passage)
    similarity = [
        doc.similarity(sampled_passage)
        for sampled_passage in sampled_passages
    ]
    # 选择相似度最低的4个passage作为负样本
    negative_passages_id = [
        sample_passage_id[i]
        for i in np.argsort(similarity)[:4]
    ]
    # print(f"The negative passages id are: {negative_passages_id}")
    negative_passages = [
        data_processor.dataset['collection.sampled'][i]['passage']
        for i in negative_passages_id
    ]
    # print(f"The negative passages are: {negative_passages}")
    # print(f"The similarity between the positive passage and sampled passages are: {similarity}")
    
    # 根据正样本生成2个query
    generated_queries = generate_queries(positive_passage)
    # print(f"The generated queries are: {generated_queries}")
    # break
    generation_dict[qid] = {
        "positive_passage_id": [positive_pid],
        "negative_passages_id": negative_passages_id,
        "generated_queries": generated_queries,
    }
    break

  0%|          | 0/20000 [00:01<?, ?it/s]


In [9]:
generation_dict

{'1185869': {'positive_passage_id': ['msmarco_passage_08_840101254'],
  'negative_passages_id': ['msmarco_passage_38_192481906',
   'msmarco_passage_20_341886578',
   'msmarco_passage_32_869755311',
   'msmarco_passage_11_205128414'],
  'generated_queries': ['which major achievement of the early 20th century was the manhattan project?',
   'why was the manhattan project important?']}}