## Natural Language Inference and the Dataset

In [9]:
import os
import re
import torch
from torch import nn
from d2l import torch as d2l

d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

## Reading the Dataset

In [5]:
def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除我们不会使用的信息
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 用一个空格替换两个或多个连续的空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] \
                in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels

## Print the first 3 pairs

In [11]:
train_data = read_snli(data_dir, is_train=True)

for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('premise: ', x0)
    print('hypothesis: ', x1)
    print('label: ', y)
    
premises_tokens = d2l.tokenize(train_data[0][:3])

premise:  A person on a horse jumps over a broken down airplane .
hypothesis:  A person is training his horse for a competition .
label:  2
premise:  A person on a horse jumps over a broken down airplane .
hypothesis:  A person is at a diner , ordering an omelette .
label:  1
premise:  A person on a horse jumps over a broken down airplane .
hypothesis:  A person is outdoors , on a horse .
label:  0


In [12]:
premises_tokens

[['A',
  'person',
  'on',
  'a',
  'horse',
  'jumps',
  'over',
  'a',
  'broken',
  'down',
  'airplane',
  '.'],
 ['A',
  'person',
  'on',
  'a',
  'horse',
  'jumps',
  'over',
  'a',
  'broken',
  'down',
  'airplane',
  '.'],
 ['A',
  'person',
  'on',
  'a',
  'horse',
  'jumps',
  'over',
  'a',
  'broken',
  'down',
  'airplane',
  '.']]

## Labels "entailment", "contradiction", and "neutral" aar blanced

In [8]:
test_data = read_snli(data_dir, is_train=False)

for data in [train_data, test_data]:
    print([[row for row in data[2]].count(i) for i in range(3)])
    
tokens = 

[183416, 183187, 182764]
[3368, 3237, 3219]


## Build Dataset

In [16]:
class SNLIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('Read ' + str(len(self.premises)) + ' examples')
        
    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

## Putting All things Together

In [18]:
def load_data_snli(batch_size, num_steps=50):
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, is_train=True)
    test_data = read_snli(data_dir, is_train=False)
    
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps)
    
    train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab


train_iter, test_iter, vocab = load_data_snli(128, 50)

Read 549367 examples
Read 9824 examples


In [22]:
for X, Y in train_iter:
    print(X[0].shape)
    print(X[1].shape)
    print(Y.shape)
    print(Y[0:5])
    break

torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])
tensor([1, 0, 1, 2, 2])
