In [4]:
import torch
import pandas as pd
import numpy as np

# collate_fn

## 重新定义一个函数

In [17]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese", model_max_length=15)

In [7]:
data = pd.DataFrame({"text_a": ["糖尿病该吃什么", "为什么我有高血压", "甲状腺该吃什么", "为什么甲状腺激素身高"], \
                     "text_b": ["糖尿病患者饮食推荐", "高血压吃啥合适", "甲状腺病因", "甲状腺激素身高理由"], \
                     "label": [1, 0, 0, 1]})

In [6]:
from torch.utils.data import Dataset, DataLoader

In [25]:
data.to_dict("records")

[{'text_a': '糖尿病该吃什么', 'text_b': '糖尿病患者饮食推荐', 'label': 1},
 {'text_a': '为什么我有高血压', 'text_b': '高血压吃啥合适', 'label': 0},
 {'text_a': '甲状腺该吃什么', 'text_b': '甲状腺病因', 'label': 0},
 {'text_a': '为什么甲状腺激素身高', 'text_b': '甲状腺激素身高理由', 'label': 1}]

In [26]:
class PairDataset(Dataset):
    def __init__(self, df):
        self.data = df.to_dict("records")
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

In [44]:
def collate_fn_intra(batch):
    # 此处的batch，是由dataset中__getitem__获取的数据列表list
    # print(batch)
    # print(type(batch))
    # print(type(batch[0]))

    text_a = [x['text_a'] for x in batch]
    text_b = [x['text_b'] for x in batch]
    labels = [x['label'] for x in batch]

    batch_data = {}
    batch_a = tokenizer(text_a, padding=True, return_tensors='pt')
    batch_b = tokenizer(text_b, padding=True, return_tensors='pt')
    batch_data['input_a'] = batch_a['input_ids']
    batch_data['input_b'] = batch_b['input_ids']
    batch_data['label'] = torch.tensor(labels)

    return batch_data

In [46]:
ds = PairDataset(data)

dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0, drop_last=False, collate_fn=collate_fn_intra)

for x in dl:
    print(x)
    print('-'*50)

{'input_a': tensor([[ 101, 5131, 2228, 4567, 6421, 1391,  784,  720,  102,    0],
        [ 101,  711,  784,  720, 2769, 3300, 7770, 6117, 1327,  102]]), 'input_b': tensor([[ 101, 5131, 2228, 4567, 2642, 5442, 7650, 7608, 2972, 5773,  102],
        [ 101, 7770, 6117, 1327, 1391, 1567, 1394, 6844,  102,    0,    0]]), 'label': tensor([1, 0])}
--------------------------------------------------
{'input_a': tensor([[ 101, 4508, 4307, 5593, 6421, 1391,  784,  720,  102,    0,    0,    0],
        [ 101,  711,  784,  720, 4508, 4307, 5593, 4080, 5162, 6716, 7770,  102]]), 'input_b': tensor([[ 101, 4508, 4307, 5593, 4567, 1728,  102,    0,    0,    0,    0],
        [ 101, 4508, 4307, 5593, 4080, 5162, 6716, 7770, 4415, 4507,  102]]), 'label': tensor([0, 1])}
--------------------------------------------------


## 利用lambda

直接喂文本，tokenizer留给模型去做

In [94]:
ds = PairDataset(data)

dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0, drop_last=False, collate_fn=lambda x: x)

for x in dl:
    print(x)
    print('-'*50)

[{'text_a': '糖尿病该吃什么', 'text_b': '糖尿病患者饮食推荐', 'label': 1}, {'text_a': '为什么我有高血压', 'text_b': '高血压吃啥合适', 'label': 0}]
--------------------------------------------------
[{'text_a': '甲状腺该吃什么', 'text_b': '甲状腺病因', 'label': 0}, {'text_a': '为什么甲状腺激素身高', 'text_b': '甲状腺激素身高理由', 'label': 1}]
--------------------------------------------------


In [95]:
def input_text(batch):
    text_a = [x['text_a'] for x in batch]
    text_b = [x['text_b'] for x in batch]
    labels = [x['label'] for x in batch]
    
    return text_a, text_b, labels

In [96]:
ds = PairDataset(data)

dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=0, drop_last=False, collate_fn=lambda x: input_text(x))

for x in dl:
    print(x)
    print('-'*50)

(['糖尿病该吃什么', '为什么我有高血压'], ['糖尿病患者饮食推荐', '高血压吃啥合适'], [1, 0])
--------------------------------------------------
(['甲状腺该吃什么', '为什么甲状腺激素身高'], ['甲状腺病因', '甲状腺激素身高理由'], [0, 1])
--------------------------------------------------


In [92]:
# 可以写处理函数传一些其他参数进去

def collate_fn_intra(batch, tokenizer):
    # 此处的batch，是由dataset中__getitem__获取的数据列表list
    # print(batch)
    # print(type(batch))
    # print(type(batch[0]))
    
    print(batch)
    print('*'*20)
    text_a = [x['text_a'] for x in batch]
    text_b = [x['text_b'] for x in batch]
    labels = [x['label'] for x in batch]

    batch_data = {}
    batch_a = tokenizer(text_a, padding=True, return_tensors='pt')
    batch_b = tokenizer(text_b, padding=True, return_tensors='pt')
    batch_data['input_a'] = batch_a['input_ids']
    batch_data['input_b'] = batch_b['input_ids']
    batch_data['label'] = torch.tensor(labels)

    return batch_data

In [93]:
ds = PairDataset(data)

dl = DataLoader(ds, batch_size=2, \
                shuffle=False, \
                num_workers=0, \
                drop_last=False, \
                collate_fn=lambda x: collate_fn_intra(x, tokenizer))

for x in dl:
    print(x)
    print('-'*50)

[{'text_a': '糖尿病该吃什么', 'text_b': '糖尿病患者饮食推荐', 'label': 1}, {'text_a': '为什么我有高血压', 'text_b': '高血压吃啥合适', 'label': 0}]
********************
{'input_a': tensor([[ 101, 5131, 2228, 4567, 6421, 1391,  784,  720,  102,    0],
        [ 101,  711,  784,  720, 2769, 3300, 7770, 6117, 1327,  102]]), 'input_b': tensor([[ 101, 5131, 2228, 4567, 2642, 5442, 7650, 7608, 2972, 5773,  102],
        [ 101, 7770, 6117, 1327, 1391, 1567, 1394, 6844,  102,    0,    0]]), 'label': tensor([1, 0])}
--------------------------------------------------
[{'text_a': '甲状腺该吃什么', 'text_b': '甲状腺病因', 'label': 0}, {'text_a': '为什么甲状腺激素身高', 'text_b': '甲状腺激素身高理由', 'label': 1}]
********************
{'input_a': tensor([[ 101, 4508, 4307, 5593, 6421, 1391,  784,  720,  102,    0,    0,    0],
        [ 101,  711,  784,  720, 4508, 4307, 5593, 4080, 5162, 6716, 7770,  102]]), 'input_b': tensor([[ 101, 4508, 4307, 5593, 4567, 1728,  102,    0,    0,    0,    0],
        [ 101, 4508, 4307, 5593, 4080, 5162, 6716, 7770, 4415, 450

# Datasets与IterableDatasets的区别

在matchzoo-py的源码中看到了`IterableDatasets`,找其区别

- 使用场景：总有些数据不是能一次性读入内存中的，需要iterable对象批量读取

- `iter` 方法的作用是让对象可以用 for...in...循环遍历，`getitem` 方法是让对象可以通过index索引的方式访问实例中的元素。

- 主要实现iter函数获取数据。注意当进行多线程处理数据时，每个worker会复制一份数据，为了避免数据重复，官方给出[两种解决方式](https://pytorch.org/docs/stable/data.html#iterable-style-datasets)


In [88]:
from torch.utils.data import IterableDataset

class IterPairDataset(IterableDataset):
    def __init__(self, nums):
        self.nums = nums
    
    def __iter__(self):
        for i in range(self.nums):
            yield i

In [91]:
iter_ds = IterPairDataset(6)

dl = DataLoader(iter_ds, batch_size=2, num_workers=0)

for x in dl:
    print(x)
    print('-'*50)

tensor([0, 1])
--------------------------------------------------
tensor([2, 3])
--------------------------------------------------
tensor([4, 5])
--------------------------------------------------
