<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Objective" data-toc-modified-id="Objective-1">Objective</a></span></li><li><span><a href="#Data-Models" data-toc-modified-id="Data-Models-2">Data Models</a></span></li></ul></div>

In [3]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

------------

# Objective

The goal of this notebook is to train a language model from scratch on `wikitext-2`, which you can find [here](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/). Our focus will be on getting the pre-processing and training loops working in the traditional, non-federated setting.  

This notebook borrows heavily from [this](https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html) pytorch tutorial, which is absolutely outstanding.

--------

# Data Models

In [18]:
class Dictionary:
    """Base class for encoding a vocabulary."""
    
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        
    def __len__(self):
        return len(self.idx2word)
        
    def add_word(self, word):
        """Add a new word to the dictionary."""
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
                        
    def get_index(self, word):
        """Return the index of a word."""
        return self.word2idx[word]
            
    def add_words(self, words):
        """Add a list of words to the dictionary."""
        for word in words:
            self.add_word(word)
    
    def get_indexes(self, words):
        """Return the indexes of a list of words."""
        return [self.get_index(word) for word in words]

In [19]:
class Corpus:
    """Base class for modeling a corpus of text."""
    
    def __init__(self, dirpath):
        """Initialise a corpus given a dir with train, valid, and test .txt files."""
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(dirpath, "train.txt"))
        self.valid = self.tokenize(os.path.join(dirpath, "valid.txt"))
        self.test = self.test(os.path.join(dirpath, "test.txt"))
        
    def vectorize(self, fpath):
        """Return a tensor of indexes encoding the words in a file."""
        idxs = []
        with open(fpath, "r", encoding="utf8") as f:
            for line in f:
                words = f.split().append("<EOS>")
                self.dictionary.add_words(words)
                idxs.extend(self.dictionary.get_indexes(words))
        return torch.LongTensor(idxs)