In [35]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset
import time
from tqdm import tqdm

### Information
- We will do a few preliminary exercises and also build a character level MLP language model.
- This model will be similar to the model we did in class, except that we will have characters as tokens, not words.
- You will need a conda environment for this, here is general information on this.
 - https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html
 - PyTorch: https://anaconda.org/pytorch/pytorch
 
In the code below, FILL-IN the code necessary in the hint string provided.

### Preliminary exercises
- Please fill in the cells below with the asked for data.

In [36]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f8ff742acf0>

In [37]:
# Create an embedding layer for a vocabulary of size 10 and the word vectors are each of dimension 5.
e = nn.Embedding(10,5)

# Extract the embedding for the word whose token index is 3. What is the shape of this vector?
v = e(torch.tensor(3))

# Extract the weight matrix from the layer e.
# Create a linear layer (with no bias) of size 10 by 5 and set it's data to the embedding matrix.
l = nn.Linear(5,10, bias = False)
l.weight = e.weight

# Insert inside of the assert below some sort of equality check between l.weight and e.weight; it should pass to true.
# Hint: look up torch.all() and torch.eq()
assert(torch.eq(e.weight, l.weight).all())

In [38]:
# Create a batch of size 2 with entries [0, 1, 2] and [2, 3, 4] in the data batch.
x = torch.tensor([[0, 1, 2], [2, 3, 4]])

In [39]:
# What is the dimesion of this batch ran through the embeding layer?
assert(e(x).shape == torch.Size([2,3,5]))

### Constants and configs used below.

In [40]:
DEVICE = "cpu"
LR = 4.0
BATCH_SIZE = 16
NUM_EPOCHS = 5
MARKER = '.'
# N-gram level; P(w_t | w_{t-1}, ..., w_{t-n+1}).
# We use 3 words to predict the next word.
n = 4
# Hidden layer dimension.
h = 20
# Word embedding dimension.
m = 20

### Get the dataset and the tokenizer.

In [41]:
class CharDataset(Dataset):
    def __init__(self, words, chars):
        self.words = words
        self.chars = chars
        # Inverse dictionaries mapping char tokens to unique ids and the reverse.
        # Tokens in this case are the unique chars we passed in above.
        # Each token should be mappend to a unique integer and MARKER should have token 0.
        # For example, stoi should be like {'.' -> 0, 'a' -> 1, 'b' -> 2} if I pass in chars = '.ab'.
        dic_stoi, dic_itos = {}, {}
        count = 0
        for ele in chars:
            dic_stoi[ele] = count
            dic_itos[count] = ele
            count += 1
        self.stoi = dic_stoi
        self.itos = dic_itos # Inverse mapping.

    def __len__(self):
        # Number of words.
        return len(self.words)

    def contains(self, word):
        # Check if word is in self.words and return True/False if it is, is not.
        return True if word in self.words else False

    def get_vocab_size(self):
        # Return the vocabulary size.
        return len(self.chars)

    def encode(self, word):
        # Express this word as a list of int ids. For example, maybe ".abc" -> [0, 1, 2, 3].
        # This assumes 'a' -> 1, etc.
        result = []
        for char in word:
            result.append(self.stoi[char])
        return result
    
    def decode(self, tokens):
        # For a set of tokens, return back the string.
        # For example, maybe [1, 1, 2] -> "aac"
        result = []
        for tok in tokens:
            result.append(self.dic_itos[tok])
        return result

    def __getitem__(self, idx):
        # This is used so we can loop over the data.
        word = self.words[idx]
        return self.encode(word)

