In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

SEED = 515
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
import torchtext
from torchtext.data import Field, LabelField, BucketIterator

# Set `batch_first=True` in the `Field`.
TEXT = Field(tokenize='spacy', include_lengths=True, batch_first=True)
LABEL = LabelField()

train_data, test_data = torchtext.datasets.TREC.splits(TEXT, LABEL, fine_grained=False, root='data')
train_data, valid_data = train_data.split()

In [3]:
MAX_VOCAB_SIZE = 25000

TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE, 
                 vectors="glove.6B.100d", vectors_cache="vector_cache", 
                 unk_init=torch.Tensor.normal_)

LABEL.build_vocab(train_data)
print(LABEL.vocab.stoi)

defaultdict(None, {'ENTY': 0, 'HUM': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ex = train_data[0]
print(ex.text)
# If `include_lengths=False`, it should be:  
# `TEXT.numericalize([ex.text], device=device)`
text, text_lens = TEXT.numericalize(([ex.text], [len(ex.text)]), device=device)
print(text)
print(text_lens)

['What', 'do', 'the', 'letters', 'D.C.', 'stand', 'for', 'in', 'Washington', ',', 'D.C.', '?']
tensor([[  4,  24,   3, 297, 552, 106,  18,   7, 286,  14, 552,   2]])
tensor([12])


In [5]:
BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, sort_within_batch=True, device=device)

## `BucketIterator.splits`
By default, the first dataset (`train_iterator`) would be shuffled, while the other datasets (`valid_iterator` and `test_iterator`) would be not shuffled. 

In [6]:
for i, batch in enumerate(train_iterator):
    text, text_lens = batch.text
    print(text_lens.max(), text_lens[0], text_lens[-1])
    if i >= 9:
        break

tensor(11) tensor(11) tensor(11)
tensor(16) tensor(16) tensor(15)
tensor(6) tensor(6) tensor(5)
tensor(9) tensor(9) tensor(9)
tensor(20) tensor(20) tensor(17)
tensor(11) tensor(11) tensor(11)
tensor(10) tensor(10) tensor(10)
tensor(15) tensor(15) tensor(14)
tensor(10) tensor(10) tensor(10)
tensor(14) tensor(14) tensor(13)


In [7]:
for i, batch in enumerate(train_iterator):
    text, text_lens = batch.text
    print(text_lens.max(), text_lens[0], text_lens[-1])
    if i >= 9:
        break

tensor(14) tensor(14) tensor(13)
tensor(7) tensor(7) tensor(6)
tensor(7) tensor(7) tensor(7)
tensor(11) tensor(11) tensor(10)
tensor(11) tensor(11) tensor(11)
tensor(6) tensor(6) tensor(5)
tensor(16) tensor(16) tensor(15)
tensor(9) tensor(9) tensor(9)
tensor(10) tensor(10) tensor(9)
tensor(12) tensor(12) tensor(12)


In [8]:
for i, batch in enumerate(valid_iterator):
    text, text_lens = batch.text
    print(text_lens.max(), text_lens[0], text_lens[-1])
    if i >= 9:
        break

tensor(5) tensor(5) tensor(3)
tensor(7) tensor(7) tensor(6)
tensor(7) tensor(7) tensor(7)
tensor(8) tensor(8) tensor(7)
tensor(9) tensor(9) tensor(8)
tensor(9) tensor(9) tensor(9)
tensor(10) tensor(10) tensor(9)
tensor(11) tensor(11) tensor(10)
tensor(12) tensor(12) tensor(11)
tensor(13) tensor(13) tensor(12)


In [9]:
for i, batch in enumerate(valid_iterator):
    text, text_lens = batch.text
    print(text_lens.max(), text_lens[0], text_lens[-1])
    if i >= 9:
        break

tensor(5) tensor(5) tensor(3)
tensor(7) tensor(7) tensor(6)
tensor(7) tensor(7) tensor(7)
tensor(8) tensor(8) tensor(7)
tensor(9) tensor(9) tensor(8)
tensor(9) tensor(9) tensor(9)
tensor(10) tensor(10) tensor(9)
tensor(11) tensor(11) tensor(10)
tensor(12) tensor(12) tensor(11)
tensor(13) tensor(13) tensor(12)
