In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from torchtext.datasets import WikiText2
from torchtext.vocab import build_vocab_from_iterator
import torchtext.data.utils as ttdutils


from text_dataset import TextDataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_iter = WikiText2(root="../data", split="train")
# slice to 30 lines
train_iter = list(train_iter)[:2000]
tokenizer = ttdutils.get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
len(vocab)

10053

In [3]:
data = [torch.LongTensor([vocab(tokenizer(item))]) for item in train_iter]
data = tuple(filter(lambda x: x.numel() > 0, data))
data = torch.cat(data, dim=1).squeeze(0)
data.shape

torch.Size([117249])

In [17]:
def collate_fn(batch):
    return torch.stack(batch, dim=1)

dataset = TextDataset(data, seq_len=32)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0, collate_fn=collate_fn)


In [18]:
for i, x in enumerate(dataloader):
    if i == 0:
        for xpart in x.T:
            print(xpart.shape)
            print(" ".join(vocab.lookup_tokens(xpart.tolist())))
        # print(vocab.lookup_tokens(x.squeeze(1).tolist()))
        break

torch.Size([32])
homes in the human world . there they inhabited the cult images , the statues that depicted deities and allowed humans to interact with them in temple rituals . this movement between
torch.Size([32])
1970s that this blend of technologies started to mature , resulting in the birth of the microlight movement . another milestone in the development of ga was the 1964 introduction of the
torch.Size([32])
building . the new museum ' s goal is to educate and inform visitors about the military history of arkansas , preserve the tower building , honor servicemen and <unk> of the
torch.Size([32])
operator must satisfy the caa that the physical conditions at the aerodrome , and its environs , are acceptable the scale of equipment , and facilities provided , are adequate for the
torch.Size([32])
seemed to satisfy the committee as nothing more was done for the time , and when a gold dollar was proposed again in 1846 , mckay ' s committee recommended against it
torch.Size([32])
obsc

In [29]:
from tqdm import tqdm

counts = {}
for i, x in enumerate(tqdm(dataloader)):
    for e in x:
        for w in e:
            wtext = vocab.lookup_token(w.item())
            if wtext in counts:
                counts[wtext] += 1
            else:
                counts[wtext] = 1


100%|██████████| 916/916 [00:21<00:00, 42.38it/s]


In [32]:
# print the 10 most common words
print(sorted(counts.items(), key=lambda x: x[1], reverse=True)[:10])

# print the total count
print(sum(counts.values()))


[('the', 241704), (',', 202023), ('.', 151435), ('of', 107923), ('<unk>', 101322), ('and', 95472), ('in', 82716), ('to', 70331), ('a', 64928), ('=', 50406)]
3750944


In [6]:
embed_dim = 768
emb = nn.Embedding(len(vocab), embed_dim)
model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=12, dim_feedforward=3072, dropout=0.1, activation='gelu'), num_layers=6)
decoder = nn.Linear(embed_dim, len(vocab))

In [7]:
out = emb(x)
out.shape

eps = model(out)
eps.shape

decoded = decoder(eps)

In [8]:
logits = F.log_softmax(decoded, dim=-1)
logits_permuted = logits.permute(0, 2, 1)
indices = torch.argmax(logits, dim=-1)
print("logits shape:", logits.shape)
print("x shape:", x.shape)



loss = F.cross_entropy(logits_permuted, x)
print(loss)

print("\n\nIndices to text:")
print(indices.shape)
tokens = [vocab.lookup_tokens(i.tolist()) for i in indices.T]
# print(tokens)
sentences = [" ".join(token) for token in tokens]
print(sentences[0])
# vocab.lookup_tokens(indices)


logits shape: torch.Size([32, 2, 10053])
x shape: torch.Size([32, 2])
tensor(9.4813, grad_fn=<NllLoss2DBackward0>)


Indices to text:
torch.Size([32, 2])
walters missiles deemed matters self deemed uniform deemed takes angie millimeters wisniewski deemed airworthiness reworked functional mastering yeah elected were ministers taxon inside punk blade reworked encourage predominant generic rejected deemed assessed
