In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


### Information
- We will do a few preliminary exercises and also build a character level MLP language model.
- This model will be similar to the model we did in class, except that we will have characters as tokens, not words.
- You will need a conda environment for this, here is general information on this.
 - https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html
 - PyTorch: https://anaconda.org/pytorch/pytorch
 
In the code below, FILL-IN the code necessary in the hint string provided.

### Preliminary exercises
- Please fill in the cells below with the asked for data.

In [2]:
torch.manual_seed(1)

<torch._C.Generator at 0x111334150>

In [2]:
e = nn.Embedding(10,5)
v = e(torch.tensor(3))
l = e.weight

In [3]:
l

Parameter containing:
tensor([[ 0.8246, -0.8029, -0.8771,  0.9178, -1.1951],
        [-0.7006,  0.8931,  0.5070,  0.5189,  0.8453],
        [ 0.9370, -0.5061,  0.7507, -0.0704,  0.3276],
        [-0.3582,  0.6241, -0.0508, -1.0284, -0.6014],
        [ 0.5474,  0.1593,  0.6124,  0.1262,  0.9572],
        [ 0.0257,  1.0389, -2.3217,  0.1992,  0.4714],
        [ 0.1319, -1.6183, -0.2496,  1.1969,  1.4502],
        [ 1.3331, -0.1006, -2.0760,  1.1813, -0.6348],
        [ 0.1365, -2.6525, -0.0141, -2.1362,  0.2617],
        [-0.1473,  0.2094, -0.7971,  0.2595,  0.7638]], requires_grad=True)

In [6]:
l.weight

Linear(in_features=10, out_features=5, bias=False)

In [19]:
# Create an embedding layer for a vocabulary of size 10 and the word vectors are each of dimension 5.
e = nn.Embedding(10,5)

# Extract the embedding for the word whose token index is 3. What is the shape of this vector?
v = e(torch.tensor(3))

# Extract the weight matrix from the layer e.
# Create a linear layer (with no bias) of size 10 by 5 and set it's data to the embedding matrix.
l = nn.Linear(5,10, bias = False)
l.weight = e.weight

# Insert inside of the assert below some sort of equality check between l.weight and e.weight; it should pass to true.
# Hint: look up torch.all() and torch.eq()
assert(torch.eq(e.weight, l.weight).all())

In [20]:
# Create a batch of size 2 with entries [0, 1, 2] and [2, 3, 4] in the data batch.
x = torch.tensor([[0, 1, 2], [2, 3, 4]])

In [24]:
# What is the dimesion of this batch ran through the embeding layer?
assert(e(x).shape == torch.Size([2,3,5]))

### Constants and configs used below.

In [55]:
DEVICE = "cpu"
LR = 4.0
BATCH_SIZE = 16
NUM_EPOCHS = 5
MARKER = '.'
# N-gram level; P(w_t | w_{t-1}, ..., w_{t-n+1}).
# We use 3 words to predict the next word.
n = 4
# Hidden layer dimension.
h = 20
# Word embedding dimension.
m = 20

### Get the dataset and the tokenizer.

In [56]:
class CharDataset(Dataset):
    def __init__(self, words, chars):
        self.words = words
        self.chars = chars
        # Inverse dictionaries mapping char tokens to unique ids and the reverse.
        # Tokens in this case are the unique chars we passed in above.
        # Each token should be mappend to a unique integer and MARKER should have token 0.
        # For example, stoi should be like {'.' -> 0, 'a' -> 1, 'b' -> 2} if I pass in chars = '.ab'.
        dic_stoi, dic_itos = {}, {}
        for ele in chars:
            count = 0
            dic_stoi[ele] = count
            dic_itos[count] = ele
            count += 1
        self.stoi = dic_stoi
        self.itos = dic_itos # Inverse mapping.

    def __len__(self):
        # Number of words.
        return len(self.words)

    def contains(self, word):
        # Check if word is in self.words and return True/False if it is, is not.
        return True if word in self.words else False

    def get_vocab_size(self):
        # Return the vocabulary size.
        return len(self.chars)

    def encode(self, word):
        # Express this word as a list of int ids. For example, maybe ".abc" -> [0, 1, 2, 3].
        # This assumes 'a' -> 1, etc.
        result = []
        for char in word:
            result.append(self.stoi[char])
        return result
    
    def decode(self, tokens):
        # For a set of tokens, return back the string.
        # For example, maybe [1, 1, 2] -> "aac"
        result = []
        for tok in tokens:
            result.append(self.dic_itos[tok])
        return result

    def __getitem__(self, idx):
        # This is used so we can loop over the data.
        word = self.words[idx]
        return self.encode(word)

