In [1]:
import math
import os
import random
import torch
from d2l import torch as d2l

In [5]:
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip', 
                      '319d85e578af0cdc590547f26231e4e31cdf1e42')
def read_ptb():
    """将PTB数据集加载到文本行的列表中。"""
    data_dir = d2l.download_extract('ptb')
    # Read the training set.
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

sentences = read_ptb()

In [6]:
print(f"sentences数：{len(sentences)}")

sentences数：42069


In [9]:
# 为语料库构建一个词表，将出现次数少于10次的单词都用‘<unk>’词元替换
vocab = d2l.Vocab(sentences, min_freq=10)
print(f'vocab size:{len(vocab)}')

vocab size:6719


下采样,数据集中每个词将有概率地被丢弃

In [14]:
def subsample(sentences, vocab):
    """下采样高频词"""
    # 排除未知词元'<unk>'
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                for line in sentences]
    counter = d2l.count_corpus(sentences)
    num_tokens = sum(counter.values())
    
        # 如果在下采样期间保留词元，则返回True
    def keep(token):
        return (random.uniform(0, 1) < 
                math.sqrt(1e-4 / counter[token] * num_tokens))
    
    return ([[token for token in line if keep(token)] 
             for line in sentences], counter)

subsampled, counter = subsample(sentences, vocab)

下采样后，将词元映射到它们在语料库中的索引

In [24]:
corpus = [vocab[line] for line in subsampled]

提取中心词和上下文词

In [38]:
def get_centers_and_contexts(corpus, max_window_size):
    """返回跳远模型中的中心词和上下文词"""
    centers, contexts = [], []
    for line in corpus:
        # 要形成 ‘中心词-上下文词’对，每个句子至少需要2个词
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)): # 上下文窗口中间‘i’
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size), 
                                 min(len(line), i + 1 + window_size))) # 上下文窗口下标
            # 从上下文词中排除中心词
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts

In [39]:
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
print(f'center-context pairs:{sum([len(contexts) for contexts in all_contexts])}')

center-context pairs:1500778


负采样进行近似训练，采样k个不是来自上下文窗口的噪声词

In [41]:
# 。。。

In [52]:
# 对于一对中心词和上下文词，随机抽取K个（例如5个）噪声词
# 一对是指：中心词和上下文窗口中的一个词 是一对。
all_negatives = d2l.get_negatives(all_contexts, vocab, counter, 5) 

小批量加载训练实例

In [70]:
def batchify(data):
    """返回带有负采样的跳远模型的小批量"""
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(contexts_negatives),
           torch.tensor(masks), torch.tensor(labels))

最后，定义读取PTB数据集并返回数据迭代器和词表的load_data_ptb函数

In [79]:
def load_data_ptb(batch_size, max_window_size, num_niose_words):
    """下载PTB数据集，然后将其加载到内存中"""
    num_workers = d2l.get_dataloader_workers()
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(corpus, max_window_size)
    all_negatives = d2l.get_negatives(all_contexts, vocab, counter, num_niose_words) 
    
    class PTBDataset(torch.utils.data.Dataset):
        def __init__(self, centers, contexts, negatives):
            assert len(centers) == len(contexts) == len(negatives)
            self.centers = centers
            self.contexts = contexts
            self.negatives = negatives
        
        def __getitem__(self, index):
            return (self.centers[index], self.contexts[index],
                   self.negatives[index])
        
        def __len__(self):
            return len(self.centers)
        
    dataset = PTBDataset(all_centers, all_contexts, all_negatives)
    data_iter = torch.utils.data.DataLoader(dataset, 
                                               batch_size, 
                                               shuffle=True,
                                                  collate_fn=batchify,
                                                  num_workers=num_workers)
    return data_iter, vocab
    

打印一个小批量

In [80]:
data_iter, vocab = load_data_ptb(512, 5, 5)
names = ['centers', 'contexts_negatives', 'masks', 'labels']
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])
