In [1]:
!pip install gdown --no-cache-dir -U



## Import libraries

In [1]:
import os
import re
import random
import time
from timeit import default_timer as timer
from typing import Iterable, List
from tqdm import tqdm
import h5py
import pathlib as pl
from pathlib import Path
import pandas as pd
import numpy as np
from d2l import torch as d2l
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchtext.data import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import GloVe, vocab
from torchinfo import summary
%matplotlib inline

## Obtain the data

Get the data from google drive and extract it into the correct directory

In [None]:
!gdown "18d7-qbKjt2uS1ORdvVIr8LBrTqdZYaTI"
!tar xvjf "./data/C4_200M.hdf5-00001.3-of-00010.tar.bz2"

## Dataset Object

Create an HDF5 dataset object which will be fed into the DataLoader later

In [2]:
class Hdf5Dataset(Dataset):
    """Custom Dataset for loading entries from HDF5 databases"""

    def __init__(self, h5_path: str, transform = None, num_entries: int = None):
        """Initialization function. Obtains the raw data as bytes, num entries, and transform, if any
        
        :param h5_path: a string representing a valid path to the hdf5 file to read
        :param transform: a function that is used to process the raw input (optional)
        :param num_entries: number of entries in the dataset (optional)
        :returns: an object of the Hdf5Datset
        """
        self.h5f = h5py.File(h5_path, "r") # open HDF5 file for reading
        if num_entries:
            self.num_entries = num_entries
        else: # if num_entries not given, set len of dataset to the len of the lables in the hdf5 file
            self.num_entries = self.h5f["labels"].shape[0]
        self.transform = transform

    def __getitem__(self, index):
        """obtains both the input as well as the labels at the given index. index may
           be an integer or slice
        
        :param index: integer or slice to index dataset with
        :returns: a tuple containing either a single input,label pair or list of inputs, list of labes pair
        """
        if isinstance(index, slice):
            input = [entry.decode("utf-8") for entry in self.h5f["input"][index]]
            label = [entry.decode("utf-8") for entry in self.h5f["labels"][index]]
        elif isinstance(index, int):
            if index > self.num_entries:
                raise StopIteration
            input = self.h5f["input"][index].decode("utf-8")
            label = self.h5f["labels"][index].decode("utf-8")
        # if self.transform is not None:
        #     features = self.transform(input)
        return input, label

    def __len__(self):
        """gets number of entries"""
        return self.num_entries

In [3]:
# playing with hdf5 data importing in python
root_path = "./"
data_path = f"{root_path}data/"
c4_hdf5_train_filename = "C4_200M.hdf5-00000-of-00010"
c4_hdf5_train_filepath = f"{data_path}{c4_hdf5_train_filename}"
c4_dataset = Hdf5Dataset(c4_hdf5_train_filepath) # no transforms yet, no len given

In [4]:
len(c4_dataset)
c4_dataset[-1]

('Preheat oven to 356 degrees Fahrenheit. (180 & degrees celsius), I don’t use a macro.wave so I put the coconut oil in the oven to melt it while preheated. (lifehack :-)) Some extra coconut oil to green coconut oil to grease your pan,',
 'Preheat oven to 356 degrees Fahrenheit (180 degrees celsius). I don’t have a microwave so I put the coconut oil in the oven to melt it while it is preheating (lifehack :-)). Use some extra coconut oil to grease your pan.')

## Global Variables

includes input parameters, vocabulary parameters, filepaths

In [5]:
# def yield_tokens(data_iter: Iterable, index: int, src_lang_lbl: str, tgt_lang_lbl: str) -> List[str]:
#     """Helper function to yield a list of tokens from a data iterable"""
#     language_index = {src_lang_lbl: 0, tgt_lang_lbl: 1}
#     for data_sample in tqdm(data_iter):
#         if data_sample[index] and isinstance(data_sample[index], str):
#             yield token_transform(data_sample[index])

SRC_LANGUAGE = "incorrect"
TGT_LANGUAGE = "correct"

MAX_LENGTH = 512 # maximum length of an input sequence
VOCAB_SIZE = 20000 # size of the vocabulary, something to play around with
# debug nums
N_TRAIN_SAMPLES = 10000 # number of training samples to pull from the dataset
N_VAL_SAMPLES = int(0.1 * N_TRAIN_SAMPLES) # number of validation/test samples to pull from the dataset

