In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from tqdm import tqdm
from utils import *
from network import CharRNN

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
with open('data/internet_archive_scifi_v3.txt') as f:
    text = f.read()

In [4]:
criterion = nn.CrossEntropyLoss()

# check if there are any checkpoint available from previous runs, else start from zero
try:
    checkpoint = torch.load('models/char_rnn_ckpt.pth', map_location=device)
    model = CharRNN(
        tokens=checkpoint['tokens'],
        n_hidden=checkpoint['n_hidden'],
        n_layers=checkpoint['n_layers'],
        ).to(device)
    model.load_state_dict(checkpoint['model'])
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=8, verbose=True)
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    start_epoch = checkpoint['epoch'] + 1
    val_loss_min = checkpoint['val_loss']
    print(f'checkpoint found, training will start from epoch {start_epoch}')
    
    encoding = np.array([model.char2int[ch] for ch in text[580:]])
    
except:
    chars, encoding = tokenize(text[580:])  # scifi text starts from character no '580'
    model = CharRNN(tokens=chars).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=8, verbose=True)
    start_epoch = 0
    val_loss_min = np.Inf
    print(f'no checkpoint found, training will start from epoch {start_epoch}')

print(model)

# create training and validation data
val_frac = 0.01  # 1% of entire dataset (~1.5M chars)
val_idx = int(len(encoding)*(1-val_frac))
data, val_data = encoding[:val_idx], encoding[val_idx:]

no checkpoint found, training will start from epoch 0
CharRNN(
  (lstm): LSTM(75, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=75, bias=True)
)


In [5]:
seq_length = 64
batch_size = 512
num_epochs = 32
n_chars = len(model.chars)
num_train_batches = len(data)//(batch_size*seq_length)
num_val_batches = len(val_data)//(batch_size*seq_length)

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0

    h = model.init_hidden(batch_size)  # initialize hidden state
    loop = tqdm(get_batches(data, batch_size, seq_length), total=num_train_batches)

    for x, y in loop:
        x = one_hot_encode(x, n_chars)  # one hot encode the data
        inputs, targets = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
        # creates new variables for the hidden state, otherwise
        # the optimizer would backprop through the entire training history
        h = tuple([each.data for each in h])
        model.zero_grad()  # zero out accumulated gradients
        outputs, h = model(inputs, h)
        loss = criterion(outputs, targets.view(batch_size*seq_length).long())
        loss.backward()
        # clip gradients, prevents exploding gradient problem in RNNs
        nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=5)
        optimizer.step()

        train_loss += loss.item()
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=loss.item())


    model.eval()
    val_loss = 0
    with torch.no_grad():
        val_h = model.init_hidden(batch_size)
        for x, y in get_batches(val_data, batch_size, seq_length):
            x = one_hot_encode(x, n_chars)
            inputs, targets = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
            outputs, val_h = model(inputs, val_h)
            loss = criterion(outputs, targets.view(batch_size*seq_length).long())
            val_loss += loss.item()


    scheduler.step(val_loss)
    print(f'\n\t\tavg_train_loss={train_loss/num_train_batches:.4f}, avg_val_loss={val_loss/num_val_batches:.4f}')

    # save the model if the validation loss has decreased
    if val_loss <= val_loss_min:
        print(f'\t\tval_loss decreased ({val_loss_min:.4f} --> {val_loss:.4f}) saving model...\n')
        checkpoint = {
            'epoch': epoch,
            'n_hidden': model.n_hidden,
            'n_layers': model.n_layers,
            'tokens': model.chars,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'val_loss': val_loss
            }
        torch.save(checkpoint, 'models/char_rnn_ckpt.pth')
        val_loss_min = val_loss


Epoch [ 1/32]: 100%|██████████| 4511/4511 [16:54<00:00,  4.45it/s, loss=1.45]



		avg_train_loss=1.6226, avg_val_loss=1.2837
		val_loss decreased (inf --> 57.7671) saving model...



Epoch [ 2/32]: 100%|██████████| 4511/4511 [16:56<00:00,  4.44it/s, loss=1.39]



		avg_train_loss=1.3518, avg_val_loss=1.2196
		val_loss decreased (57.7671 --> 54.8805) saving model...



