In [4]:
import numpy as np

from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe

import torch

from sklearn.manifold import TSNE

# we'll use the bokeh library to create beautiful plots
# *_notebook functions are needed for correct use in jupyter
from bokeh.plotting import figure, ColumnDataSource
from bokeh.models import HoverTool
from bokeh.io import output_notebook, show, push_notebook
# output_notebook()

In [40]:
# ---- Data processing ----

# improve using 'spacy'
def tokenize(x):
    return x.split()

# Define preprocessing pipeline
TEXT = data.Field(lower=True, tokenize=tokenize, sequential=True)

# create splits
train, valid, test = datasets.PennTreebank.splits(TEXT) # loading custom datasets requires passing in the field, but nothing else.

# this takes a long long time, but without a vocab I get errors
TEXT.build_vocab(train, max_size=None, vectors=[GloVe(name='6B', dim='300')])

# Create iterators of batch_size = 32
train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=32,
    bptt_len=30, # this is where we specify the sequence length
    device=torch.device('cpu'),
    repeat=False)


.vector_cache/glove.6B.zip: 862MB [45:52, 313kB/s]                                 
100%|█████████▉| 399622/400000 [01:56<00:00, 4775.58it/s]

In [42]:
print(next(iter(train_iter)).text)

tensor([[ 9971,    38,  2438,    11,   233,   540,    44,   168,  1453,   115,
            30,   478,     0,   128,     9,     9,   313,  1596,  1815,     2,
            67,    42,    89,    16,     2,     4,  7133,   331,   573,  8516,
          2392,    32],
        [ 9972,    34,    55,     2,    71,  1305,   115,     5,  8596,    90,
            29,    14,   482,   270,     7,  1036,  1642,     0,     8,  3713,
           111,     2,    20,    15,   508,     4,   788,   322,     2,   564,
          2168,     0],
        [ 9973,   853,  2156,   498,     3,    20,    29,  8246,    60,    87,
          2112,  1292,   582,     5,     0,    28,     5,     7,     2,   211,
           922,  1156,     0,    34,    75,     6,   133,   113,   617,     3,
             8,    44],
        [ 9975,  7536,     5,   513,    30,  3679,   176,    96,  1570,    26,
           117,     6,    17,  3517,    12,  2634,  1064,    38,   764,   133,
            26,   213,     9,   185,  2620,     4,   172,  