# function-name placeholders for transforms
token_transform = get_tokenizer("basic_english")
vocab_transform = None

folder = "./data"
train_filename = "C4_200M.hdf5-00000-of-00010"
valid_filename = "C4_200M.hdf5-00001-of-00010"

embedding_path = "./glove.42B.300d.txt"

checkpoint_folder = "./checkpoints"

In [None]:
import gdown

gdown.download_folder(
    "https://drive.google.com/drive/folders/1FQ_jm765fgwcD5lLtjl6ef9k532hdADR",
    quiet=True,
)

# Vocabulary

Create special tokens then create pre-trained GloVe vocabulary

In [6]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3

# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ["<UNK>", "<PAD>", "<BOS>", "<EOS>"]

In [7]:
def pretrained_embs(name: str, dim: str, max_vectors: int = None):
    """get the pre-trained word embeddings for the vocabulary
    
    :param name: the name of the embedding, according to torchtext.vocab.GloVe
    :param dim: the desired dimensionality of the embedding
    :param max_vectors: maximum number of pre-trained vectors to load
    :returns: a tuple containing the str-to-int vocab dict and the embedded vectors of the vocab
    """
    glove_vectors = GloVe(name=name, dim=dim, max_vectors=max_vectors)
    glove_vocab = vocab(glove_vectors.stoi)
    pretrained_embeddings = glove_vectors.vectors
    glove_vocab.insert_token("<UNK>", UNK_IDX)
    pretrained_embeddings = torch.cat(
        (torch.mean(pretrained_embeddings, dim=0, keepdims=True), pretrained_embeddings)
    )
    glove_vocab.insert_token("<PAD>", PAD_IDX)
    pretrained_embeddings = torch.cat(
        (torch.zeros(1, pretrained_embeddings.shape[1]), pretrained_embeddings)
    )
    glove_vocab.insert_token("<BOS>", PAD_IDX)
    pretrained_embeddings = torch.cat(
        (torch.rand(1, pretrained_embeddings.shape[1]), pretrained_embeddings)
    )
    glove_vocab.insert_token("<EOS>", PAD_IDX)
    pretrained_embeddings = torch.cat(
        (torch.rand(1, pretrained_embeddings.shape[1]), pretrained_embeddings)
    )
    glove_vocab.set_default_index(UNK_IDX)
    return glove_vocab, pretrained_embeddings

In [72]:
vocab, embeddings = pretrained_embs("42B", "300", 20000)
print(type(vocab), type(embeddings))
torch.save(embeddings, "glove.42B.300d.20K.pth")

.vector_cache/glove.42B.300d.zip: 1.88GB [05:53, 5.32MB/s]                                                                                                                                                                                                                    
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 19999/20000 [00:01<00:00, 14166.85it/s]


<class 'torchtext.vocab.vocab.Vocab'> <class 'torch.Tensor'>


In [8]:
# Load vocabulary and pretrained embeddings
vocab_transform = torch.load("vocab/vocab_20K.pth")
embeddings = torch.load("glove.42B.300d.20K.pth")

---

## Collation

In [9]:
def sequential_transforms(*transforms):
    """Helper function to compose sequential operations
    
    param transforms: list of transformation functions to be applied
    :returns: the full transformation after composition of sequential ops
    """
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func


def tensor_transform(token_ids: List[int]):
    """prepends and appends the given list of token ids with BOS id and EOS
    
    param token_ids: a list of integers representing the ids of the tokens in a sequence
    :returns: a pytorch tensor representation of the concatenated BOS id, ids list, and EOS id
    """
    return torch.cat(
        (torch.tensor([BOS_IDX]), torch.tensor(token_ids), torch.tensor([EOS_IDX]))
    )

# src and tgt language text transforms to convert raw strings into tensors indices
# with BOS and EOS prepended and appended resp.
text_transform = sequential_transforms(
    token_transform, vocab_transform, tensor_transform
)

# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform(src_sample.rstrip("\n")))
        tgt_batch.append(text_transform(tgt_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    # src_batch = torch.permute(src_batch, (1, 0))
    # tgt_batch = torch.permute(tgt_batch, (1, 0))
    return src_batch, tgt_batch

In [10]:
text = """According to all known laws of aviation, there
          is no way a bee should be able to fly."""
tokenized_input = token_transform(text)
print("tokenized input:\n", tokenized_input)

encoded_input = vocab_transform(tokenized_input)
print("encoded input:\n", encoded_input)

print("transformed input:\n", text_transform(text))

tokenized input:
 ['according', 'to', 'all', 'known', 'laws', 'of', 'aviation', ',', 'there', 'is', 'no', 'way', 'a', 'bee', 'should', 'be', 'able', 'to', 'fly', '.']
encoded input:
 [306, 8, 42, 529, 1888, 9, 4717, 5, 70, 13, 81, 138, 10, 6663, 129, 28, 315, 8, 2702, 4]
transformed input:
 tensor([   2,  306,    8,   42,  529, 1888,    9, 4717,    5,   70,   13,   81,
         138,   10, 6663,  129,   28,  315,    8, 2702,    4,    3])


## Unknown words

In [11]:
text = "lmk where ur at"
tokenized_input = token_transform(text)
print(tokenized_input)

encoded_input = vocab_transform(tokenized_input)
print(encoded_input)

print("transformed input:\n", text_transform(text))

['lmk', 'where', 'ur', 'at']
[0, 116, 11025, 23]
transformed input:
 tensor([    2,     0,   116, 11025,    23,     3])


## RNN Network

### Masking

We will definitely need a padding mask, and if we use self-attention (maybe later) we will also need a subsequent mask

In [12]:
def generate_square_subsequent_mask(sz, device='cpu'):
    """Create a mask that ignores future inputs

    :param sz: the size the input sequence to which the mask will be applied
    :returns: a tensor with shape (sz, sz)
    """
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


def create_padding_mask(src):
    """creates a padding mask which ignores PAD tokens when applied to an input

    :param src: the input sequence to compute the mask for
    :returns: a tensor with shape (src_seq_len, src_seq_len)
    """
    src_seq_len = src.shape[0]
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    return src_padding_mask

In [13]:
example_size = 10
subseq_mask = generate_square_subsequent_mask(example_size)
subseq_mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [14]:
example_text = "lmk where ur at"
example_text_id_seq = text_transform(example_text.rstrip("\n"))

example_max_len = 50
num_pad = example_max_len - len(example_text_id_seq)
padded_example_text_id_seq = torch.cat((example_text_id_seq, torch.tensor([PAD_IDX] * num_pad)))
padded_example_text_id_seq = padded_example_text_id_seq.view(-1,1) # turn the tensor into a column vec
example_padding_mask = create_padding_mask(padded_example_text_id_seq)
print("example text ids before padding:\n", example_text_id_seq)
print()
print("example text ids after padding:\n", padded_example_text_id_seq.T)
print()
print("example text padding mask:\n", example_padding_mask)

example text ids before padding:
 tensor([    2,     0,   116, 11025,    23,     3])

example text ids after padding:
 tensor([[    2,     0,   116, 11025,    23,     3,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]])

example text padding mask:
 tensor([[False, False, False, False, False, False,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])


### Model-Specific Parameters

In [15]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

teacher_forcing_ratio = 0.5
torch.manual_seed(0)

MAX_LENGTH = 512
EMB_SIZE = 300
HIDDEN_SIZE = 512
BATCH_SIZE = 16
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1

learning_rate = 0.001

cuda


### The Model

The RNN we employ contains encoder block(s) and a decoder block(s) which themselves are comprised of GRU units.\
We Apply an embedding layer before the encoder/decoder blocks, and then we apply a single dropout layer for the\
encoding layer, and a fully connected output layer, followed by drop-out and softmax in the decoding layer.

#### TODO:

- check that the embeddings are not trainable
- double-check the architecture of the encoder
- double check the architecture of the decoder, particularly the fc outputs

In [16]:
def init_seq2seq(module):
    """Initialize weights for seq2seq"""
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)
    if type(module) == nn.GRU:
        for param in module._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(module._parameters[param])

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout=0, embedding_weights=None):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        self.emb_dim = emb_dim
        self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx=1)
        if embedding_weights is not None:
            self.embedding.weight = torch.nn.Parameter(
                torch.from_numpy(embedding_weights)
            )
        self.gru = nn.GRU(emb_dim, hid_dim, num_layers=n_layers, batch_first=False, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.apply(init_seq2seq)

    def forward(self, src):
        # src shape: (seq_len, batch_size)
        embedded = self.dropout(self.embedding(src.type(torch.int64)))
        # embedded shape: (seq_len, batch_size, emb_dim)
        outputs, state = self.gru(embedded)
        # outputs shape: (seq_len, batch_size, hid_dim * n_dirs)
        # state shape: (n_layers * n_dirs, batch_size, hid_dim)
        return outputs, state

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout=0, embedding_weights=None,
                 teacher_forcing=0.5):
        super().__init__()
        self.hid_dim = hid_dim
        self.output_dim = output_dim
        self.emb_dim = emb_dim
        self.n_layers = n_layers
        self.teacher_forcing = teacher_forcing
        self.embedding = nn.Embedding(output_dim, emb_dim, padding_idx=1)
        if embedding_weights is not None:
            self.embedding.weight = torch.nn.Parameter(
                torch.from_numpy(embedding_weights)
            )
        # only have one layer, so dropout not actually added
        self.gru = nn.GRU(emb_dim + hid_dim, hid_dim, num_layers=n_layers, batch_first=False, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.LogSoftmax(dim=1)
        self.apply(init_seq2seq)
        
    def forward(self, input, context, hidden):
        # input shape: (batch_size)
        # hidden shape: (n_layers * n_dirs, batch_size, hid_dim)
        input = input.unsqueeze(0)
        # input shape: (1, batch_size)
        embedded = self.dropout(self.embedding(input.type(torch.int64)))
        # embedded shape: (1, batch_size, emb_dim)
        emb_and_con = torch.cat((embedded, context), dim=-1)
        output, hidden = self.gru(emb_and_con, hidden)
        # output shape: (seq_len, batch_size, hid_dim * n_dirs)
        # hidden shape: (n_layers * n_dirs, batch_size, hid_dim)

        # will always have seq_len == n_dr == 1 in decoder (unless use self-attention
        prediction = self.softmax(self.fc_out(output.squeeze(0)))
        # prediction shape: (batch_size, output_dim)
        return prediction, hidden

In [17]:
# testing out the encoder and decoder
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 1
batch_size, num_t_steps = 4, 9
encoder = Encoder(vocab_size, embed_size, num_hiddens, num_layers)
example_src = torch.zeros(num_t_steps, batch_size)
example_trg = torch.zeros(batch_size)
enc_outputs, enc_hidden = encoder(example_src)
d2l.check_shape(enc_outputs, (num_t_steps, batch_size, num_hiddens))
d2l.check_shape(enc_hidden, (num_layers, batch_size, num_hiddens))

decoder = Decoder(vocab_size, embed_size, num_hiddens, num_layers)
dec_output, dec_hidden = decoder(example_trg, enc_hidden, enc_hidden)
d2l.check_shape(dec_output, (batch_size, vocab_size)) # output for a single t step!
d2l.check_shape(dec_hidden, (num_layers, batch_size, num_hiddens))

In [18]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert (
            encoder.hid_dim == decoder.hid_dim
        ), "Hidden dimensions of encoder and decoder must be equal!"
        assert (
            encoder.n_layers == decoder.n_layers
        ), "Number of layers of encoder and decoder must be equal!"

    def forward(self, src, trg, teacher_forcing=0.5):
        # src shape: (num_t_steps, batch_size)
        # trg shape: (num_t_steps, batch_size)
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        # last hidden state of the encoder is the context
        trg_len, batch_size = trg.shape
        trg_vocab_size = self.decoder.output_dim
        # tensor to store outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        # obtain the hidden state from the encoder...
        _, enc_hidden = self.encoder(src)
        # ...which also servers as the input context and hidden state to the decoder
        context = enc_hidden
        dec_hidden = enc_hidden
        # get the first batch-worth's tokens from the target (ie batch_size number of <BOS> tokens)
        dec_input = trg[0, :]
        for t in np.arange(1, trg_len+1):
            dec_output, dec_hidden = self.decoder(dec_input, context, dec_hidden)
            # dec_output is softmax regression over vocab ids, so grab id that maximizes prob
            outputs[t-1] = dec_output # give outputs the probs: this is what softmax is trained on!
            top_token_id = dec_output.argmax(1) # <- does this give (batch_size, vocab_sz) tensor??
            if t < trg_len:
                dec_input = trg[t] if np.random.rand() < teacher_forcing else top_token_id
        return outputs

In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [23]:
# attn = Attention(HIDDEN_SIZE, HIDDEN_SIZE)

encoder1 = Encoder(
    VOCAB_SIZE, EMB_SIZE, HIDDEN_SIZE, NUM_ENCODER_LAYERS, dropout=0, embedding_weights=np.array(embeddings)
)
decoder1 = Decoder(
    VOCAB_SIZE, EMB_SIZE, HIDDEN_SIZE, NUM_DECODER_LAYERS, dropout=0.1, embedding_weights=np.array(embeddings)
)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

model = Seq2Seq(encoder1, decoder1, DEVICE)
model = nn.DataParallel(model, device_ids=[0,1])
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)