In [64]:
'..'*3

'......'

In [65]:
def create_datasets(window, input_file = 'names.txt'):
    """
    This takes a file of words and separates all the words.
    It then gets all the characters present in the universe of words and then ouputs the statistics. 
    """
    with open(input_file, 'r') as f:
        data = f.read()
    # Split the file by new lines. You should get a list of names.
    words = data.split('\n')
    words = [word.replace(' ', '') for word in words] # This gets rid of any trailing and starting white spaces.
    words = [word for word in words if word] # Filter out all the empty words.

    
    chars = sorted(list(set([char for word in words for char in word]))) # This gets the universe of all characters.
    
    # Will force chars to have MARKER having index 0.
    chars= [MARKER] + chars
    
    # Pad each word with a context window of size n-1.
    # Why? a word like "abc" should becomes "..abc.." if the window is size 3.
    # This is some we can get pair of (x, y) data like this: ".." -> "a", ".a" -> "b", "ab" -> "c", "bc" -> ".", "c." -> "."
    # I.e. this allows us to know that "a" is a start character.
    # So you should get something like ["ab", "c"] -> ["..ab..", "..c.."], for example.
    words = [('.'*(window-1))+word+('.'*(window-1)) for word in words]
            
    print(f"The number of examples in the dataset: {len(words)}")
    print(f"The number of unique characters in the vocabulary: {len(chars)}")
    print(f"The vocabulary we have is: {''.join(chars)}")

    # Partition the input data into a training, validation, and the test set.
    out_of_sample_set_size = min(2000, int(len(words) * 0.1)) # We use 10% of the training set, or up to 2000 examples.
    test_set_size = 1500
    
    # First, get a random permutation of randomly permute of size len(words).
    # Then, convert this to a list. 
    # This index list is used below to get the train, validation, and test sets.
    rp = torch.randperm(len(words)).tolist()
    
    # Get train, validation, and test set.
    train_words = [words[i] for i in rp[:-out_of_sample_set_size]]
    validation_words = [words[i] for i in rp[-out_of_sample_set_size:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]    
    
    print(f"We've split up the dataset into {len(train_words)}, {len(validation_words)}, {len(test_words)} training, validation, and test examples")

    # But the data in the data set objects.
    train_dataset = CharDataset(train_words, chars)
    validation_dataset = CharDataset(validation_words, chars)
    test_dataset = CharDataset(test_words, chars)

    return train_dataset, validation_dataset, test_dataset

In [66]:
train_dataset, validation_dataset, test_dataset = create_datasets(n)

The number of examples in the dataset: 32033
The number of unique characters in the vocabulary: 27
The vocabulary we have is: .abcdefghijklmnopqrstuvwxyz
We've split up the dataset into 30033, 500, 1500 training, validation, and test examples


## Explore the data

In [67]:
# Get the first word in "train_dataset"
train_dataset.words[0]

'...eimy...'

In [68]:
# Get the stoi map of train_dataset. How many keys does it have?
print(len(train_dataset.stoi))
print(train_dataset.get_vocab_size())


27
27


### Get the dataloader

In [69]:
def create_dataloader(dataset, window):
    x_list = []
    y_list = []
    # For ech word.
    for i, word in enumerate(dataset):
        # Grab a context of size window and window-1 characters will be in x, 1 will be in y.
        for j, _ in enumerate(word):
            # If there is no widow of size window left, break.
            if j + window > len(word) - 1:
                break
            word_window = word[j:j+window]
            x, y = word_window[:window-1], word_window[-1]
            x_list.append(x)
            y_list.append(y)
            
    return DataLoader(
        TensorDataset(torch.tensor(x_list), torch.tensor(y_list)),
        BATCH_SIZE,
        shuffle=True
    )

In [70]:
train_dataloader = create_dataloader(train_dataset, n)
validation_dataloader = create_dataloader(validation_dataset, n)
test_dataloader = create_dataloader(test_dataset, n)

### Set up the model
- Identical to lecture. Please look over that!

In [33]:
# One of the first Neural language models!
class CharacterNeuralLanguageModel(nn.Module):
    def __init__(self, V, m, h, n):
        super(CharacterNeuralLanguageModel, self).__init__()
        
        # Vocabulary size.
        self.V = " FILL_IN "
        
        # Embedding dimension, per word.
        self.m = " FILL_IN "
        
        # Hidden dimension.
        self.h = " FILL_IN "
        
        # N in "N-gram"
        self.n = " FILL_IN "
        
        # Can you change all this stuff to use nn.Linear?
        # Ca also use nn.Parameter(torch.zeros(V, m)) for self.C but then we need one-hot and this is slow.
        self.C = " FILL_IN " 
        self.H = " FILL_IN "
        self.W = " FILL_IN "
        self.U = " FILL_IN "
        
        self.b = " FILL_IN ")
        self.d = " FILL_IN "
        
        self.init_weights()

    def init_weights(self):
        # Intitialize C, H, W, U in a nice way. Use xavier initialization for the weights.
        # On a first run, just pass.
        " FILL_IN "
        pass # Replace this pass with something else.
        
    def forward(self, x):
        
        # x is of dimenson N = batch size X n-1
        
        # N X (n-1) X m 
        x = " FILL_IN "
        
        # N
        N = " FILL_IN "
        
        # N X (n-1) * m
        x = " FILL_IN "
    
        # N X V
        y = self.b + torch.matmul(x, self.W) + torch.matmul(nn.Tanh()(self.d + torch.matmul(x, self.H)), self.U)
        
        return y