Epoch [ 3/32]: 100%|██████████| 4511/4511 [16:57<00:00,  4.43it/s, loss=1.37]



		avg_train_loss=1.3085, avg_val_loss=1.1924
		val_loss decreased (54.8805 --> 53.6597) saving model...



Epoch [ 4/32]: 100%|██████████| 4511/4511 [16:58<00:00,  4.43it/s, loss=1.35]



		avg_train_loss=1.2863, avg_val_loss=1.1763
		val_loss decreased (53.6597 --> 52.9334) saving model...



Epoch [ 5/32]: 100%|██████████| 4511/4511 [16:58<00:00,  4.43it/s, loss=1.34]



		avg_train_loss=1.2719, avg_val_loss=1.1649
		val_loss decreased (52.9334 --> 52.4198) saving model...



Epoch [ 6/32]: 100%|██████████| 4511/4511 [16:59<00:00,  4.42it/s, loss=1.33]



		avg_train_loss=1.2615, avg_val_loss=1.1566
		val_loss decreased (52.4198 --> 52.0464) saving model...



Epoch [ 7/32]: 100%|██████████| 4511/4511 [16:57<00:00,  4.43it/s, loss=1.33]



		avg_train_loss=1.2533, avg_val_loss=1.1504
		val_loss decreased (52.0464 --> 51.7688) saving model...



Epoch [ 8/32]: 100%|██████████| 4511/4511 [16:53<00:00,  4.45it/s, loss=1.32]



		avg_train_loss=1.2468, avg_val_loss=1.1446
		val_loss decreased (51.7688 --> 51.5086) saving model...



Epoch [ 9/32]: 100%|██████████| 4511/4511 [16:57<00:00,  4.44it/s, loss=1.32]



		avg_train_loss=1.2413, avg_val_loss=1.1399
		val_loss decreased (51.5086 --> 51.2976) saving model...



Epoch [10/32]: 100%|██████████| 4511/4511 [16:57<00:00,  4.43it/s, loss=1.31]



		avg_train_loss=1.2366, avg_val_loss=1.1359
		val_loss decreased (51.2976 --> 51.1175) saving model...



Epoch [11/32]: 100%|██████████| 4511/4511 [16:59<00:00,  4.43it/s, loss=1.31]



		avg_train_loss=1.2327, avg_val_loss=1.1326
		val_loss decreased (51.1175 --> 50.9681) saving model...



Epoch [12/32]: 100%|██████████| 4511/4511 [16:58<00:00,  4.43it/s, loss=1.31]



		avg_train_loss=1.2292, avg_val_loss=1.1298
		val_loss decreased (50.9681 --> 50.8414) saving model...



Epoch [13/32]: 100%|██████████| 4511/4511 [17:10<00:00,  4.38it/s, loss=1.31]



		avg_train_loss=1.2260, avg_val_loss=1.1270
		val_loss decreased (50.8414 --> 50.7153) saving model...



Epoch [14/32]: 100%|██████████| 4511/4511 [17:12<00:00,  4.37it/s, loss=1.3]



		avg_train_loss=1.2233, avg_val_loss=1.1248
		val_loss decreased (50.7153 --> 50.6138) saving model...



Epoch [15/32]: 100%|██████████| 4511/4511 [17:10<00:00,  4.38it/s, loss=1.3]



		avg_train_loss=1.2208, avg_val_loss=1.1224
		val_loss decreased (50.6138 --> 50.5058) saving model...



Epoch [16/32]: 100%|██████████| 4511/4511 [17:15<00:00,  4.36it/s, loss=1.3]



		avg_train_loss=1.2186, avg_val_loss=1.1208
		val_loss decreased (50.5058 --> 50.4351) saving model...



Epoch [17/32]: 100%|██████████| 4511/4511 [17:19<00:00,  4.34it/s, loss=1.3]



		avg_train_loss=1.2164, avg_val_loss=1.1185
		val_loss decreased (50.4351 --> 50.3334) saving model...



