# 处理流程
- 将原始数据集随机打乱，得到**乱序数据集**；
- 可视化乱序数据集的上下文长度分布；
- 保存乱序数据集。

In [1]:
from abc import abstractmethod
import gc
import logging
import random
from tqdm import tqdm
from transformers import StoppingCriteria
import numpy as np
from collections import defaultdict
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import StoppingCriteriaList

# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()

import argparse
import pickle

import logging
import os
import json
import hashlib
import datasets

from collections import Counter
import matplotlib.pyplot as plt
from datasets.arrow_dataset import Dataset

# 使用镜像
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

In [2]:
# 加载数据集
from datasets import load_dataset
dataset = load_dataset("google-research-datasets/natural_questions", "default")
print(dataset)

data_dict = defaultdict(dict)
for split in dataset.keys():
    data_dict[split] = defaultdict(dict)
for split, ds in dataset.items():
    for d in tqdm(dataset[split], desc=split):
        data_dict[split][d['id']]['id'] = d['id']
        data_dict[split][d['id']]['question'] = d['question']['text']
        data_dict[split][d['id']]['document'] = d['document']['html']
        answers = []
        for sa in d['annotations']['short_answers']:
            answers.extend(sa['text'])
        data_dict[split][d['id']]['answers'] = answers
    print(f"Split: {split}, Number of examples: {len(data_dict[split])}")

Resolving data files:   0%|          | 0/287 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/287 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/235 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'document', 'question', 'long_answer_candidates', 'annotations'],
        num_rows: 307373
    })
    validation: Dataset({
        features: ['id', 'document', 'question', 'long_answer_candidates', 'annotations'],
        num_rows: 7830
    })
})


train:  62%|██████▏   | 189562/307373 [51:02<31:43, 61.90it/s]  


KeyboardInterrupt: 

In [None]:
print(list(data_dict['train'].keys())[0])
print(list(data_dict['train'].values())[0])

In [None]:
for split, sub_dict in data_dict.items():
    for id, d in sub_dict.items():
        answers = []
        for a in d['answers']:
            answers.append(a)
        d['answers'] = answers

In [None]:
# 统计答案长度
answer_lengths = Counter()
for split, sub_dict in data_dict.items():
    for id, d in sub_dict.items():
        for a in d['answers']:
            answer_lengths[len(a)] += 1

    # 可视化答案长度
    plt.bar(answer_lengths.keys(), answer_lengths.values())
    plt.xlabel('Answer Length')
    plt.ylabel('Count')
    plt.title(f'{split} Answer Length Distribution')
    plt.show()

In [None]:
def filter_short_answers(example):
    return len(example["annotations"]["short_answers"]) > 0
dataset = dataset.filter(filter_short_answers)
dataset.save_to_disk("/home/song/dataset/song/nq_short")

In [None]:
dataset['train'][0].keys()

In [None]:
dataset['train'][0]['question']['text']

In [None]:
d = dataset['train'][0]
t = d['document']['html']
start_byte = d['annotations']['short_answers'][0]['start_byte'][0]
end_byte = d['annotations']['short_answers'][0]['end_byte'][0]
print(start_byte, end_byte)
print(t[start_byte:end_byte])
print(d['annotations']['short_answers'][0]['text'])

In [None]:
# 结论：每个样例都只有一个短答案，即len(d['annotations']['short_answers']) == 1



total = len(dataset['train'])
from tqdm import tqdm
short_answers_count = [] # 每个问题的短答案数量
short_answers_len = [] # 每个短答案的文本长度
for d in tqdm(dataset['train']):
    for sa in d['annotations']['short_answers']:
        short_answers_count.append(len(sa['text']))
        for text in sa['text']:
            short_answers_len.append(len(text))

# 统计短答案数量的分布
plt.hist(short_answers_count, bins=20)
plt.show()
# 统计短答案文本长度的分布
plt.hist(short_answers_len, bins=20)
plt.show()

In [None]:
dataset['train'][0]

In [None]:
# todo: 下面代码未修改

In [None]:
reformat = lambda x: {
    'id': x['id'],
    'question': x['question'],
    'context': x['context'],
    'answers': x['answers']['text'],
}
# filter out examples without answers
train_dataset = [reformat(d) for d in dataset["train"] if d['answers']['text']]
validation_dataset = [reformat(d) for d in dataset["validation"] if d['answers']['text']]
print(len(train_dataset))
print(len(validation_dataset))
print(train_dataset[0])
print(validation_dataset[0])

In [None]:
# 固定随机数种子
random.seed(42)
# 将原数据集打乱
random.shuffle(train_dataset)
random.shuffle(validation_dataset)

# 将原train数据集按9:1分为train, validation数据集
total_size = len(train_dataset)
train_size = int(total_size * 0.9)
new_train_dataset = train_dataset[:train_size]
new_validation_dataset = train_dataset[train_size:]

# 将原validation数据集作为test数据集
new_test_dataset = validation_dataset

In [None]:
# 添加无关上下文
def add_irrelevant_context(dataset):
    # 构建id-context字典
    id_context_dict = {d['id']: d['context'] for d in dataset}
    # 数据集id列表
    id_list = [d['id'] for d in dataset]

    while True:
        # 生成新的id列表
        shuffled_id_list = id_list.copy()
        random.shuffle(shuffled_id_list)
        # 确认新旧id列表对应位置不相同
        if all(shuffled_id_list[i] != id_list[i] for i in range(len(id_list))):
            break

    # 在原数据集中增加字段
    for i in range(len(dataset)):
        dataset[i]['irrelevant_id'] = shuffled_id_list[i]
        dataset[i]['irrelevant_context'] = id_context_dict[shuffled_id_list[i]]

    return dataset

new_train_dataset = add_irrelevant_context(new_train_dataset)
print(new_train_dataset)
new_validation_dataset = add_irrelevant_context(new_validation_dataset)
print(new_validation_dataset)
new_test_dataset = add_irrelevant_context(new_test_dataset)
print(new_test_dataset)

In [None]:
# 可视化上下文长度分布
def plot_context_length(dataset):
    context_lengths = [len(d['context']) for d in dataset]
    plt.hist(context_lengths, bins=100, edgecolor='black')
    plt.xlabel('Context Length')
    plt.ylabel('Count')
    plt.title('Context Length Distribution in Filtered Dataset')
    plt.show()

plot_context_length(train_dataset)
plot_context_length(validation_dataset)

In [None]:
# 固定随机数种子
random.seed(42)
# 将原数据集打乱
random.shuffle(train_dataset)
random.shuffle(validation_dataset)

# 将原train数据集按9:1分为train, validation数据集
total_size = len(train_dataset)
train_size = int(total_size * 0.9)
new_train_dataset = train_dataset[:train_size]
new_validation_dataset = train_dataset[train_size:]

# 将原validation数据集作为test数据集
new_test_dataset = validation_dataset

# 合并成一个DatasetDict
dataset = datasets.DatasetDict({
    'train': datasets.Dataset.from_list(new_train_dataset),
    'validation': datasets.Dataset.from_list(new_validation_dataset),
    'test': datasets.Dataset.from_list(new_test_dataset),
})
print(dataset)

# 保存新数据集
dataset.save_to_disk("/Users/song/datasets/song/squad")

In [None]:
dataset