In [1]:
import numpy as np
import torch
import torch.nn as nn
device = torch.device("mps")

In [2]:
with open('../Data/shakespeare.txt') as f:
    text = f.read()

In [3]:
all_char = set(text)
n_unique_char = len(all_char)

In [4]:
decoder = dict(enumerate(all_char))
encoder = {char:ind for ind,char in decoder.items()}

In [5]:
encoded_text = np.array([encoder[char] for char in text])

In [6]:
def one_hot_encoder(encoded_text,n_unique_char):
    one_hot = np.zeros((encoded_text.size,n_unique_char)).astype(np.float32)
    one_hot[np.arange(one_hot.shape[0]),encoded_text.flatten()] = 1.0
    one_hot = one_hot.reshape((*encoded_text.shape,n_unique_char))
    return one_hot   

In [7]:
def generate_batches(encoded_text,sample_per_batch=10,seq_len=50):
    char_per_batch = sample_per_batch * seq_len
    avail_batch = int(len(encoded_text)/char_per_batch)
    encoded_text = encoded_text[:char_per_batch*avail_batch]
    encoded_text = encoded_text.reshape((sample_per_batch,-1))
    
    for n in range(0,encoded_text.shape[1],seq_len):
        x = encoded_text[:,n:n+seq_len]
        y = np.zeros_like(x)
        try : 
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,n+seq_len]
        #for the very last case
        except : 
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,0]
        yield x,y

In [8]:
batch_generator = generate_batches(encoded_text,sample_per_batch=10,seq_len=50)

In [9]:
x,y = next(batch_generator)

In [10]:
class LSTM(nn.Module):
    def __init__(self,all_char,num_hidden=256,num_layers=4,drop_prob=0.5):
        super(LSTM,self).__init__()
        self.all_char = all_char
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.drop_prob = drop_prob
        
        self.decoder = dict(enumerate(all_char))
        self.encoder = {char:ind for ind,char in decoder.items()}
        
        self.lstm = nn.LSTM(len(self.all_char),num_hidden,num_layers,dropout = drop_prob,batch_first=True)
        self.fc_linear = nn.Linear(num_hidden,len(self.all_char))
        self.dropout = nn.Dropout(drop_prob)
    def forward(self,x,hidden):
        lstm_out, hidden = self.lstm(x,hidden)
        drop_out = self.dropout(lstm_out)
        drop_out = drop_out.contiguous().view(-1,self.num_hidden)
        final_out = self.fc_linear(drop_out)
        return final_out,hidden
    def init_hidden(self,batch_size):
        hidden = (torch.zeros(self.num_layers,batch_size,self.num_hidden).to(device),
                 torch.zeros(self.num_layers,batch_size,self.num_hidden).to(device))
        return hidden
        

In [11]:
model = LSTM(all_char,num_hidden=256,num_layers=4,drop_prob=0.4).to(device)

In [12]:
model

LSTM(
  (lstm): LSTM(84, 256, num_layers=4, batch_first=True, dropout=0.4)
  (fc_linear): Linear(in_features=256, out_features=84, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
)

In [13]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

In [14]:
train_percent = 0.9
train_ind = int(len(encoded_text) * (train_percent))
train_data = encoded_text[:train_ind]
val_data = encoded_text[train_ind:]

In [15]:
num_epoch = 75
batch_size = 100
seq_len = 100
tracker = 0
num_char = max(encoded_text)+1

In [16]:
model.train()

LSTM(
  (lstm): LSTM(84, 256, num_layers=4, batch_first=True, dropout=0.4)
  (fc_linear): Linear(in_features=256, out_features=84, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
)

In [17]:
model.train()
for epoch in range(num_epoch):
    hidden = model.init_hidden(batch_size)
    for x,y in generate_batches(train_data,batch_size,seq_len):
        tracker +=1 
        x = one_hot_encoder(x,num_char)
        inputs = torch.tensor(x).to(device)
        targets = torch.LongTensor(y).to(device)
        hidden = tuple([state.data for state in hidden])
        model.zero_grad()
        lstm_out,hidden = model.forward(inputs,hidden)
        loss = criterion(lstm_out,targets.view(batch_size*seq_len).long())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(),max_norm=5)
        optimizer.step()
        if tracker % 25 == 0:
            val_hidden = model.init_hidden(batch_size)
            val_losses = []
            model.eval()
            for x,y in generate_batches(val_data,batch_size,seq_len):
                x = one_hot_encoder(x,num_char)
                inputs = torch.tensor(x).to(device)
                targets = torch.LongTensor(y).to(device)
                val_hidden = tuple([state.data for state in hidden])
                lstm_output,val_hidden = model.forward(inputs,val_hidden)
                val_loss = criterion(lstm_output,targets.view(batch_size*seq_len).long())
                val_losses.append(val_loss.item())
            model.train()
            print(f"epoch  : {epoch+1} step : {tracker} val_loss : {val_loss.item()}")

epoch  : 1 step : 25 val_loss : 3.2059872150421143
epoch  : 1 step : 50 val_loss : 3.197155475616455
epoch  : 1 step : 75 val_loss : 3.1957905292510986
epoch  : 1 step : 100 val_loss : 3.1974141597747803
epoch  : 1 step : 125 val_loss : 3.1971182823181152
epoch  : 1 step : 150 val_loss : 3.1970865726470947
epoch  : 1 step : 175 val_loss : 3.1945366859436035
epoch  : 1 step : 200 val_loss : 3.1953234672546387
epoch  : 1 step : 225 val_loss : 3.1963179111480713
epoch  : 1 step : 250 val_loss : 3.199916362762451
epoch  : 1 step : 275 val_loss : 3.1965813636779785
epoch  : 1 step : 300 val_loss : 3.1970553398132324
epoch  : 1 step : 325 val_loss : 3.1973752975463867
epoch  : 1 step : 350 val_loss : 3.196497678756714
epoch  : 1 step : 375 val_loss : 3.1967051029205322
epoch  : 1 step : 400 val_loss : 3.1948065757751465
epoch  : 1 step : 425 val_loss : 3.193148136138916
epoch  : 1 step : 450 val_loss : 3.1941583156585693
epoch  : 1 step : 475 val_loss : 3.1933465003967285
epoch  : 2 step : 5

KeyboardInterrupt: 