In [42]:
def create_datasets(window, input_file = 'names.txt'):
    """
    This takes a file of words and separates all the words.
    It then gets all the characters present in the universe of words and then ouputs the statistics. 
    """
    with open(input_file, 'r') as f:
        data = f.read()
    # Split the file by new lines. You should get a list of names.
    words = data.split('\n')
    words = [word.replace(' ', '') for word in words] # This gets rid of any trailing and starting white spaces.
    words = [word for word in words if word] # Filter out all the empty words.

    
    chars = sorted(list(set([char for word in words for char in word]))) # This gets the universe of all characters.
    
    # Will force chars to have MARKER having index 0.
    chars= [MARKER] + chars
    
    # Pad each word with a context window of size n-1.
    # Why? a word like "abc" should becomes "..abc.." if the window is size 3.
    # This is some we can get pair of (x, y) data like this: ".." -> "a", ".a" -> "b", "ab" -> "c", "bc" -> ".", "c." -> "."
    # I.e. this allows us to know that "a" is a start character.
    # So you should get something like ["ab", "c"] -> ["..ab..", "..c.."], for example.
    words = [('.'*(window-1))+word+('.'*(window-1)) for word in words]
            
    print(f"The number of examples in the dataset: {len(words)}")
    print(f"The number of unique characters in the vocabulary: {len(chars)}")
    print(f"The vocabulary we have is: {''.join(chars)}")

    # Partition the input data into a training, validation, and the test set.
    out_of_sample_set_size = min(2000, int(len(words) * 0.1)) # We use 10% of the training set, or up to 2000 examples.
    test_set_size = 1500
    
    # First, get a random permutation of randomly permute of size len(words).
    # Then, convert this to a list. 
    # This index list is used below to get the train, validation, and test sets.
    rp = torch.randperm(len(words)).tolist()
    
    # Get train, validation, and test set.
    train_words = [words[i] for i in rp[:-out_of_sample_set_size]]
    validation_words = [words[i] for i in rp[-out_of_sample_set_size:-test_set_size]]
    test_words = [words[i] for i in rp[-test_set_size:]]    
    
    print(f"We've split up the dataset into {len(train_words)}, {len(validation_words)}, {len(test_words)} training, validation, and test examples")

    # But the data in the data set objects.
    train_dataset = CharDataset(train_words, chars)
    validation_dataset = CharDataset(validation_words, chars)
    test_dataset = CharDataset(test_words, chars)

    return train_dataset, validation_dataset, test_dataset

In [43]:
train_dataset, validation_dataset, test_dataset = create_datasets(n)

The number of examples in the dataset: 32033
The number of unique characters in the vocabulary: 27
The vocabulary we have is: .abcdefghijklmnopqrstuvwxyz
We've split up the dataset into 30033, 500, 1500 training, validation, and test examples


## Explore the data

In [44]:
# Get the first word in "train_dataset"
train_dataset.words[0]

'...niyam...'

In [45]:
# Get the stoi map of train_dataset. How many keys does it have?
print(len(train_dataset.stoi))
print(train_dataset.get_vocab_size())


27
27


### Get the dataloader

In [46]:
def create_dataloader(dataset, window):
    x_list = []
    y_list = []
    # For ech word.
    for i, word in enumerate(dataset):
        # Grab a context of size window and window-1 characters will be in x, 1 will be in y.
        for j, _ in enumerate(word):
            # If there is no widow of size window left, break.
            if j + window > len(word) - 1:
                break
            word_window = word[j:j+window]
            x, y = word_window[:window-1], word_window[-1]
            x_list.append(x)
            y_list.append(y)
            
    return DataLoader(
        TensorDataset(torch.tensor(x_list), torch.tensor(y_list)),
        BATCH_SIZE,
        shuffle=True
    )

In [47]:
train_dataloader = create_dataloader(train_dataset, n)
validation_dataloader = create_dataloader(validation_dataset, n)
test_dataloader = create_dataloader(test_dataset, n)

### Set up the model
- Identical to lecture. Please look over that!

In [48]:
# One of the first Neural language models!
class CharacterNeuralLanguageModel(nn.Module):
    def __init__(self, V, m, h, n):
        super(CharacterNeuralLanguageModel, self).__init__()
        
        # Vocabulary size.
        self.V = V
        
        # Embedding dimension, per word.
        self.m = m
        
        # Hidden dimension.
        self.h = h
        
        # N in "N-gram"
        self.n = n
        
        # Can you change all this stuff to use nn.Linear?
        # Ca also use nn.Parameter(torch.zeros(V, m)) for self.C but then we need one-hot and this is slow.
        self.C = nn.Embedding(V, m)
        self.H = nn.Parameter(torch.zeros((n-1) * m, h))
        self.W = nn.Parameter(torch.zeros((n-1) * m, V))
        self.U = nn.Parameter(torch.zeros(h, V))
        
        self.b = torch.nn.Parameter(torch.ones(V))
        self.d = torch.nn.Parameter(torch.ones(h))
        
        self.init_weights()

    def init_weights(self):
        # Intitialize C, H, W, U in a nice way. Use xavier initialization for the weights.
        # On a first run, just pass.
        with torch.no_grad():
            torch.nn.init.xavier_uniform_(self.C.weight)
            torch.nn.init.xavier_uniform_(self.H)
            torch.nn.init.xavier_uniform_(self.W)
            torch.nn.init.xavier_uniform_(self.U)
        
        
    def forward(self, x):
        
        # x is of dimenson N = batch size X n-1
        
        # N X (n-1) X m 
        x = self.C(x)
        
        # N
        N = x.shape[0]
        
        # N X (n-1) * m
        x = x.view(N, -1)
    
        # N X V
        y = self.b + torch.matmul(x, self.W) + torch.matmul(nn.Tanh()(self.d + torch.matmul(x, self.H)), self.U)
        
        return y

