In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from tqdm import tqdm
from utils import *
from network import CharRNN

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
with open('data/internet_archive_scifi_v3.txt') as f:
    text = f.read()

In [None]:
chars, encoding = tokenize(text[580:])  # scifi text starts from character no '580'

# create training and validation data
val_frac = 0.01  # 1% of entire dataset (~1.5M chars)
val_idx = int(len(encoding)*(1-val_frac))
data, val_data = encoding[:val_idx], encoding[val_idx:]

In [None]:
criterion = nn.CrossEntropyLoss()

model = CharRNN(tokens=chars).to(device)
print(model)

optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=8, verbose=True)

In [None]:
try:
    checkpoint = torch.load('drive/MyDrive/char_rnn_ckpt.pth', map_location=device)
    model = CharRNN(
        tokens=checkpoint['tokens'],
        n_hidden=checkpoint['n_hidden'],
        n_layers=checkpoint['n_layers'],
        ).to(device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

    start_epoch = checkpoint['epoch'] + 1
    print(f'checkpoint found, training will start from epoch {start_epoch}\n')
    print(model)

except:
    start_epoch = 0
    print(f'no checkpoint found, training will start from epoch {start_epoch}')

In [None]:
seq_length = 64
batch_size = 256
num_epochs = 64
n_chars = len(model.chars)
val_loss_min = np.Inf

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0

    h = model.init_hidden(batch_size)  # initialize hidden state
    loop = tqdm(get_batches(data, batch_size, seq_length), total=9023)

    for x, y in loop:
        x = one_hot_encode(x, n_chars)  # one hot encode the data
        inputs, targets = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
        # creates new variables for the hidden state, otherwise
        # the optimizer would backprop through the entire training history
        h = tuple([each.data for each in h])
        model.zero_grad()  # zero out accumulated gradients
        outputs, h = model(inputs, h)
        loss = criterion(outputs, targets.view(batch_size*seq_length).long())
        loss.backward()
        # clip gradients, prevents exploding gradient problem in RNNs
        nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=5)
        optimizer.step()

        train_loss += loss.item()
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=loss.item())


    model.eval()
    val_loss = 0
    with torch.no_grad():
        val_h = model.init_hidden(batch_size)
        for x, y in get_batches(val_data, batch_size, seq_length):
            x = one_hot_encode(x, n_chars)
            inputs, targets = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
            outputs, val_h = model(inouts, val_h)
            loss = criterion(outputs, targets.view(batch_size*seq_length).long())
            val_loss += loss.item()


    scheduler.step(val_loss)
    tqdm.write(f'\t\ttrain_loss={train_loss}, val_loss={val_loss}')

    # save the model if validation loss has decreased
    if val_loss <= val_loss_min:
        tqdm.write(f'\t\tval_loss decreased ({val_loss_min:.4f} --> {val_loss:.4f}) saving model...')
        checkpoint = {
            'epoch': epoch,
            'n_hidden': model.n_hidden,
            'n_layers': model.n_layers,
            'tokens': model.chars,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
            }
        torch.save(checkpoint, f'drive/MyDrive/char_rnn_ckpt.pth')
        val_loss_min = val_loss
