In [1]:
import numpy as np
import torch
import torch.nn as nn
device = torch.device("cuda")

In [2]:
with open('/kaggle/input/char-level-mode-text/shakespeare.txt') as f:
    text = f.read()

In [3]:
all_char = set(text)
n_unique_char = len(all_char)

In [4]:
decoder = dict(enumerate(all_char))
encoder = {char:ind for ind,char in decoder.items()}

In [5]:
encoded_text = np.array([encoder[char] for char in text])

In [6]:
def one_hot_encoder(encoded_text,n_unique_char):
    one_hot = np.zeros((encoded_text.size,n_unique_char)).astype(np.float32)
    one_hot[np.arange(one_hot.shape[0]),encoded_text.flatten()] = 1.0
    one_hot = one_hot.reshape((*encoded_text.shape,n_unique_char))
    return one_hot   

In [7]:
def generate_batches(encoded_text,sample_per_batch=10,seq_len=50):
    char_per_batch = sample_per_batch * seq_len
    avail_batch = int(len(encoded_text)/char_per_batch)
    encoded_text = encoded_text[:char_per_batch*avail_batch]
    encoded_text = encoded_text.reshape((sample_per_batch,-1))
    
    for n in range(0,encoded_text.shape[1],seq_len):
        x = encoded_text[:,n:n+seq_len]
        y = np.zeros_like(x)
        try : 
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,n+seq_len]
        #for the very last case
        except : 
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,0]
        yield x,y

In [8]:
batch_generator = generate_batches(encoded_text,sample_per_batch=10,seq_len=50)

In [9]:
x,y = next(batch_generator)

In [10]:
class LSTM(nn.Module):
    def __init__(self,all_char,num_hidden=512,num_layers=3,drop_prob=0.5):
        super(LSTM,self).__init__()
        self.all_char = all_char
        self.num_hidden = num_hidden
        self.num_layers = num_layers
        self.drop_prob = drop_prob
        
        self.decoder = dict(enumerate(all_char))
        self.encoder = {char:ind for ind,char in decoder.items()}
        
        self.lstm = nn.LSTM(len(self.all_char),num_hidden,num_layers,dropout = drop_prob,batch_first=True)
        self.fc_linear = nn.Linear(num_hidden,len(self.all_char))
        self.dropout = nn.Dropout(drop_prob)
    def forward(self,x,hidden):
        lstm_out, hidden = self.lstm(x,hidden)
        drop_out = self.dropout(lstm_out)
        drop_out = drop_out.contiguous().view(-1,self.num_hidden)
        final_out = self.fc_linear(drop_out)
        return final_out,hidden
    def init_hidden(self,batch_size):
        hidden = (torch.zeros(self.num_layers,batch_size,self.num_hidden).to(device),
                 torch.zeros(self.num_layers,batch_size,self.num_hidden).to(device))
        return hidden
        

In [11]:
model = LSTM(all_char,num_hidden=1024,num_layers=4,drop_prob=0.6).to(device)

In [12]:
model

LSTM(
  (lstm): LSTM(84, 1024, num_layers=4, batch_first=True, dropout=0.6)
  (fc_linear): Linear(in_features=1024, out_features=84, bias=True)
  (dropout): Dropout(p=0.6, inplace=False)
)

In [13]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

In [14]:
train_percent = 0.9
train_ind = int(len(encoded_text) * (train_percent))
train_data = encoded_text[:train_ind]
val_data = encoded_text[train_ind:]

In [15]:
num_epoch = 20
batch_size = 100
seq_len = 100
tracker = 0
num_char = max(encoded_text)+1

In [16]:
model.train()

LSTM(
  (lstm): LSTM(84, 1024, num_layers=4, batch_first=True, dropout=0.6)
  (fc_linear): Linear(in_features=1024, out_features=84, bias=True)
  (dropout): Dropout(p=0.6, inplace=False)
)

