![alternative text](../../data/rnn_chatgpt.png)
![alternative text](../../data/rnn_pytorch.png)


In [None]:
from torch import nn
import torch
import numpy as np
from matplotlib.pylab import plt
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.distributions import Categorical


In [None]:
text_list = ' '.join(open('data/training_text.txt','r').readlines())
text_list = (text_list.replace('<s>','').replace('</s>','').split('\n'))
text_list = ' '.join([' '.join(x.split()).lower().replace('*','').replace('#','').replace('@','').replace('[','').replace(']','').replace('{','').replace('}','').replace(')','').replace('(','').replace('%','').replace('$','') for x in text_list])


In [None]:
########### Hyperparameters ###########
hidden_size = 512   # size of hidden state
seq_len = 100       # length of LSTM sequence
num_layers = 3      # num of layers in LSTM layer stack
lr = 0.01          # learning rate
epochs = 5        # max number of epochs
op_seq_len = 200    # total num of characters in output test sequence
    


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class RNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, input_size)
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input_seq, hidden_state):
        embedding = self.embedding(input_seq)
        output, hidden_state = self.rnn(embedding, hidden_state)
        output = self.decoder(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())
    
# load the text file
data = text_list
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print("Data has {} characters, {} unique".format(data_size, vocab_size))
    
# char to index and index to char maps
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }

# convert data from chars to indices
data = list(data)
for i, ch in enumerate(data):
    data[i] = char_to_ix[ch]

# data tensor on device
data = torch.tensor(data).to(device)
data = torch.unsqueeze(data, dim=1)

In [None]:
# model instance
rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)
    
# loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

In [None]:

training = data[:-50000]
train_data = []
label_data = []
data_ptr = 0

for i in range(20000):
    data_ptr += np.random.randint(128)
    input_seq  = training[data_ptr : data_ptr+seq_len].squeeze()
    target_seq = training[data_ptr+1 : data_ptr+seq_len+1].squeeze()
    if len(input_seq) == seq_len:
        train_data.append(input_seq)
        label_data.append(target_seq)

batch_size = 1
train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
label_data_loader = DataLoader(label_data, batch_size=batch_size, shuffle=True)


In [None]:
def train():

    # training loop
    for i_epoch in range(1, epochs+1):
        
        # random starting point (1st 100 chars) from data to begin
        n = 0
        running_loss = 0
        hidden_state = None
        batch_size = 1
        train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        label_data_loader = DataLoader(label_data, batch_size=batch_size, shuffle=True)


        for input_seq, target_seq in zip(train_data_loader,label_data_loader):
            print(n,end='\r')
            
            # forward pass
            output, hidden_state = rnn(input_seq, hidden_state)
            
            # compute loss
            loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
            running_loss += loss.item()
            
            # compute gradients and take optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # update the data pointer
            n +=1
            if n % 10000 == 0:
                break
            
                
        # print loss and save weights after every epoch
        print("Epoch: {0} \t Loss: {1:.8f}".format(i_epoch, running_loss/n))
            
        print("--------------------------------------------\n")
        # sample / generate a text sequence after every epoch
        
        model.eval()
        data_ptr = 0
        hidden_state = None

        # random character from data to begin
        rand_index = np.random.randint(data_size-1)
        input_seq = data[rand_index : rand_index+1]

        while True:
            # forward pass
            output, hidden_state = rnn(input_seq, hidden_state)

            # construct categorical distribution and sample a character
            output = F.softmax(torch.squeeze(output), dim=0)
            dist = Categorical(output)
            index = dist.sample()

            # print the sampled character
            print(ix_to_char[index.item()], end='')

            # next input is current output
            input_seq[0][0] = index.item()
            data_ptr += 1

            if data_ptr > op_seq_len:
                break
        print("\n--------------------------------------------")

train()