In [1]:
import json
from tqdm.notebook import tqdm
from datasets import load_dataset
from sklearn.model_selection import train_test_split

In [2]:
dataset = load_dataset("ms_marco", split='train', name='v2.1')

Found cached dataset ms_marco (/home/admin/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [3]:
print(len(dataset))

808731


Уберу бесполезные параграфы

In [4]:
def filter_passages(passages, is_selected):
    zipped = zip(passages, is_selected)
    return list(filter(lambda x: x[1], zipped))

In [5]:
n_all, n_substr = 0, 0

In [53]:
for data in tqdm(dataset):
    passages = filter_passages(data['passages']['passage_text'], data['passages']['is_selected'])
    if passages == []:
        continue
    answer = data['answers'][0]
    n_all += 1
    for passage in passages:
        if passage[0].lower().find(answer.lower()) != -1:
            n_substr += 1
            break

  0%|          | 0/808731 [00:00<?, ?it/s]

Доля подстрок

In [54]:
n_all

502939

In [55]:
n_substr / n_all

0.541503045100897

**Алгоритм обработки датасета**

Убрать все неиспользуемые параграфы. Убрать все записи без используемых параграфов. Отфильтировать только ответы состоящие из <=50 токенов. Оставить лишь первый ответ. Убрать записи без ответов. Оставить лишь записи, в которых будет не более 500 токенов.

In [5]:
import deeppavlov
from deeppavlov.models.preprocessors.torch_transformers_preprocessor import TorchTransformersGenerativeQAPreprocessor
import transformers
from transformers import AutoTokenizer

[nltk_data] Downloading package punkt to /home/admin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/admin/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package perluniprops to
[nltk_data]     /home/admin/nltk_data...
[nltk_data]   Package perluniprops is already up-to-date!
[nltk_data] Downloading package nonbreaking_prefixes to
[nltk_data]     /home/admin/nltk_data...
[nltk_data]   Package nonbreaking_prefixes is already up-to-date!


In [32]:
tokenizer = TorchTransformersGenerativeQAPreprocessor("google/mt5-base", answer_maxlength=-1)

In [40]:
def form_dataset(dataset):
    new_dataset = []
    for data in tqdm(dataset):
        passages = filter_passages(data['passages']['passage_text'], data['passages']['is_selected'])
        if passages == []:
            continue
        answer = data['answers'][0]
        passages = [passage[0] for passage in passages]
        question = data['query']
        
        tokens, mask, ans = tokenizer([question], [passages], [answer])
        if len(tokens[0]) <= 200 and 5 <= len(ans[0]) <= 30:
            new_dataset.append([question, passages, answer])
    return new_dataset

In [41]:
new_dataset = form_dataset(dataset)

  0%|          | 0/808731 [00:00<?, ?it/s]

In [42]:
print(len(new_dataset))

274119


In [43]:
with open('../../datasets/msmarco/train_preprocessed.json', 'w', encoding='utf-8') as f:
    json.dump(new_dataset, f, ensure_ascii=False)

In [50]:
dataset = load_dataset("ms_marco", split='validation', name='v2.1')

Found cached dataset ms_marco (/home/admin/.cache/huggingface/datasets/ms_marco/v2.1/2.1.0/b6a62715fa5219aea5275dd3556601004cd63945cb63e36e022f77bb3cbbca84)


In [51]:
new_dataset = form_dataset(dataset)
with open('../../datasets/msmarco/dev_preprocessed.json', 'w', encoding='utf-8') as f:
    json.dump(new_dataset, f, ensure_ascii=False)

  0%|          | 0/101093 [00:00<?, ?it/s]

In [57]:
with open('../../datasets/msmarco/train_preprocessed.json', 'r', encoding='utf-8') as f:
    train = json.load(f)
with open('../../datasets/msmarco/dev_preprocessed.json', 'r', encoding='utf-8') as f:
    dev = json.load(f)
new_dataset = {"train": [], "valid": [], "test": []}
new_dataset['train'] = train
valid, test = train_test_split(dev, test_size=0.5, random_state=2022)
new_dataset['valid'], new_dataset['test'] = valid, test

with open('../../datasets/msmarco/ms_marco_preprocessed.json', 'w', encoding='utf-8') as f:
    json.dump(new_dataset, f, ensure_ascii=False)

In [58]:
nd = {}
for b in new_dataset.keys():
    nd[b] = []
    for question, contexts, answer in new_dataset[b]:
        nd[b].append([[question, contexts], answer])

In [59]:
with open('../../datasets/msmarco/ms_marco_preprocessed.json', 'w', encoding='utf-8') as f:
    json.dump(nd, f, ensure_ascii=False)

In [62]:
for d in nd['valid']:
    if len(d[0][1]) > 1:
        print(d)

[['who won hr derby', ['Sign Up for the ASG Newsletter. Todd Frazier hit a homer in bonus time to beat Joc Pederson, 15-14, becoming the first to win on his home field since 1990. There are three different player modes in the free update of the Home Run Derby 15 game.', 'Giancarlo Stanton #27 of the Miami Marlins has the longest home runs and is has the best odds in Las Vegas to win the 2016 T-Mobile Home Run Derby (Photo by Rich Schultz/Getty Images) So, the T-Mobile Home Run Derby from San Diego at Petco Park airs tonight at 8 p.m. ET/5 PT on Walt Disney-owned ESPN and simulcast on MLB.com.']], 'Todd Frazier won the Home Run Derby 15 and Giancarlo Stanton won 2016 T-Mobile Home Run Derby.']
[["north america's longest rivers", ['The longest river in North America is the Missouri River, actually just a couple of kilometers longer than the Mississippi, which it joins at St. Louis, Missouri. Next is the Yukon river in Northwest Canada and Alaska, then the Rio Grande which forms part of t