Epoch [18/32]: 100%|██████████| 4511/4511 [18:06<00:00,  4.15it/s, loss=1.3]



		avg_train_loss=1.2145, avg_val_loss=1.1169
		val_loss decreased (50.3334 --> 50.2624) saving model...



Epoch [19/32]: 100%|██████████| 4511/4511 [18:10<00:00,  4.14it/s, loss=1.29]



		avg_train_loss=1.2128, avg_val_loss=1.1152
		val_loss decreased (50.2624 --> 50.1818) saving model...



Epoch [20/32]: 100%|██████████| 4511/4511 [18:16<00:00,  4.12it/s, loss=1.3]



		avg_train_loss=1.2112, avg_val_loss=1.1136
		val_loss decreased (50.1818 --> 50.1104) saving model...



Epoch [21/32]: 100%|██████████| 4511/4511 [18:13<00:00,  4.12it/s, loss=1.3]



		avg_train_loss=1.2097, avg_val_loss=1.1129
		val_loss decreased (50.1104 --> 50.0825) saving model...



Epoch [22/32]: 100%|██████████| 4511/4511 [18:09<00:00,  4.14it/s, loss=1.3]



		avg_train_loss=1.2083, avg_val_loss=1.1112
		val_loss decreased (50.0825 --> 50.0058) saving model...



Epoch [23/32]: 100%|██████████| 4511/4511 [18:16<00:00,  4.12it/s, loss=1.29]



		avg_train_loss=1.2069, avg_val_loss=1.1104
		val_loss decreased (50.0058 --> 49.9665) saving model...



Epoch [24/32]: 100%|██████████| 4511/4511 [18:14<00:00,  4.12it/s, loss=1.3]



		avg_train_loss=1.2057, avg_val_loss=1.1095
		val_loss decreased (49.9665 --> 49.9267) saving model...



Epoch [25/32]: 100%|██████████| 4511/4511 [18:09<00:00,  4.14it/s, loss=1.29]



		avg_train_loss=1.2045, avg_val_loss=1.1085
		val_loss decreased (49.9267 --> 49.8840) saving model...



Epoch [26/32]: 100%|██████████| 4511/4511 [18:09<00:00,  4.14it/s, loss=1.29]



		avg_train_loss=1.2034, avg_val_loss=1.1071
		val_loss decreased (49.8840 --> 49.8187) saving model...



Epoch [27/32]: 100%|██████████| 4511/4511 [18:16<00:00,  4.11it/s, loss=1.29]



		avg_train_loss=1.2024, avg_val_loss=1.1059
		val_loss decreased (49.8187 --> 49.7641) saving model...



Epoch [28/32]: 100%|██████████| 4511/4511 [18:15<00:00,  4.12it/s, loss=1.28]



		avg_train_loss=1.2013, avg_val_loss=1.1056
		val_loss decreased (49.7641 --> 49.7538) saving model...



Epoch [29/32]: 100%|██████████| 4511/4511 [18:09<00:00,  4.14it/s, loss=1.29]



		avg_train_loss=1.2003, avg_val_loss=1.1051
		val_loss decreased (49.7538 --> 49.7274) saving model...



Epoch [30/32]: 100%|██████████| 4511/4511 [18:15<00:00,  4.12it/s, loss=1.29]



		avg_train_loss=1.1995, avg_val_loss=1.1034
		val_loss decreased (49.7274 --> 49.6527) saving model...



Epoch [31/32]: 100%|██████████| 4511/4511 [18:14<00:00,  4.12it/s, loss=1.29]



		avg_train_loss=1.1986, avg_val_loss=1.1025
		val_loss decreased (49.6527 --> 49.6115) saving model...



Epoch [32/32]: 100%|██████████| 4511/4511 [18:09<00:00,  4.14it/s, loss=1.29]



		avg_train_loss=1.1978, avg_val_loss=1.1019
		val_loss decreased (49.6115 --> 49.5843) saving model...



In [None]:
# save only required stuff which will be needed during inference
checkpoint = {
            'n_hidden': model.n_hidden,
            'n_layers': model.n_layers,
            'tokens': model.chars,
            'model': model.state_dict()
            }
torch.save(checkpoint, 'models/char_rnn.pth')