### Set up the model.

In [49]:
# Identical to lecture.
criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
model = CharacterNeuralLanguageModel(
    train_dataset.get_vocab_size(), m, h, n).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [50]:
# How many parameters does the neural network have?
# Hint: look up model.named_parameters and the method "nelement" on a tensor.
# See also the XOR notebook where we count the gradients that are 0.
# There, we loop over the parameters.
number_parameters = 0
for name, param in model.named_parameters():
    number_parameters += 1
    print(name, param.shape, param.requires_grad)
print("number_parameters is {}".format(number_parameters))

H torch.Size([60, 20]) True
W torch.Size([60, 27]) True
U torch.Size([20, 27]) True
b torch.Size([27]) True
d torch.Size([20]) True
C.weight torch.Size([27, 20]) True
number_parameters is 6


### Train the model.

In [51]:
def calculate_perplexity(total_loss, total_batches):
    return torch.exp(torch.tensor(total_loss / total_batches)).item()

In [55]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_loss, total_batches = 0.0, 0.0
    log_interval = 500

    for idx, (x, y) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        
        logits = model(x)
                        
        # Get the loss.
        loss = criterion(input=logits, target=y.view(-1))

        # Do back propagation.
        loss.backward()
                        
        # Clip the gradients so they don't explode. Look at how this is done in lecture.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # Do an optimization step.
        optimizer.step()
        total_loss += loss.item()
        total_batches += 1
                
        if idx % log_interval == 0 and idx > 0:
            perplexity = calculate_perplexity(total_loss,  total_batches)
            print(
                "| epoch {:3d} "
                "| {:5d}/{:5d} batches "
                "| perplexity {:8.3f} "
                "| loss {:8.3f} "
                .format(
                    epoch,
                    idx,
                    len(dataloader),
                    perplexity,
                    total_loss / total_batches,
                )
            )
            total_loss, total_batches = 0.0, 0

In [56]:
def evaluate(dataloader, model, criterion):
    model.eval()
    total_loss, total_batches = 0.0, 0

    with torch.no_grad():
        for idx, (x, y) in enumerate(dataloader):
            logits = model(x)
            total_loss += criterion(input=logits, target=y.squeeze(-1)).item()
            total_batches += 1
    return total_loss / total_batches, calculate_perplexity(total_loss,  total_batches)

In [57]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    loss_val, perplexity_val = evaluate(validation_dataloader, model, criterion)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} "
        "| time: {:5.2f}s "
        "| valid perplexity {:8.3f} "
        "| valid loss {:8.3f}".format(
            epoch,
            time.time() - epoch_start_time,
            perplexity_val,
            loss_val
        )
    )
    print("-" * 59)

print("Checking the results of test dataset.")
loss_test, perplexity_test = evaluate(test_dataloader, model, criterion)
print("test perplexity {:8.3f} | test loss {:8.3f} ".format(perplexity_test, loss_test))

682it [00:00, 1096.81it/s]

| epoch   1 |   500/15247 batches | perplexity    7.833 | loss    2.058 


1187it [00:01, 1237.41it/s]

| epoch   1 |  1000/15247 batches | perplexity    7.943 | loss    2.072 


1699it [00:01, 1267.36it/s]

| epoch   1 |  1500/15247 batches | perplexity    7.940 | loss    2.072 


2209it [00:01, 1263.39it/s]

| epoch   1 |  2000/15247 batches | perplexity    7.733 | loss    2.045 


2724it [00:02, 1277.43it/s]

| epoch   1 |  2500/15247 batches | perplexity    7.856 | loss    2.061 


3238it [00:02, 1265.80it/s]

