# Dataset 

仿照 `Trasformers` 库实现数据处理流程，具体包含：

原始数据 是用 `list` 或 `dict` 组织的:

```text
List{
    {'input': '天气很好', 'label': 1},
    {'input': '你好, 'label': 0},
    ...
}
```
或
```text
Dict{
    'input': ['天气很好', '你好', ...],
    'label': [1, 0, ...],
}
```

Dataset 的目标 是为了能够获取 batch token-id 序列：

1. 一批数据里的 token 序列是不等长的
2. 需要截断数据或填充 `<PAD>`
3. 考虑 shuffle 时，动态 padding
4. 在训练过程中，获取批量数据的 `input_ids`, `attention_masks` 和 `label`
5. Transformer 训练较为特殊的一点是：`label` 也是模型的输入，而不仅仅是用于 求 Loss。
6. `Dict` 版本数据存储数据会更优化, `List`版本 会重复存储 多次 `input` 和 `label` 字段

## config

In [17]:
src_len = 6
trg_len = 12
src_vocab_size = 26
trg_vocab_size = 26

PAD_TOKEN_ID = 0
SOS_TOKEN_ID = 1
EOS_TOKEN_ID = 2
UNK_TOKEN_ID = 3

IGNORE_INDEX = -100

## 等长 Dataset

In [5]:
import torch
from torch.utils.data import Dataset
torch.manual_seed(42)

class Seq2SeqDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return {'input_ids':self.X[idx,:], 'label': self.Y[idx,:]}

