In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
from torch.utils.data import DataLoader
from tokenizers import Tokenizer
from datasets import load_from_disk

In [2]:
# Load pre-trained tokenizer and tokenized datasets:
tokenizer = Tokenizer.from_file("serialized_tokenizer")
train_ds, val_ds, test_ds = load_from_disk("tokenized_train"), load_from_disk("tokenized_val"), load_from_disk("tokenized_test")
train_ds.set_format(type="pt", columns=["ids", "attention_mask"])
val_ds.set_format(type="pt", columns=["ids", "attention_mask"])
test_ds.set_format(type="pt", columns=["ids", "attention_mask"])

train_ids = train_ds["ids"]
val_ids = val_ds["ids"]
test_ids = test_ds["ids"]

In [3]:
VOCAB_SIZE = tokenizer.get_vocab_size()

def prep_batches(dataset, batch_size, seq_len):
    num_batches = len(dataset) // batch_size
    inputs = dataset[:num_batches * batch_size]
    targets = torch.zeros_like(inputs)
    for i in range(0, len(inputs)):
        targets[i][:-1] = inputs[i][1:] # skip first token
        # targets[i][-1] = dataset[i][0] # as first token is always [CLS], no reason to append to the end.
    inputs = inputs.view((num_batches, -1, seq_len))
    targets = targets.view((num_batches, -1, seq_len))
    return inputs, targets

def one_hot_encode(idx, vocab_size):
    one_hot = np.zeros(vocab_size)
    one_hot[idx] = 1
    return one_hot

def one_hot_encode_seq(sequence, vocab_size):
    encoding = torch.tensor([one_hot_encode(token, vocab_size) for token in sequence])
    #encoding = encoding.view(encoding.shape[0], encoding.shape[1], 1)
    return encoding

In [4]:
SEQ_LEN = 256
EMBED_DIM = 64
HIDDEN_DIM = 64
N_LAYERS = 2
BATCH_SIZE = 64
EPOCHS = 50
DROPOUT_RATE = 0.5

In [5]:
class Seq(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout_rate):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_dim, embed_dim) #input_dim == vocab_size (one-hot encoding)
        self.lstm = nn.LSTM(
            input_size = embed_dim,
            hidden_size = hidden_dim,
            num_layers = n_layers,
            bias = True, # default
            batch_first = True,
            dropout = dropout_rate,
            bidirectional = False # default
        )
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, input_dim)
    def forward(self, x, h, c, teacher_forcing = False):
        # x: [batch, seq len] # Just seq len?
        e = self.dropout(self.embedding(x))
        # e: [batch, seq len, emb]
        o, (h, c) = self.lstm(e,(h,c))
        # o: [batch, seq len, hidden dim], (h, c): [n layers, batch, hidden dim]
        p = self.fc(o)
        # p: [batch, seq len] (try to guess an entire sequence or just next token?)
        return p, h, c

In [6]:
net = Seq(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT_RATE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0, weight_decay=0)

In [7]:
train_batches = prep_batches(train_ids, BATCH_SIZE, SEQ_LEN)
valid_batches = prep_batches(val_ids, BATCH_SIZE, SEQ_LEN)
test_batches  = prep_batches(test_ids, BATCH_SIZE, SEQ_LEN)

In [8]:
h = torch.zeros((N_LAYERS, BATCH_SIZE, HIDDEN_DIM))
c = torch.zeros_like(h)
p, h, c = net(train_batches[0][0], h, c)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)
iteration = 0
for e in range(EPOCHS):
    h = torch.zeros((N_LAYERS, BATCH_SIZE, HIDDEN_DIM))
    c = torch.zeros_like(h)

    net.train()
    for i in range(len(train_batches[0])):
        iteration += 1

        # zero gradients
        optimizer.zero_grad()

        # data to device
        x = torch.tensor(train_batches[0][i]).to(device)
        y = torch.tensor(train_batches[1][i]).to(device)

        lgts, h, c = net(x, h, c)
        loss = criterion(lgts, y)
        h.detach()
        c.detach()

        loss_val = loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
        optimizer.step()