| epoch   1 |  3000/15247 batches | perplexity    7.835 | loss    2.059 


3751it [00:03, 1262.74it/s]

| epoch   1 |  3500/15247 batches | perplexity    8.007 | loss    2.080 


4134it [00:03, 1266.75it/s]

| epoch   1 |  4000/15247 batches | perplexity    7.969 | loss    2.076 


4647it [00:03, 1277.40it/s]

| epoch   1 |  4500/15247 batches | perplexity    7.799 | loss    2.054 


5163it [00:04, 1265.31it/s]

| epoch   1 |  5000/15247 batches | perplexity    7.635 | loss    2.033 


5680it [00:04, 1273.59it/s]

| epoch   1 |  5500/15247 batches | perplexity    7.874 | loss    2.064 


6198it [00:04, 1274.97it/s]

| epoch   1 |  6000/15247 batches | perplexity    7.728 | loss    2.045 


6692it [00:05, 1110.48it/s]

| epoch   1 |  6500/15247 batches | perplexity    7.882 | loss    2.065 


7184it [00:05, 1169.31it/s]

| epoch   1 |  7000/15247 batches | perplexity    8.006 | loss    2.080 


7673it [00:06, 1151.00it/s]

| epoch   1 |  7500/15247 batches | perplexity    7.642 | loss    2.034 


8189it [00:06, 1258.41it/s]

| epoch   1 |  8000/15247 batches | perplexity    7.688 | loss    2.040 


8697it [00:07, 1250.20it/s]

| epoch   1 |  8500/15247 batches | perplexity    7.954 | loss    2.074 


9214it [00:07, 1279.68it/s]

| epoch   1 |  9000/15247 batches | perplexity    7.771 | loss    2.050 


9730it [00:07, 1281.00it/s]

| epoch   1 |  9500/15247 batches | perplexity    7.897 | loss    2.066 


10248it [00:08, 1284.25it/s]

| epoch   1 | 10000/15247 batches | perplexity    7.753 | loss    2.048 


10629it [00:08, 1231.91it/s]

| epoch   1 | 10500/15247 batches | perplexity    7.898 | loss    2.067 


11147it [00:09, 1277.35it/s]

| epoch   1 | 11000/15247 batches | perplexity    7.916 | loss    2.069 


11664it [00:09, 1277.91it/s]

| epoch   1 | 11500/15247 batches | perplexity    7.453 | loss    2.009 


12175it [00:09, 1259.96it/s]

| epoch   1 | 12000/15247 batches | perplexity    7.767 | loss    2.050 


12674it [00:10, 1228.98it/s]

| epoch   1 | 12500/15247 batches | perplexity    7.947 | loss    2.073 


13166it [00:10, 1214.81it/s]

| epoch   1 | 13000/15247 batches | perplexity    7.785 | loss    2.052 


13649it [00:11, 1185.17it/s]

| epoch   1 | 13500/15247 batches | perplexity    7.787 | loss    2.053 


14125it [00:11, 1171.26it/s]

| epoch   1 | 14000/15247 batches | perplexity    7.740 | loss    2.046 


14719it [00:12, 1168.71it/s]

| epoch   1 | 14500/15247 batches | perplexity    7.747 | loss    2.047 


15189it [00:12, 1159.95it/s]

| epoch   1 | 15000/15247 batches | perplexity    7.680 | loss    2.039 


