In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from torchtext.datasets import WikiText2
from torchtext.vocab import build_vocab_from_iterator
import torchtext.data.utils as ttdutils


from text_dataset import TextDataset


In [54]:
train_iter = WikiText2(root="../data", split="train")
# slice to 30 lines
train_iter = list(train_iter)[:2000]
tokenizer = ttdutils.get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
len(vocab)

10053

In [55]:
data = [torch.LongTensor([vocab(tokenizer(item))]) for item in train_iter]
data = tuple(filter(lambda x: x.numel() > 0, data))
data = torch.cat(data, dim=1).squeeze(0)
data.shape

torch.Size([117249])

In [31]:
def collate_fn(batch):
    return torch.stack(batch, dim=1)

dataset = TextDataset(data, seq_len=32)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0, collate_fn=collate_fn)


In [35]:
for i, x in enumerate(dataloader):
    if i == 0:
        for xpart in x.T:
            print(xpart.shape)
            print(vocab.lookup_tokens(xpart.tolist()))
        # print(vocab.lookup_tokens(x.squeeze(1).tolist()))
        break

torch.Size([32])
['land', 'super', 'mario', 'land', '3', 'began', 'the', 'wario', 'franchise', '.', 'after', '19', 'years', ',', 'the', '2011', 'title', 'super', 'mario', '3d', 'land', 'for', 'the', 'nintendo', '3ds', 'became', 'mario', "'", 's', 'first', 'game', 'in']
torch.Size([32])
['the', 'species', '.', '=', '=', 'conservation', '=', '=', '<unk>', 'records', 'indicate', 'that', 'in', 'pre', '@-@', 'polynesian', 'times', ',', 'the', 'kakapo', 'was', 'new', 'zealand', "'", 's', 'third', 'most', 'common', 'bird', 'and', 'it', 'was']


In [23]:
embed_dim = 768
emb = nn.Embedding(len(vocab), embed_dim)
model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu'), num_layers=6)
decoder = nn.Linear(embed_dim, len(vocab))

In [24]:
out = emb(x)
out.shape

eps = model(out)
eps.shape

decoded = decoder(eps)

In [25]:
logits = F.log_softmax(decoded, dim=-1)
logits_permuted = logits.permute(0, 2, 1)
indices = torch.argmax(logits, dim=-1)
print("logits shape:", logits.shape)
print("x shape:", x.shape)



loss = F.cross_entropy(logits_permuted, x)
print(loss)

print("\n\nIndices to text:")
print(indices.shape)
tokens = [vocab.lookup_tokens(i.tolist()) for i in indices.T]
# print(tokens)
sentences = [" ".join(token) for token in tokens]
print(sentences[0])
# vocab.lookup_tokens(indices)


logits shape: torch.Size([32, 1, 28782])
x shape: torch.Size([32, 1])
tensor(10.3937, grad_fn=<NllLoss2DBackward0>)


Indices to text:
torch.Size([32, 1])
sculpture infants gallian corpse elinor swaziland disbanded birmingham meng potatoes preventing emily rotational squirrels isabella watershed thickly birmingham situ clues bout bulldogs receivership isabella 864 presumed desmond microlight infants gallian corpse sensitive