### Set up the model.

In [34]:
# Identical to lecture.
criterion = " FILL_IN "
model = CharacterNeuralLanguageModel(
    train_dataset.get_vocab_size(), m, h, n
).to(DEVICE)
optimizer = " FILL_IN "
scheduler = " FILL_IN "

In [None]:
# How many parameters does the neural network have?
# Hint: look up model.named_parameters and the method "nelement" on a tensor.
# See also the XOR notebook where we count the gradients that are 0.
# There, we loop over the parameters.
number_parameters = " FILL_IN "

### Train the model.

In [35]:
def calculate_perplexity(total_loss, total_batches):
    return " FILL_IN "

In [44]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (x, y) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        
        logits = model(x)
                        
        # Get the loss.
        loss = " FILL_IN "

        # Do back propagation.
        loss.backward()
                        
        # Clip the gradients so they don't explode. Look at how this is done in lecture.
        " FILL_IN "
        
        # Do an optimization step.
        " FILL_IN "
        
        total_loss += loss.item()
        total_batches += 1
                
        if idx % log_interval == 0 and idx > 0:
            perplexity = calculate_perplexity(total_loss,  total_batches)
            print(
                "| epoch {:3d} "
                "| {:5d}/{:5d} batches "
                "| perplexity {:8.3f} "
                "| loss {:8.3f} "
                .format(
                    epoch,
                    idx,
                    len(dataloader),
                    perplexity,
                    total_loss / total_batches,
                )
            )
            total_loss, total_batches = 0.0, 0

In [45]:
def evaluate(dataloader, model, criterion):
    model.eval()
    total_loss, total_batches = 0.0, 0

    with torch.no_grad():
        for idx, (x, y) in enumerate(dataloader):
            logits = model(x)
            total_loss += criterion(input=logits, target=y.squeeze(-1)).item()
            total_batches += 1
    return total_loss / total_batches, calculate_perplexity(total_loss,  total_batches)