15247it [00:12, 1216.03it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 12.62s | valid perplexity    7.886 | valid loss    2.065
-----------------------------------------------------------


638it [00:00, 1068.79it/s]

| epoch   2 |   500/15247 batches | perplexity    7.289 | loss    1.986 


1229it [00:01, 1165.99it/s]

| epoch   2 |  1000/15247 batches | perplexity    7.144 | loss    1.966 


1700it [00:01, 1168.08it/s]

| epoch   2 |  1500/15247 batches | perplexity    7.085 | loss    1.958 


2171it [00:01, 1158.75it/s]

| epoch   2 |  2000/15247 batches | perplexity    7.127 | loss    1.964 


2648it [00:02, 1186.69it/s]

| epoch   2 |  2500/15247 batches | perplexity    7.249 | loss    1.981 


3135it [00:02, 1199.80it/s]

| epoch   2 |  3000/15247 batches | perplexity    7.037 | loss    1.951 


3629it [00:03, 1215.08it/s]

| epoch   2 |  3500/15247 batches | perplexity    7.076 | loss    1.957 


4133it [00:03, 1248.67it/s]

| epoch   2 |  4000/15247 batches | perplexity    7.250 | loss    1.981 


4642it [00:03, 1224.62it/s]

| epoch   2 |  4500/15247 batches | perplexity    7.192 | loss    1.973 


5135it [00:04, 1219.11it/s]

| epoch   2 |  5000/15247 batches | perplexity    7.168 | loss    1.970 


5642it [00:04, 1258.29it/s]

| epoch   2 |  5500/15247 batches | perplexity    6.978 | loss    1.943 


6152it [00:05, 1254.93it/s]

| epoch   2 |  6000/15247 batches | perplexity    6.885 | loss    1.929 


6651it [00:05, 1193.99it/s]

| epoch   2 |  6500/15247 batches | perplexity    7.074 | loss    1.956 


7120it [00:06, 1118.39it/s]

| epoch   2 |  7000/15247 batches | perplexity    7.059 | loss    1.954 


7712it [00:06, 1163.25it/s]

| epoch   2 |  7500/15247 batches | perplexity    6.958 | loss    1.940 


8183it [00:06, 1155.84it/s]

| epoch   2 |  8000/15247 batches | perplexity    7.041 | loss    1.952 


8655it [00:07, 1168.08it/s]

| epoch   2 |  8500/15247 batches | perplexity    7.228 | loss    1.978 


9126it [00:07, 1153.33it/s]

| epoch   2 |  9000/15247 batches | perplexity    7.278 | loss    1.985 


9716it [00:08, 1163.33it/s]

| epoch   2 |  9500/15247 batches | perplexity    6.977 | loss    1.943 


10193it [00:08, 1172.05it/s]

| epoch   2 | 10000/15247 batches | perplexity    7.020 | loss    1.949 


10677it [00:09, 1196.52it/s]

| epoch   2 | 10500/15247 batches | perplexity    6.968 | loss    1.941 


11163it [00:09, 1200.70it/s]

| epoch   2 | 11000/15247 batches | perplexity    6.923 | loss    1.935 


11650it [00:09, 1189.09it/s]

| epoch   2 | 11500/15247 batches | perplexity    6.973 | loss    1.942 


12135it [00:10, 1199.65it/s]

| epoch   2 | 12000/15247 batches | perplexity    7.100 | loss    1.960 


12626it [00:10, 1203.42it/s]

| epoch   2 | 12500/15247 batches | perplexity    7.172 | loss    1.970 


13245it [00:11, 1226.82it/s]

| epoch   2 | 13000/15247 batches | perplexity    7.058 | loss    1.954 


13734it [00:11, 1208.69it/s]

| epoch   2 | 13500/15247 batches | perplexity    7.093 | loss    1.959 


14223it [00:12, 1213.02it/s]

| epoch   2 | 14000/15247 batches | perplexity    7.135 | loss    1.965 


14710it [00:12, 1196.17it/s]

| epoch   2 | 14500/15247 batches | perplexity    6.993 | loss    1.945 


15200it [00:12, 1208.48it/s]

| epoch   2 | 15000/15247 batches | perplexity    7.107 | loss    1.961 


15247it [00:12, 1177.04it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 13.03s | valid perplexity    6.983 | valid loss    1.943
-----------------------------------------------------------


729it [00:00, 1225.54it/s]

| epoch   3 |   500/15247 batches | perplexity    6.924 | loss    1.935 


1221it [00:01, 1211.42it/s]

| epoch   3 |  1000/15247 batches | perplexity    7.085 | loss    1.958 


1707it [00:01, 1192.50it/s]

| epoch   3 |  1500/15247 batches | perplexity    6.950 | loss    1.939 


2194it [00:01, 1206.79it/s]

| epoch   3 |  2000/15247 batches | perplexity    7.124 | loss    1.964 


2679it [00:02, 1198.81it/s]

| epoch   3 |  2500/15247 batches | perplexity    6.945 | loss    1.938 


3163it [00:02, 1196.95it/s]

| epoch   3 |  3000/15247 batches | perplexity    7.025 | loss    1.949 


3648it [00:03, 1191.94it/s]

| epoch   3 |  3500/15247 batches | perplexity    6.738 | loss    1.908 


4126it [00:03, 1166.93it/s]

| epoch   3 |  4000/15247 batches | perplexity    7.140 | loss    1.966 


4740it [00:03, 1215.26it/s]

| epoch   3 |  4500/15247 batches | perplexity    6.856 | loss    1.925 


5230it [00:04, 1199.50it/s]

| epoch   3 |  5000/15247 batches | perplexity    6.777 | loss    1.914 


5717it [00:04, 1196.44it/s]

| epoch   3 |  5500/15247 batches | perplexity    6.880 | loss    1.929 


6188it [00:05, 1153.24it/s]

| epoch   3 |  6000/15247 batches | perplexity    6.928 | loss    1.936 


6646it [00:05, 1115.29it/s]

| epoch   3 |  6500/15247 batches | perplexity    7.009 | loss    1.947 


7222it [00:06, 1140.86it/s]

| epoch   3 |  7000/15247 batches | perplexity    7.069 | loss    1.956 


7678it [00:06, 1097.72it/s]

| epoch   3 |  7500/15247 batches | perplexity    6.863 | loss    1.926 


8127it [00:06, 1092.92it/s]

| epoch   3 |  8000/15247 batches | perplexity    7.024 | loss    1.949 


8714it [00:07, 1164.91it/s]

| epoch   3 |  8500/15247 batches | perplexity    6.909 | loss    1.933 


9185it [00:07, 1165.11it/s]

| epoch   3 |  9000/15247 batches | perplexity    7.073 | loss    1.956 


9667it [00:08, 1180.16it/s]

| epoch   3 |  9500/15247 batches | perplexity    6.987 | loss    1.944 


10143it [00:08, 1171.42it/s]

| epoch   3 | 10000/15247 batches | perplexity    7.073 | loss    1.956 


10626it [00:09, 1189.52it/s]

| epoch   3 | 10500/15247 batches | perplexity    6.942 | loss    1.938 


11242it [00:09, 1214.61it/s]

| epoch   3 | 11000/15247 batches | perplexity    6.864 | loss    1.926 


11730it [00:10, 1203.38it/s]

| epoch   3 | 11500/15247 batches | perplexity    7.168 | loss    1.970 


12209it [00:10, 1170.64it/s]

| epoch   3 | 12000/15247 batches | perplexity    6.975 | loss    1.942 


12682it [00:10, 1162.18it/s]

| epoch   3 | 12500/15247 batches | perplexity    6.954 | loss    1.939 


13148it [00:11, 1152.37it/s]

| epoch   3 | 13000/15247 batches | perplexity    7.094 | loss    1.959 


13730it [00:11, 1148.71it/s]

| epoch   3 | 13500/15247 batches | perplexity    7.055 | loss    1.954 


14191it [00:12, 1141.68it/s]

| epoch   3 | 14000/15247 batches | perplexity    6.885 | loss    1.929 


14657it [00:12, 1147.71it/s]

| epoch   3 | 14500/15247 batches | perplexity    6.698 | loss    1.902 


15124it [00:13, 1149.32it/s]

| epoch   3 | 15000/15247 batches | perplexity    6.969 | loss    1.941 


15247it [00:13, 1161.44it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 13.20s | valid perplexity    6.926 | valid loss    1.935
-----------------------------------------------------------


691it [00:00, 1164.44it/s]

| epoch   4 |   500/15247 batches | perplexity    6.962 | loss    1.940 


1162it [00:01, 1127.34it/s]

| epoch   4 |  1000/15247 batches | perplexity    6.998 | loss    1.946 


1625it [00:01, 972.08it/s] 

| epoch   4 |  1500/15247 batches | perplexity    7.032 | loss    1.950 


2212it [00:02, 1137.58it/s]

| epoch   4 |  2000/15247 batches | perplexity    6.932 | loss    1.936 


2681it [00:02, 1157.34it/s]

| epoch   4 |  2500/15247 batches | perplexity    6.885 | loss    1.929 


3152it [00:02, 1160.83it/s]

| epoch   4 |  3000/15247 batches | perplexity    7.052 | loss    1.953 


3617it [00:03, 1146.26it/s]

| epoch   4 |  3500/15247 batches | perplexity    6.896 | loss    1.931 


4211it [00:03, 1176.36it/s]

| epoch   4 |  4000/15247 batches | perplexity    6.901 | loss    1.932 


4685it [00:04, 1165.83it/s]

| epoch   4 |  4500/15247 batches | perplexity    7.130 | loss    1.964 


5152it [00:04, 1155.25it/s]

| epoch   4 |  5000/15247 batches | perplexity    7.034 | loss    1.951 


5616it [00:05, 1140.78it/s]

| epoch   4 |  5500/15247 batches | perplexity    6.945 | loss    1.938 


6207it [00:05, 1165.57it/s]

| epoch   4 |  6000/15247 batches | perplexity    6.973 | loss    1.942 


6680it [00:05, 1169.92it/s]

| epoch   4 |  6500/15247 batches | perplexity    6.652 | loss    1.895 


7150it [00:06, 1155.83it/s]

| epoch   4 |  7000/15247 batches | perplexity    6.875 | loss    1.928 


7622it [00:06, 1161.65it/s]

| epoch   4 |  7500/15247 batches | perplexity    7.016 | loss    1.948 


8216it [00:07, 1161.03it/s]

| epoch   4 |  8000/15247 batches | perplexity    6.871 | loss    1.927 


8687it [00:07, 1155.51it/s]

| epoch   4 |  8500/15247 batches | perplexity    6.943 | loss    1.938 


9155it [00:08, 1148.01it/s]

| epoch   4 |  9000/15247 batches | perplexity    7.035 | loss    1.951 


9624it [00:08, 1152.59it/s]

| epoch   4 |  9500/15247 batches | perplexity    7.077 | loss    1.957 


10220it [00:09, 1174.36it/s]

| epoch   4 | 10000/15247 batches | perplexity    6.879 | loss    1.929 


10691it [00:09, 1161.47it/s]

| epoch   4 | 10500/15247 batches | perplexity    6.925 | loss    1.935 


11146it [00:09, 1034.20it/s]

| epoch   4 | 11000/15247 batches | perplexity    7.021 | loss    1.949 


11705it [00:10, 1115.88it/s]

| epoch   4 | 11500/15247 batches | perplexity    7.079 | loss    1.957 


12159it [00:10, 1123.92it/s]

| epoch   4 | 12000/15247 batches | perplexity    6.807 | loss    1.918 


12729it [00:11, 1126.30it/s]

| epoch   4 | 12500/15247 batches | perplexity    6.863 | loss    1.926 


13185it [00:11, 1121.73it/s]

| epoch   4 | 13000/15247 batches | perplexity    7.146 | loss    1.967 


13636it [00:12, 1100.68it/s]

| epoch   4 | 13500/15247 batches | perplexity    6.851 | loss    1.924 


14217it [00:12, 1154.98it/s]

| epoch   4 | 14000/15247 batches | perplexity    7.080 | loss    1.957 


14679it [00:13, 1135.35it/s]

| epoch   4 | 14500/15247 batches | perplexity    6.722 | loss    1.905 


15144it [00:13, 1147.89it/s]

| epoch   4 | 15000/15247 batches | perplexity    6.861 | loss    1.926 


15247it [00:13, 1126.12it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 13.62s | valid perplexity    6.935 | valid loss    1.937
-----------------------------------------------------------


694it [00:00, 1176.03it/s]

| epoch   5 |   500/15247 batches | perplexity    6.822 | loss    1.920 


1167it [00:01, 1168.90it/s]

| epoch   5 |  1000/15247 batches | perplexity    7.051 | loss    1.953 


1632it [00:01, 1136.93it/s]

| epoch   5 |  1500/15247 batches | perplexity    7.050 | loss    1.953 


2217it [00:01, 1153.70it/s]

| epoch   5 |  2000/15247 batches | perplexity    7.094 | loss    1.959 


2682it [00:02, 1150.58it/s]

| epoch   5 |  2500/15247 batches | perplexity    6.964 | loss    1.941 


3148it [00:02, 1147.77it/s]

| epoch   5 |  3000/15247 batches | perplexity    7.104 | loss    1.961 


3620it [00:03, 1158.56it/s]

| epoch   5 |  3500/15247 batches | perplexity    6.827 | loss    1.921 


4219it [00:03, 1169.35it/s]

| epoch   5 |  4000/15247 batches | perplexity    6.901 | loss    1.932 


4691it [00:04, 1127.84it/s]

| epoch   5 |  4500/15247 batches | perplexity    6.939 | loss    1.937 


5157it [00:04, 1152.74it/s]

| epoch   5 |  5000/15247 batches | perplexity    6.952 | loss    1.939 


5628it [00:04, 1157.81it/s]

| epoch   5 |  5500/15247 batches | perplexity    6.996 | loss    1.945 


6227it [00:05, 1170.91it/s]

| epoch   5 |  6000/15247 batches | perplexity    7.018 | loss    1.948 


6700it [00:05, 1161.43it/s]

| epoch   5 |  6500/15247 batches | perplexity    6.956 | loss    1.940 


7162it [00:06, 1131.75it/s]

| epoch   5 |  7000/15247 batches | perplexity    6.787 | loss    1.915 


7618it [00:06, 1123.50it/s]

| epoch   5 |  7500/15247 batches | perplexity    6.974 | loss    1.942 


8200it [00:07, 1150.82it/s]

| epoch   5 |  8000/15247 batches | perplexity    6.962 | loss    1.940 


8660it [00:07, 1091.32it/s]

| epoch   5 |  8500/15247 batches | perplexity    6.819 | loss    1.920 


9224it [00:08, 1126.49it/s]

| epoch   5 |  9000/15247 batches | perplexity    6.854 | loss    1.925 


9677it [00:08, 1117.07it/s]

| epoch   5 |  9500/15247 batches | perplexity    6.823 | loss    1.920 


10126it [00:08, 1109.05it/s]

| epoch   5 | 10000/15247 batches | perplexity    6.982 | loss    1.943 


10694it [00:09, 1118.85it/s]

| epoch   5 | 10500/15247 batches | perplexity    6.743 | loss    1.909 


11147it [00:09, 1114.42it/s]

| epoch   5 | 11000/15247 batches | perplexity    6.913 | loss    1.933 


11726it [00:10, 1143.21it/s]

| epoch   5 | 11500/15247 batches | perplexity    6.991 | loss    1.945 


12189it [00:10, 1141.26it/s]

| epoch   5 | 12000/15247 batches | perplexity    6.862 | loss    1.926 


12653it [00:11, 1124.03it/s]

| epoch   5 | 12500/15247 batches | perplexity    6.977 | loss    1.943 


13112it [00:11, 1119.80it/s]

| epoch   5 | 13000/15247 batches | perplexity    6.989 | loss    1.944 


13700it [00:12, 1159.92it/s]

| epoch   5 | 13500/15247 batches | perplexity    6.907 | loss    1.933 


14159it [00:12, 1090.13it/s]

| epoch   5 | 14000/15247 batches | perplexity    6.999 | loss    1.946 


14620it [00:12, 1127.28it/s]

| epoch   5 | 14500/15247 batches | perplexity    6.911 | loss    1.933 


15209it [00:13, 1166.43it/s]

| epoch   5 | 15000/15247 batches | perplexity    7.076 | loss    1.957 


15247it [00:13, 1130.38it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 13.57s | valid perplexity    6.940 | valid loss    1.937
-----------------------------------------------------------
Checking the results of test dataset.
test perplexity    7.089 | test loss    1.959 


Hint: For the above, you should see your loss around 2.0 and going down. Similarly to perplexity which should be aroud 7 to 8.

## Generate some text.

In [None]:
def generate_word(model, dataset, window):
    generated_word = []
    # Set the context to a window-1 length array having just the MARKER character's token_id.
    context = " FILL_IN "
    
    while True:
        logits = model(torch.tensor(context).view(1, -1))
        
        # Get the probabilities from the logits.
        # Hint: softmax!
        probs = " FILL_IN "
        
        # Get 1 sample from a multinomial having the above probabilities.
        token_id = torch.multinomial(" FILL_IN ").item()
        
        # Append the token_id to the generated word.
        " FILL_IN "
        
        # Move the context over 1, drop the first (oldest) token and apped the new one above.
        # The size of the resulting context should be the same.
        # For exaple, if it was "[0, 1, 2]" and you generated 4, it should now be [1, 2, 4].
        context = " FILL_IN "
        
        if token_id == 0:
            # If you generate token_id = 0, i.e. '.', break out.
            " FILL_IN "
    # Return and decode the generated word to a string.        
    return ''.join(dataset.decode(generated_word))

In [None]:
torch.manual_seed(1)
for _ in range(50):
    print(generate_word(model, train_dataset, n))

TypeError: new(): invalid data type 'str'