In [17]:
model.train()
for epoch in range(num_epoch):
    hidden = model.init_hidden(batch_size)
    for x,y in generate_batches(train_data,batch_size,seq_len):
        tracker +=1 
        x = one_hot_encoder(x,num_char)
        inputs = torch.tensor(x).to(device)
        targets = torch.LongTensor(y).to(device)
        hidden = tuple([state.data for state in hidden])
        model.zero_grad()
        lstm_out,hidden = model.forward(inputs,hidden)
        loss = criterion(lstm_out,targets.view(batch_size*seq_len).long())
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(),max_norm=5)
        optimizer.step()
        if tracker % 50 == 0:
            val_hidden = model.init_hidden(batch_size)
            val_losses = []
            model.eval()
            for x,y in generate_batches(val_data,batch_size,seq_len):
                x = one_hot_encoder(x,num_char)
                inputs = torch.tensor(x).to(device)
                targets = torch.LongTensor(y).to(device)
                val_hidden = tuple([state.data for state in hidden])
                lstm_output,val_hidden = model.forward(inputs,val_hidden)
                val_loss = criterion(lstm_output,targets.view(batch_size*seq_len).long())
                val_losses.append(val_loss.item())
            model.train()
            print(f"epoch  : {epoch+1} step : {tracker} val_loss : {val_loss.item()}")

epoch  : 1 step : 50 val_loss : 3.1927735805511475
epoch  : 1 step : 100 val_loss : 3.0099895000457764
epoch  : 1 step : 150 val_loss : 2.521446943283081
epoch  : 1 step : 200 val_loss : 2.314289093017578
epoch  : 1 step : 250 val_loss : 2.1889312267303467
epoch  : 1 step : 300 val_loss : 2.1046571731567383
epoch  : 1 step : 350 val_loss : 2.030832290649414
epoch  : 1 step : 400 val_loss : 1.961775779724121
epoch  : 1 step : 450 val_loss : 1.9127683639526367
epoch  : 2 step : 500 val_loss : 1.8476203680038452
epoch  : 2 step : 550 val_loss : 1.8103877305984497
epoch  : 2 step : 600 val_loss : 1.7718286514282227
epoch  : 2 step : 650 val_loss : 1.7311971187591553
epoch  : 2 step : 700 val_loss : 1.7184250354766846
epoch  : 2 step : 750 val_loss : 1.6885802745819092
epoch  : 2 step : 800 val_loss : 1.6591687202453613
epoch  : 2 step : 850 val_loss : 1.661063313484192
epoch  : 2 step : 900 val_loss : 1.643154501914978
epoch  : 2 step : 950 val_loss : 1.6288139820098877
epoch  : 3 step : 1

In [18]:
def pred_next_char(model,char,hidden=None,k=1):
    encoded_text = model.encoder[char]
    encoded_text = np.array([[encoded_text]]) 
    encoded_text = one_hot_encoder(encoded_text,len(model.all_char))
    #create input by encoding and one_hotting the char
    inputs = torch.tensor(encoded_text).to(device)
    #create hidden state
    hidden = tuple([state.data for state in hidden])
    #make prediction
    lstm_out,hidden = model.forward(inputs,hidden)
    #get probabilities
    probs = F.softmax(lstm_out,dim=1).data
    probs = probs.cpu()
    probs,index_position = probs.topk(k)
    index_position = index_position.numpy().squeeze()
    probs = probs.numpy().flatten()
    probs = probs/probs.sum()
    #choose a char from top k
    char = np.random.choice(index_position,p=probs)
    return model.decoder[char],hidden

In [19]:
def generate_text(model,size,seed = 'the',k=1):
    model = model.to(device)
    model.eval()
    output_char = [c for c in seed]
    hidden = model.init_hidden(1)
    for char in seed:
        char,hidden = pred_next_char(model,char,hidden,k=k)
    output_char.append(char)
    for i in range(size):
        char,hidden = pred_next_char(model,output_char[-1],hidden,k=k)
        output_char.append(char)
    return ''.join(output_char)

In [35]:
import torch.nn.functional as F
print(generate_text(model,2000,seed='The',k=2))

Theres' hearts,
    To strike and see him there to think of them.
    The star the field the sea and soul of heaven
    Hath still success'd the senses of the world,
    And that his state and thieves are the contraries
    That stands and therefore shall be thought to see.
    If thou dost see his soul that to the war
    Was so their father's heart and sent thy son,
    To see them to be true as he is born.
    If thou hadst seen to stop their faces arm,
    And, without triumph, the third the fiend of France,
    The state, a soldier to the street of this
    With their disposition that thou art a most,
    Therefore the sea and three of them the sea,
    And she will be a most right strain of him.
    The stroke of this dear sister without son
    Was to the sense of thine arms to th' state,
    And that the sea and soul of this soldily
    Will be the form and strange as the statues that have seen.
    Therefore the world is to be thought to be
    That the sun of him that hath se