In [46]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    loss_val, perplexity_val = " FILL_IN "
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} "
        "| time: {:5.2f}s "
        "| valid perplexity {:8.3f} "
        "| valid loss {:8.3f}".format(
            epoch,
            time.time() - epoch_start_time,
            perplexity_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
loss_test, perplexity_test = " FILL_IN "
print("test perplexity {:8.3f} | test loss {:8.3f} ".format(perplexity_test, loss_test))

1020it [00:00, 2664.65it/s]

| epoch   1 |   500/15255 batches | perplexity   10.395 | loss    2.341 
| epoch   1 |  1000/15255 batches | perplexity    9.218 | loss    2.221 


1863it [00:00, 2766.71it/s]

| epoch   1 |  1500/15255 batches | perplexity    8.665 | loss    2.159 
| epoch   1 |  2000/15255 batches | perplexity    8.565 | loss    2.148 


2995it [00:01, 2814.72it/s]

| epoch   1 |  2500/15255 batches | perplexity    8.649 | loss    2.157 
| epoch   1 |  3000/15255 batches | perplexity    8.368 | loss    2.124 


3847it [00:01, 2832.80it/s]

| epoch   1 |  3500/15255 batches | perplexity    8.517 | loss    2.142 
| epoch   1 |  4000/15255 batches | perplexity    8.269 | loss    2.112 


4989it [00:01, 2848.82it/s]

| epoch   1 |  4500/15255 batches | perplexity    8.159 | loss    2.099 
| epoch   1 |  5000/15255 batches | perplexity    8.112 | loss    2.093 


5843it [00:02, 2837.26it/s]

| epoch   1 |  5500/15255 batches | perplexity    8.444 | loss    2.133 
| epoch   1 |  6000/15255 batches | perplexity    8.106 | loss    2.093 


6980it [00:02, 2815.81it/s]

| epoch   1 |  6500/15255 batches | perplexity    8.148 | loss    2.098 
| epoch   1 |  7000/15255 batches | perplexity    8.361 | loss    2.124 


7825it [00:02, 2801.48it/s]

| epoch   1 |  7500/15255 batches | perplexity    8.250 | loss    2.110 
| epoch   1 |  8000/15255 batches | perplexity    8.428 | loss    2.132 


8955it [00:03, 2817.10it/s]

| epoch   1 |  8500/15255 batches | perplexity    8.339 | loss    2.121 
| epoch   1 |  9000/15255 batches | perplexity    8.346 | loss    2.122 


9807it [00:03, 2826.06it/s]

| epoch   1 |  9500/15255 batches | perplexity    8.264 | loss    2.112 
| epoch   1 | 10000/15255 batches | perplexity    8.040 | loss    2.084 


10955it [00:03, 2849.26it/s]

| epoch   1 | 10500/15255 batches | perplexity    7.991 | loss    2.078 
| epoch   1 | 11000/15255 batches | perplexity    8.127 | loss    2.095 


11811it [00:04, 2836.52it/s]

| epoch   1 | 11500/15255 batches | perplexity    8.098 | loss    2.092 
| epoch   1 | 12000/15255 batches | perplexity    8.351 | loss    2.122 


12943it [00:04, 2773.03it/s]

| epoch   1 | 12500/15255 batches | perplexity    8.200 | loss    2.104 
| epoch   1 | 13000/15255 batches | perplexity    8.236 | loss    2.109 


13778it [00:04, 2755.20it/s]

| epoch   1 | 13500/15255 batches | perplexity    8.186 | loss    2.102 
| epoch   1 | 14000/15255 batches | perplexity    8.155 | loss    2.099 


14906it [00:05, 2807.07it/s]

| epoch   1 | 14500/15255 batches | perplexity    8.170 | loss    2.101 
| epoch   1 | 15000/15255 batches | perplexity    8.241 | loss    2.109 


15255it [00:05, 2790.19it/s]


-----------------------------------------------------------
| end of epoch   1 | time:  5.52s | valid perplexity    8.537 | valid loss    2.144
-----------------------------------------------------------


274it [00:00, 2732.65it/s]

| epoch   2 |   500/15255 batches | perplexity    7.641 | loss    2.034 


835it [00:00, 2782.61it/s]

| epoch   2 |  1000/15255 batches | perplexity    7.850 | loss    2.060 


1399it [00:00, 2804.91it/s]

| epoch   2 |  1500/15255 batches | perplexity    7.600 | loss    2.028 


1964it [00:00, 2811.95it/s]

| epoch   2 |  2000/15255 batches | perplexity    7.671 | loss    2.037 


2530it [00:00, 2818.70it/s]

| epoch   2 |  2500/15255 batches | perplexity    7.594 | loss    2.027 



2812it [00:01, 2814.70it/s]

| epoch   2 |  3000/15255 batches | perplexity    7.421 | loss    2.004 


3380it [00:01, 2801.98it/s]

| epoch   2 |  3500/15255 batches | perplexity    7.258 | loss    1.982 


3947it [00:01, 2798.70it/s]

| epoch   2 |  4000/15255 batches | perplexity    7.622 | loss    2.031 


4507it [00:01, 2785.55it/s]

| epoch   2 |  4500/15255 batches | perplexity    7.404 | loss    2.002 



4786it [00:01, 2774.04it/s]

| epoch   2 |  5000/15255 batches | perplexity    7.516 | loss    2.017 


5342it [00:01, 2766.06it/s]

| epoch   2 |  5500/15255 batches | perplexity    7.473 | loss    2.011 


5912it [00:02, 2808.32it/s]

| epoch   2 |  6000/15255 batches | perplexity    7.408 | loss    2.003 


6485it [00:02, 2837.76it/s]

| epoch   2 |  6500/15255 batches | perplexity    7.566 | loss    2.024 


7062it [00:02, 2859.66it/s]

| epoch   2 |  7000/15255 batches | perplexity    7.434 | loss    2.006 



7355it [00:02, 2879.75it/s]

| epoch   2 |  7500/15255 batches | perplexity    7.548 | loss    2.021 


7931it [00:02, 2868.77it/s]

| epoch   2 |  8000/15255 batches | perplexity    7.632 | loss    2.032 


8503it [00:03, 2838.10it/s]

| epoch   2 |  8500/15255 batches | perplexity    7.479 | loss    2.012 



8787it [00:03, 2826.63it/s]

| epoch   2 |  9000/15255 batches | perplexity    7.308 | loss    1.989 


9351it [00:03, 2791.36it/s]

| epoch   2 |  9500/15255 batches | perplexity    7.484 | loss    2.013 


9924it [00:03, 2829.92it/s]

| epoch   2 | 10000/15255 batches | perplexity    7.436 | loss    2.006 


10491it [00:03, 2800.87it/s]

| epoch   2 | 10500/15255 batches | perplexity    7.548 | loss    2.021 


11062it [00:03, 2831.43it/s]

| epoch   2 | 11000/15255 batches | perplexity    7.397 | loss    2.001 



11346it [00:04, 2826.94it/s]

| epoch   2 | 11500/15255 batches | perplexity    7.615 | loss    2.030 


11910it [00:04, 2790.04it/s]

| epoch   2 | 12000/15255 batches | perplexity    7.558 | loss    2.023 


12469it [00:04, 2780.77it/s]

| epoch   2 | 12500/15255 batches | perplexity    7.567 | loss    2.024 


13028it [00:04, 2774.59it/s]

| epoch   2 | 13000/15255 batches | perplexity    7.391 | loss    2.000 



13306it [00:04, 2775.32it/s]

| epoch   2 | 13500/15255 batches | perplexity    7.469 | loss    2.011 


13866it [00:04, 2776.71it/s]

| epoch   2 | 14000/15255 batches | perplexity    7.566 | loss    2.024 


14422it [00:05, 2727.17it/s]

| epoch   2 | 14500/15255 batches | perplexity    7.554 | loss    2.022 


14968it [00:05, 2715.58it/s]

| epoch   2 | 15000/15255 batches | perplexity    7.385 | loss    1.999 


15255it [00:05, 2792.02it/s]


-----------------------------------------------------------
| end of epoch   2 | time:  5.49s | valid perplexity    7.613 | valid loss    2.030
-----------------------------------------------------------


261it [00:00, 2609.50it/s]

| epoch   3 |   500/15255 batches | perplexity    7.343 | loss    1.994 


815it [00:00, 2740.51it/s]

| epoch   3 |  1000/15255 batches | perplexity    7.377 | loss    1.998 


1388it [00:00, 2813.76it/s]

| epoch   3 |  1500/15255 batches | perplexity    7.490 | loss    2.014 


1969it [00:00, 2863.16it/s]

| epoch   3 |  2000/15255 batches | perplexity    7.335 | loss    1.993 


2540it [00:00, 2819.87it/s]

| epoch   3 |  2500/15255 batches | perplexity    7.630 | loss    2.032 



2823it [00:01, 2806.40it/s]

| epoch   3 |  3000/15255 batches | perplexity    7.422 | loss    2.004 


3384it [00:01, 2788.39it/s]

| epoch   3 |  3500/15255 batches | perplexity    7.604 | loss    2.029 


3944it [00:01, 2792.05it/s]

| epoch   3 |  4000/15255 batches | perplexity    7.365 | loss    1.997 


4506it [00:01, 2791.01it/s]

| epoch   3 |  4500/15255 batches | perplexity    7.390 | loss    2.000 



4786it [00:01, 2761.92it/s]

| epoch   3 |  5000/15255 batches | perplexity    7.365 | loss    1.997 


5360it [00:01, 2815.78it/s]

| epoch   3 |  5500/15255 batches | perplexity    7.296 | loss    1.987 


5930it [00:02, 2832.26it/s]

| epoch   3 |  6000/15255 batches | perplexity    7.397 | loss    2.001 


6502it [00:02, 2842.52it/s]

| epoch   3 |  6500/15255 batches | perplexity    7.433 | loss    2.006 



6787it [00:02, 2824.64it/s]

| epoch   3 |  7000/15255 batches | perplexity    7.494 | loss    2.014 


7361it [00:02, 2844.46it/s]

| epoch   3 |  7500/15255 batches | perplexity    7.447 | loss    2.008 


7934it [00:02, 2846.70it/s]

| epoch   3 |  8000/15255 batches | perplexity    7.312 | loss    1.990 


8507it [00:03, 2849.75it/s]

| epoch   3 |  8500/15255 batches | perplexity    7.344 | loss    1.994 



8792it [00:03, 2847.72it/s]

| epoch   3 |  9000/15255 batches | perplexity    7.540 | loss    2.020 


9369it [00:03, 2851.70it/s]

| epoch   3 |  9500/15255 batches | perplexity    7.666 | loss    2.037 


9945it [00:03, 2856.80it/s]

| epoch   3 | 10000/15255 batches | perplexity    7.343 | loss    1.994 


10515it [00:03, 2836.27it/s]

| epoch   3 | 10500/15255 batches | perplexity    7.440 | loss    2.007 



10801it [00:03, 2841.84it/s]

| epoch   3 | 11000/15255 batches | perplexity    7.353 | loss    1.995 


11380it [00:04, 2866.38it/s]

| epoch   3 | 11500/15255 batches | perplexity    7.160 | loss    1.969 


11953it [00:04, 2854.96it/s]

| epoch   3 | 12000/15255 batches | perplexity    7.197 | loss    1.974 


12530it [00:04, 2861.95it/s]

| epoch   3 | 12500/15255 batches | perplexity    7.317 | loss    1.990 



12817it [00:04, 2806.87it/s]

| epoch   3 | 13000/15255 batches | perplexity    7.448 | loss    2.008 


13387it [00:04, 2820.13it/s]

| epoch   3 | 13500/15255 batches | perplexity    7.413 | loss    2.003 


13955it [00:04, 2810.67it/s]

| epoch   3 | 14000/15255 batches | perplexity    7.544 | loss    2.021 


14520it [00:05, 2813.85it/s]

| epoch   3 | 14500/15255 batches | perplexity    7.338 | loss    1.993 



14802it [00:05, 2810.08it/s]

| epoch   3 | 15000/15255 batches | perplexity    7.338 | loss    1.993 


15255it [00:05, 2818.28it/s]


-----------------------------------------------------------
| end of epoch   3 | time:  5.44s | valid perplexity    7.552 | valid loss    2.022
-----------------------------------------------------------


277it [00:00, 2763.62it/s]

| epoch   4 |   500/15255 batches | perplexity    7.177 | loss    1.971 


828it [00:00, 2698.84it/s]

| epoch   4 |  1000/15255 batches | perplexity    7.355 | loss    1.995 


1395it [00:00, 2785.37it/s]

| epoch   4 |  1500/15255 batches | perplexity    7.410 | loss    2.003 


1969it [00:00, 2828.88it/s]

| epoch   4 |  2000/15255 batches | perplexity    7.530 | loss    2.019 


2535it [00:00, 2797.40it/s]

| epoch   4 |  2500/15255 batches | perplexity    7.309 | loss    1.989 



2822it [00:01, 2817.60it/s]

| epoch   4 |  3000/15255 batches | perplexity    7.410 | loss    2.003 


3384it [00:01, 2793.16it/s]

| epoch   4 |  3500/15255 batches | perplexity    7.353 | loss    1.995 


3951it [00:01, 2812.34it/s]

| epoch   4 |  4000/15255 batches | perplexity    7.434 | loss    2.006 


4514it [00:01, 2767.43it/s]

| epoch   4 |  4500/15255 batches | perplexity    7.403 | loss    2.002 



4799it [00:01, 2790.65it/s]

| epoch   4 |  5000/15255 batches | perplexity    7.564 | loss    2.023 


5369it [00:01, 2811.25it/s]

| epoch   4 |  5500/15255 batches | perplexity    7.329 | loss    1.992 


5939it [00:02, 2830.86it/s]

| epoch   4 |  6000/15255 batches | perplexity    7.497 | loss    2.014 


6510it [00:02, 2822.88it/s]

| epoch   4 |  6500/15255 batches | perplexity    7.303 | loss    1.988 



6798it [00:02, 2838.67it/s]

| epoch   4 |  7000/15255 batches | perplexity    7.371 | loss    1.998 


7372it [00:02, 2802.92it/s]

| epoch   4 |  7500/15255 batches | perplexity    7.250 | loss    1.981 


7945it [00:02, 2821.86it/s]

| epoch   4 |  8000/15255 batches | perplexity    7.332 | loss    1.992 


8508it [00:03, 2785.03it/s]

| epoch   4 |  8500/15255 batches | perplexity    7.504 | loss    2.015 



8787it [00:03, 2779.57it/s]

| epoch   4 |  9000/15255 batches | perplexity    7.362 | loss    1.996 


9360it [00:03, 2823.59it/s]

| epoch   4 |  9500/15255 batches | perplexity    7.432 | loss    2.006 


9940it [00:03, 2861.10it/s]

| epoch   4 | 10000/15255 batches | perplexity    7.439 | loss    2.007 


10522it [00:03, 2883.68it/s]

| epoch   4 | 10500/15255 batches | perplexity    7.410 | loss    2.003 



10811it [00:03, 2868.49it/s]

| epoch   4 | 11000/15255 batches | perplexity    7.365 | loss    1.997 


11377it [00:04, 2769.17it/s]

| epoch   4 | 11500/15255 batches | perplexity    7.168 | loss    1.970 


11931it [00:04, 2745.04it/s]

| epoch   4 | 12000/15255 batches | perplexity    7.516 | loss    2.017 


12481it [00:04, 2736.44it/s]

| epoch   4 | 12500/15255 batches | perplexity    7.639 | loss    2.033 


13029it [00:04, 2716.92it/s]

| epoch   4 | 13000/15255 batches | perplexity    7.412 | loss    2.003 



13314it [00:04, 2754.35it/s]

| epoch   4 | 13500/15255 batches | perplexity    7.516 | loss    2.017 


13874it [00:04, 2747.42it/s]

| epoch   4 | 14000/15255 batches | perplexity    7.161 | loss    1.969 


14438it [00:05, 2786.14it/s]

| epoch   4 | 14500/15255 batches | perplexity    7.408 | loss    2.003 


15010it [00:05, 2825.56it/s]

| epoch   4 | 15000/15255 batches | perplexity    7.468 | loss    2.011 


15255it [00:05, 2793.21it/s]


-----------------------------------------------------------
| end of epoch   4 | time:  5.49s | valid perplexity    7.567 | valid loss    2.024
-----------------------------------------------------------


270it [00:00, 2696.09it/s]

| epoch   5 |   500/15255 batches | perplexity    7.320 | loss    1.991 


828it [00:00, 2763.34it/s]

| epoch   5 |  1000/15255 batches | perplexity    7.157 | loss    1.968 


1388it [00:00, 2779.93it/s]

| epoch   5 |  1500/15255 batches | perplexity    7.243 | loss    1.980 


1946it [00:00, 2778.81it/s]

| epoch   5 |  2000/15255 batches | perplexity    7.380 | loss    1.999 


2502it [00:00, 2769.97it/s]

| epoch   5 |  2500/15255 batches | perplexity    7.417 | loss    2.004 



2780it [00:01, 2741.99it/s]

| epoch   5 |  3000/15255 batches | perplexity    7.711 | loss    2.043 


3337it [00:01, 2762.80it/s]

| epoch   5 |  3500/15255 batches | perplexity    7.396 | loss    2.001 


3895it [00:01, 2772.09it/s]

| epoch   5 |  4000/15255 batches | perplexity    7.320 | loss    1.991 


4453it [00:01, 2770.31it/s]

| epoch   5 |  4500/15255 batches | perplexity    7.375 | loss    1.998 


5018it [00:01, 2795.49it/s]

| epoch   5 |  5000/15255 batches | perplexity    7.387 | loss    2.000 



5307it [00:01, 2823.22it/s]

| epoch   5 |  5500/15255 batches | perplexity    7.373 | loss    1.998 


5877it [00:02, 2831.14it/s]

| epoch   5 |  6000/15255 batches | perplexity    7.279 | loss    1.985 


6449it [00:02, 2822.75it/s]

| epoch   5 |  6500/15255 batches | perplexity    7.585 | loss    2.026 


7022it [00:02, 2845.55it/s]

| epoch   5 |  7000/15255 batches | perplexity    7.280 | loss    1.985 



7307it [00:02, 2824.48it/s]

| epoch   5 |  7500/15255 batches | perplexity    7.514 | loss    2.017 


7877it [00:02, 2818.34it/s]

| epoch   5 |  8000/15255 batches | perplexity    7.403 | loss    2.002 


8445it [00:03, 2829.45it/s]

| epoch   5 |  8500/15255 batches | perplexity    7.434 | loss    2.006 


9011it [00:03, 2824.95it/s]

| epoch   5 |  9000/15255 batches | perplexity    7.324 | loss    1.991 



9299it [00:03, 2837.35it/s]

| epoch   5 |  9500/15255 batches | perplexity    7.408 | loss    2.003 


9866it [00:03, 2825.08it/s]

| epoch   5 | 10000/15255 batches | perplexity    7.193 | loss    1.973 


10433it [00:03, 2787.58it/s]

| epoch   5 | 10500/15255 batches | perplexity    7.396 | loss    2.001 


11003it [00:03, 2812.48it/s]

| epoch   5 | 11000/15255 batches | perplexity    7.469 | loss    2.011 



11288it [00:04, 2823.51it/s]

| epoch   5 | 11500/15255 batches | perplexity    7.499 | loss    2.015 


11857it [00:04, 2828.97it/s]

| epoch   5 | 12000/15255 batches | perplexity    7.459 | loss    2.009 


12425it [00:04, 2828.33it/s]

| epoch   5 | 12500/15255 batches | perplexity    7.261 | loss    1.982 


12996it [00:04, 2839.17it/s]

| epoch   5 | 13000/15255 batches | perplexity    7.382 | loss    1.999 


13577it [00:04, 2870.13it/s]

| epoch   5 | 13500/15255 batches | perplexity    7.552 | loss    2.022 



13869it [00:04, 2882.54it/s]

| epoch   5 | 14000/15255 batches | perplexity    7.383 | loss    1.999 


14453it [00:05, 2900.96it/s]

| epoch   5 | 14500/15255 batches | perplexity    7.210 | loss    1.975 


15035it [00:05, 2897.91it/s]

| epoch   5 | 15000/15255 batches | perplexity    7.519 | loss    2.017 


15255it [00:05, 2814.68it/s]


-----------------------------------------------------------
| end of epoch   5 | time:  5.45s | valid perplexity    7.565 | valid loss    2.024
-----------------------------------------------------------
Checking the results of test dataset.
test perplexity    7.464 | test loss    2.010 


Hint: For the above, you should see your loss around 2.0 and going down. Similarly to perplexity which should be aroud 7 to 8.

## Generate some text.

In [59]:
def generate_word(model, dataset, window):
    generated_word = []
    # Set the context to a window-1 length array having just the MARKER character's token_id.
    context = " FILL_IN "
    
    while True:
        logits = model(torch.tensor(context).view(1, -1))
        
        # Get the probabilities from the logits.
        # Hint: softmax!
        probs = " FILL_IN "
        
        # Get 1 sample from a multinomial having the above probabilities.
        token_id = torch.multinomial(" FILL_IN ").item()
        
        # Append the token_id to the generated word.
        " FILL_IN "
        
        # Move the context over 1, drop the first (oldest) token and apped the new one above.
        # The size of the resulting context should be the same.
        # For exaple, if it was "[0, 1, 2]" and you generated 4, it should now be [1, 2, 4].
        context = " FILL_IN "
        
        if token_id == 0:
            # If you generate token_id = 0, i.e. '.', break out.
            " FILL_IN "
    # Return and decode the generated word to a string.        
    return ''.join(dataset.decode(generated_word))

In [61]:
torch.manual_seed(1)
for _ in range(50):
    print(generate_word(model, train_dataset, n))

aha.
ele.
lia.
aldi.
jarorsse.
dez.
bemartti.
rielci.
revy.
madlais.
hoanda.
dacelia.
kalaliey.
chis.
mayas.
tya.
jon.
ama.
tze.
karies.
jos.
ahkl.
bamaka.
anyaamiush.
kazerher.
jami.
nnek.
maremellen.
toquyla.
nzygu.
enyl.
yanstram.
ahazoriexsunya.
sermontiroonn.
eifiah.
rosi.
rouivan.
ynn.
ahdityn.
jassavoli.
wun.
jayvarante.
nor.
ilyn.
marri.
allifare.
kalyi.
daslenshanna.
daniellaenimanaililah.
cyle.