ValueError: Expected target size (64, 12800), got torch.Size([64, 256])

In [189]:
print(train_batches)

(tensor([[[    1,     3,     0,  ...,     0,     0,     0],
         [    1,  4209,  5310,  ...,     0,     0,     0],
         [    1,     3,     0,  ...,     0,     0,     0],
         ...,
         [    1,     3,     0,  ...,     0,     0,     0],
         [    1,  4228,  4827,  ...,  4176,  4165,     3],
         [    1,  4165,  4821,  ...,     0,     0,     0]],

        [[    1,  6130,  4891,  ...,     0,     0,     0],
         [    1,  4165,  4255,  ...,     0,     0,     0],
         [    1,  4383,  6765,  ...,     0,     0,     0],
         ...,
         [    1,  4209,  4209,  ...,     0,     0,     0],
         [    1,     3,     0,  ...,     0,     0,     0],
         [    1,  4165,  5628,  ...,     0,     0,     0]],

        [[    1,  4165,  9834,  ...,     0,     0,     0],
         [    1,  4184,  1272,  ...,     0,     0,     0],
         [    1,  4184,  1272,  ...,     0,     0,     0],
         ...,
         [    1,  4209,  4209,  ...,     0,     0,     0],
         

In [208]:
train_batches[0][0][3].numpy()

array([    1,  5488,  7651,   231,  4613,  5310,  6223,  7016,  1272,
          24,  4427,  4324,  9385,  4382,  4174,  8540,  4569,  4270,
        5920,  4427,  1272,  2371,  1968,  1406,  1439,  1432,  1434,
        1483,  1445,  1477,  1482,  1435,    24,  4169,  5081,  4176,
        5310,  6223,  7016,  4189,  4165,  5196,  5947,  1272,    24,
        4268,  4169,  8844,  7407,  4197,  4234,  5310,  6223,  7016,
        8540,  4569,  7378,  6474,  5303,  4169,  4264,  4162,  8943,
        4373,  5455,  4237,  6053,  5329,  4583,  5847,  4261,  4263,
        6568,  4195,  6364,  4176, 10166,  4228,  4165,  9071,  4970,
        4550,  4176,  4967,  4184,  5276,  1272,  5534,  4184,  5303,
        4169,  4255,  4264,  4165,  5159,  4583,  4184,  4165,  5310,
        6223,  7016,  4781,  4176,  6349,  4192,  4165,  5053,  4178,
        8232,  4189,  8943,  4373,  4195,  5374,  4237,  4514,  8629,
        4234,  4401, 10340,  4571,  4169,  4165,  5226,  6306,  9878,
        4197,  4165,

In [209]:
train_batches[1][0][3].numpy()

array([ 5488,  7651,   231,  4613,  5310,  6223,  7016,  1272,    24,
        4427,  4324,  9385,  4382,  4174,  8540,  4569,  4270,  5920,
        4427,  1272,  2371,  1968,  1406,  1439,  1432,  1434,  1483,
        1445,  1477,  1482,  1435,    24,  4169,  5081,  4176,  5310,
        6223,  7016,  4189,  4165,  5196,  5947,  1272,    24,  4268,
        4169,  8844,  7407,  4197,  4234,  5310,  6223,  7016,  8540,
        4569,  7378,  6474,  5303,  4169,  4264,  4162,  8943,  4373,
        5455,  4237,  6053,  5329,  4583,  5847,  4261,  4263,  6568,
        4195,  6364,  4176, 10166,  4228,  4165,  9071,  4970,  4550,
        4176,  4967,  4184,  5276,  1272,  5534,  4184,  5303,  4169,
        4255,  4264,  4165,  5159,  4583,  4184,  4165,  5310,  6223,
        7016,  4781,  4176,  6349,  4192,  4165,  5053,  4178,  8232,
        4189,  8943,  4373,  4195,  5374,  4237,  4514,  8629,  4234,
        4401, 10340,  4571,  4169,  4165,  5226,  6306,  9878,  4197,
        4165,  4394,