def create_dataset(N, src_len, trg_len, src_vocab_size, trg_vocab_size):
    X = torch.randint(3, src_vocab_size, (N, src_len), dtype = torch.long) # id 0,1,2,3 为 special token
    Y = torch.randint(3, trg_vocab_size, (N, trg_len), dtype = torch.long) 
    X[:, 0] = SOS_TOKEN_ID # <SOS>
    X[:, src_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
    Y[:, 0] = SOS_TOKEN_ID # <SOS>
    Y[:, trg_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
    dataset = Seq2SeqDataset(X, Y)
    return dataset

train_dataset = create_dataset(N = 100, 
              src_len = src_len, 
              trg_len = trg_len,  
              src_vocab_size = src_vocab_size, 
              trg_vocab_size = trg_vocab_size)

In [6]:
# 在处理过程中并没有 padding
print('获取 getitem 数据')
print(train_dataset.__getitem__(0)['input_ids'])
print(train_dataset.__getitem__(0)['label'])

# `[]` 运算符获取
print('获取运算符 [] 数据')
print(train_dataset[0]['input_ids'])
print(train_dataset[0]['label'])

# 批量获取 数据
print('获取批量数据')
print(train_dataset[0:3]['input_ids'])
print(train_dataset[0:3]['label'])

获取 getitem 数据
tensor([ 1, 17,  3, 24, 19,  2])
tensor([ 1, 20, 16,  6, 12, 15, 24, 20, 14,  9, 11,  2])
获取运算符 [] 数据
tensor([ 1, 17,  3, 24, 19,  2])
tensor([ 1, 20, 16,  6, 12, 15, 24, 20, 14,  9, 11,  2])
获取批量数据
tensor([[ 1, 17,  3, 24, 19,  2],
        [ 1, 14,  4, 15, 24,  2],
        [ 1, 25,  7, 12, 21,  2]])
tensor([[ 1, 20, 16,  6, 12, 15, 24, 20, 14,  9, 11,  2],
        [ 1,  8, 24, 18,  3, 23, 19, 10, 11,  7, 14,  2],
        [ 1, 19, 22, 16, 13, 24, 22, 19, 15, 13,  4,  2]])


## 变长序列 Dataset 

In [7]:
import torch
import random
from torch.utils.data import Dataset


class Seq2SeqTransformersDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data['input_ids'])
    
    def __getitem__(self, idx):
        return {'input_ids': self.data['input_ids'][idx], 
                'labels': self.data['labels'][idx]}

def create_transformers_dataset(N, src_len, trg_len, src_vocab_size, trg_vocab_size):

    raw_data = {'input_ids':[], 'labels':[]}
    for i in range(N):
        tmp_src_len = random.randint(3, src_len)
        tmp_trg_len = random.randint(3, trg_len)
        x = torch.randint(3, src_vocab_size, (1, tmp_src_len), dtype = torch.long)
        y = torch.randint(3, trg_vocab_size, (1, tmp_trg_len), dtype = torch.long)
        
        x[:, 0] = SOS_TOKEN_ID # <SOS>
        x[:, tmp_src_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
        y[:, 0] = SOS_TOKEN_ID # <SOS>
        y[:, tmp_trg_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
        
        raw_data['input_ids'].append(x)
        raw_data['labels'].append(y)
    
    dataset = Seq2SeqTransformersDataset(raw_data)
    return dataset

src_len = 10
trg_len = 20
src_vocab_size = 26
trg_vocab_size = 26

train_dataset = create_transformers_dataset(N = 100, 
              src_len = src_len, 
              trg_len = trg_len,  
              src_vocab_size = src_vocab_size, 
              trg_vocab_size = trg_vocab_size)

In [8]:
# 在处理过程中并没有 padding
print('获取 getitem 数据')
print(train_dataset.__getitem__(0))

# `[]` 运算符获取
print('获取运算符 [] 数据')
print(train_dataset[0]['input_ids'])
print(train_dataset[0]['labels'])

# 批量获取 数据
print('获取批量数据')
print(train_dataset[0:3]['input_ids'])
print(train_dataset[0:3]['labels']) 

# src, trg 长度都不等长

获取 getitem 数据
{'input_ids': tensor([[ 1, 24, 16, 13, 12, 11,  2]]), 'labels': tensor([[ 1,  5, 19, 16,  2]])}
获取运算符 [] 数据
tensor([[ 1, 24, 16, 13, 12, 11,  2]])
tensor([[ 1,  5, 19, 16,  2]])
获取批量数据
[tensor([[ 1, 24, 16, 13, 12, 11,  2]]), tensor([[ 1, 22, 11, 17, 20,  6,  2]]), tensor([[ 1, 11,  8, 18,  7,  2]])]
[tensor([[ 1,  5, 19, 16,  2]]), tensor([[ 1, 22, 16, 11,  5, 15, 20, 24,  8, 10, 20,  2]]), tensor([[ 1,  8,  3, 11,  6, 18, 22, 24,  2]])]


## Dataloader


在训练过程中，从完整的数据集里，随机 shuffle dataset

In [9]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, 
                    batch_size = 1 ,       # 当调整 batch_size 就报错
                    # collate_fn=collate_fn,
                    pin_memory = True,
                    shuffle = True)

for batch in train_dataloader:
    print(batch)
    print(batch['input_ids'].shape)
    print(batch['labels'].shape)
    break

{'input_ids': tensor([[[ 1,  8, 12,  6, 22,  3, 18, 11, 10,  2]]]), 'labels': tensor([[[ 1, 15, 12,  3, 13, 10,  7, 11, 11, 21, 20, 15, 16,  7, 25,  6, 18,
          24,  2]]])}
torch.Size([1, 1, 10])
torch.Size([1, 1, 19])


## DataCollate

将批数据 padding 处理成等长度，有策略

- 'max_len' : 预设最大长度
- 'longest' : 批量数据最大长度

In [10]:
def collate_fn(batch_data):
    # print(batch_data)
    return batch_data

train_dataloader = DataLoader(train_dataset, 
                    batch_size = 2 ,       # 当调整 batch_size 就报错
                    collate_fn=collate_fn,
                    pin_memory = True,
                    shuffle = True)
    
for batch in train_dataloader:
    print(batch)
    break

[{'input_ids': tensor([[ 1, 13,  2]]), 'labels': tensor([[ 1, 12, 17, 19, 12, 14, 24,  8, 11,  2]])}, {'input_ids': tensor([[1, 7, 2]]), 'labels': tensor([[ 1,  5,  6,  8, 19, 23,  3,  5, 20,  8, 16, 13, 25, 19,  6,  2]])}]


In [11]:
def paddding_collate_fn(batch_data):

    input_lens = []
    label_lens = []
    bs = len(batch_data)
    for data in batch_data:
        input_lens.append( data['input_ids'].shape[1] )
        label_lens.append( data['labels'].shape[1] )

    max_input_len = torch.max(torch.tensor(input_lens, dtype = torch.long))
    max_label_len = torch.max(torch.tensor(label_lens, dtype = torch.long))

    input_ids = torch.ones(bs, max_input_len, dtype = torch.long) * PAD_TOKEN_ID
    input_attention_masks = torch.zeros(bs, max_input_len, dtype = torch.long) 
    label_ids = torch.ones(bs, max_label_len, dtype = torch.long) * PAD_TOKEN_ID
    label_attention_masks = torch.zeros(bs, max_label_len, dtype = torch.long) 

    for i in range(bs):
        input_ids[i, :input_lens[i]] = batch_data[i]['input_ids'][0, :input_lens[i]]
        input_attention_masks[i, :input_lens[i]] = 1
        
        label_ids[i, :label_lens[i]] = batch_data[i]['labels'][0, :label_lens[i]]
        label_attention_masks[i, :label_lens[i]] = 1

    return {
        'input_ids': input_ids,
        'input_attention_mask': input_attention_masks,
        'label_ids': label_ids,
        'label_attention_mask': label_attention_masks,
    }
    

train_dataloader = DataLoader(train_dataset, 
                    batch_size = 4 ,       # 当调整 batch_size 就报错
                    collate_fn=paddding_collate_fn,
                    pin_memory = True,
                    shuffle = True)
    
for batch in train_dataloader:
    print(batch)
    break

{'input_ids': tensor([[ 1, 10, 22,  3,  4,  9,  2,  0,  0,  0],
        [ 1,  5, 10, 15,  6,  7,  2,  0,  0,  0],
        [ 1, 14,  5, 22, 11, 22, 16, 16, 21,  2],
        [ 1, 20, 15, 18, 14, 23, 24,  2,  0,  0]]), 'input_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 'label_ids': tensor([[ 1, 14,  3, 10, 14,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 14,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 17, 12,  4,  6,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  8,  7, 25, 11, 15,  7, 12,  7, 19, 18,  4, 23,  3,  9,  2]]), 'label_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


## 训练细节

In [12]:
for batch in train_dataloader:
    # 训练输入为
    decoder_input_ids = batch['label_ids'][:, :-1]
    decoder_attention_mask = batch['label_attention_mask'][:, :-1]

    
    label_for_loss = batch['label_ids'][:, 1:]
    label_for_loss[ torch.where(label_for_loss == PAD_TOKEN_ID )  ] = IGNORE_INDEX

    print(label_for_loss)
    # loss_fn(logits, label_for_loss)
    break

tensor([[   7,    5,   18,    4,   22,   25,    5,    5,    3,   25,   13,   15,
           16,    2, -100, -100],
        [  10,    8,   18,   12,   23,   13,    5,    3,   20,    4,   12,   19,
           17,    9,    5,    2],
        [   6,   20,   14,   24,    3,   17,   10,   24,   20,    9,   14,    4,
           17,    2, -100, -100],
        [   6,   18,   16,   24,   25,    2, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]])


## Attention Mask

在 `paddding_collate_fn` 里，实际上存储了 mask 序列而不是矩阵，好处在于：

1. 不占用过多显存
2. 可以在计算 attention 前，即算即释放

以下实现基于序列 mask 的 cross-attention-mask, 贴近 `Transformers` 库的实现

In [13]:
def get_src_trg_mask(src_mask, trg_mask):
    """
    trg 为检索, dim = 0
    """
    bs, src_len = src_mask.shape
    bs, trg_len = trg_mask.shape
    
    mask = torch.zeros(bs, trg_len, src_len)

    for i in range(bs):
        mask[i, :, :] = torch.outer(trg_mask[i,:], src_mask[i,:])
    return mask

src_mask = torch.tensor([[1, 1, 0, 0, 0],[1, 1, 1, 0, 0]])
trg_mask = torch.tensor([[1, 1, 0,],[1, 1, 0]])

mask = get_src_trg_mask(src_mask, trg_mask)
print(mask)

tensor([[[1., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.]]])


In [15]:
import math


for batch in train_dataloader:
    encoder_attention_mask = batch['input_attention_mask']
    decoder_attention_mask = batch['label_attention_mask'][:, :-1]

    # mask
    mask = get_src_trg_mask(encoder_attention_mask, decoder_attention_mask)
    print(mask.shape)
    # print(mask)

    bs, src_len = encoder_attention_mask.shape
    bs, trg_len = decoder_attention_mask.shape

    # multi-head attention score
    heads = 8
    dim = 64
    K = torch.randn( bs, heads, src_len, dim)
    Q = torch.randn( bs, heads, trg_len, dim)

    S = Q @ K.transpose(2,3) / math.sqrt(dim)
    print(S.shape)

    multi_head_mask = mask.unsqueeze(dim = 1) # 扩展头维度，实现头并行 mask
    S_mask = S * multi_head_mask
    print(S_mask.shape)
    
    break 

torch.Size([4, 18, 10])
torch.Size([4, 8, 18, 10])
torch.Size([4, 8, 18, 10])