#summary(model, input_size=[(MAX_LENGTH, 1), (MAX_LENGTH, 1)], device='cuda')
print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 25,549,440 trainable parameters


In [21]:
def train(model, iterator, optimizer, loss_fn, clip):
    model.train() # set model into training mode so apply dropout, etc.
    epoch_loss = 0
    for src, trg in tqdm(iterator):
        optimizer.zero_grad()
        src = src.to(DEVICE)
        trg = trg.to(DEVICE)
        output = model(src, trg)
        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        loss = loss_fn(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval() # set model into inference mode, so DONT apply dropout
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in tqdm(iterator):
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)
            output = model(src, trg, 0) #turn off teacher forcing
            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size]
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)


In [22]:
current_time = lambda: time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
print(current_time())

2023-08-11-13:19:50


In [23]:
NUM_EPOCHS = 10
CLIP = 1 # gradient clipping
RESUME = False

train_iter = Hdf5Dataset(
    pl.Path(folder) / train_filename, num_entries=N_TRAIN_SAMPLES)
train_dataloader = DataLoader(
    train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_iter = Hdf5Dataset(pl.Path(folder) / valid_filename, num_entries=N_VAL_SAMPLES)
val_dataloader = DataLoader(
    val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# make sure folder exists
pl.Path("checkpoints").mkdir(parents=True, exist_ok=True)

if RESUME:
    checkpoint = torch.load(
        pl.Path("checkpoints") /
        f"model-epoch_{NUM_EPOCHS-1}-{current_time()}.pt"
    )
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch = checkpoint["epoch"]

for epoch in range(1, NUM_EPOCHS + 1):
    start_time = timer()
    print(
        f"\033[92mEpoch {epoch} of {NUM_EPOCHS} - time: {current_time()}\033[0m")
    print(f"\033[92mTraining...\033[0m")
    train_loss = train(model, train_dataloader, optimizer, loss_fn, CLIP)
    end_time = timer()
    print(f"\033[92mValidating...\033[0m")
    val_loss = evaluate(model, val_dataloader, loss_fn)
    print(
        (
            f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
            f"Epoch time = {(end_time - start_time):.3f}s"
        )
    )
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": val_loss,
        },
        pl.Path("checkpoints") /
        f"model-epoch_{NUM_EPOCHS-1}-{current_time()}.pt",
    )

[92mEpoch 1 of 10 - time: 2023-08-11-13:19:51[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:17<00:00,  8.05it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.86it/s]


Epoch: 1, Train loss: 6.223, Val loss: 6.675, Epoch time = 77.602s
[92mEpoch 2 of 10 - time: 2023-08-11-13:21:13[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.30it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.88it/s]


Epoch: 2, Train loss: 4.630, Val loss: 6.658, Epoch time = 75.260s
[92mEpoch 3 of 10 - time: 2023-08-11-13:22:33[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.31it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.90it/s]


Epoch: 3, Train loss: 3.856, Val loss: 6.568, Epoch time = 75.237s
[92mEpoch 4 of 10 - time: 2023-08-11-13:23:52[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.31it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.94it/s]


Epoch: 4, Train loss: 3.468, Val loss: 6.565, Epoch time = 75.180s
[92mEpoch 5 of 10 - time: 2023-08-11-13:25:12[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.33it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.97it/s]


Epoch: 5, Train loss: 3.200, Val loss: 6.542, Epoch time = 75.016s
[92mEpoch 6 of 10 - time: 2023-08-11-13:26:32[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.32it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.65it/s]


Epoch: 6, Train loss: 2.989, Val loss: 6.487, Epoch time = 75.083s
[92mEpoch 7 of 10 - time: 2023-08-11-13:27:51[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.33it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.83it/s]


Epoch: 7, Train loss: 2.871, Val loss: 6.480, Epoch time = 75.058s
[92mEpoch 8 of 10 - time: 2023-08-11-13:29:11[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:15<00:00,  8.33it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.85it/s]


Epoch: 8, Train loss: 2.702, Val loss: 6.482, Epoch time = 75.077s
[92mEpoch 9 of 10 - time: 2023-08-11-13:30:31[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:14<00:00,  8.35it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.93it/s]


Epoch: 9, Train loss: 2.617, Val loss: 6.443, Epoch time = 74.849s
[92mEpoch 10 of 10 - time: 2023-08-11-13:31:50[0m
[92mTraining...[0m


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [01:14<00:00,  8.34it/s]


[92mValidating...[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:04<00:00, 14.91it/s]


Epoch: 10, Train loss: 2.511, Val loss: 6.428, Epoch time = 74.956s


In [51]:
# function to generate output sequence using greedy algorithm
def correct_sentence_vectorized(src_tensor, model, max_len=50):
    assert isinstance(src_tensor, torch.Tensor)
    if isinstance(model, nn.DataParallel):
        model = model.module # get the accessor if the model was run in wrapped dataparallel obj
    model.eval()
    src_tensor = src_tensor.unsqueeze(1).to(DEVICE)
    # get length of input sentence
    src_len = src_tensor.shape[0]
    trg_vocab_size = model.decoder.output_dim
    # tensor to store decoder outputs
    outputs = torch.zeros(max_len, 1, trg_vocab_size).to(DEVICE)
    # last hidden state of the encoder is the context
    with torch.no_grad():
        _, enc_hidden = model.encoder(src_tensor)
    # context also used as the initial hidden state of the decoder
    context = enc_hidden
    dec_hidden = enc_hidden
    # first input to the decoder is the <sos> tokens
    dec_input = src_tensor[0]
    # enc_src = [batch_sz, src_len, hid_dim]
    # Even though some examples might have been completed by producing a <eos> token
    # we still need to feed them through the model because other are not yet finished
    # and all examples act as a batch. Once every single sentence prediction encounters
    # <eos> token, then we can stop predicting.
    for t in range(1, max_len+1):
        # insert input token embedding, previous hidden state and the context state
        # receive output tensor (predictions) and new hidden state
        dec_output, dec_hidden = model.decoder(dec_input, context, dec_hidden)
        # place predictions in a tensor holding predictions for each token
        outputs[t-1] = dec_output
        # get the highest predicted token from our predictions
        top_token_id = dec_output.argmax(1)
        dec_input = top_token_id

    pred_sentence = []
    for i in range(0, len(outputs)):
        topv, topi = outputs[i, :, :].topk(1)
        pred_sentence.append(vocab_transform.vocab.itos_[topi])
        if topi == EOS_IDX:
            break
    return " ".join(pred_sentence)

In [52]:
latest_checkpoint = sorted(Path("checkpoints").glob("*.pt"), key=os.path.getmtime)[-1]
print(latest_checkpoint)
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint["model_state_dict"])

model.eval()

# Pick one in 18M examples
val_iter = Hdf5Dataset(pl.Path(folder) / valid_filename, num_entries=5)
val_iter = iter(val_iter)
src, trg = next(val_iter)
#src, trg = np.random.choice(val_iter)

print('input: "', src, '"')
print('target: "', trg, '"')

src = text_transform(src)

print(f"\033[91mModel's prediction: \033[0m", end="")
print(correct_sentence_vectorized(src, model))

checkpoints/model-epoch_9-2023-08-11-13:33:09.pt
input: " I think I'm goign to have to inform your posts a few times in order to gain even a small appreciation on how it all fits togethed but it looks very interesting! "
target: " I think I'm going to have to re-read your posts a few times in order to gain even a small understanding of how it all fits together but it certainly looks very interesting! "
[91mModel's prediction: [0mfever i am my my <UNK> <UNK> <UNK> in in my my i i i it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it it
