In [1]:
import torch
import numpy as np
import torch.nn as nn
import torchtext
from glob import glob

In [2]:
files = glob(f'../input/poemsdataset/forms/**/*.txt')
files[:5]

['../input/poemsdataset/forms/lay/LayPoemsLayAGarlandOnMyHearsePoembyFrancisBeaumont.txt',
 '../input/poemsdataset/forms/lay/LayPoemsLayHisSwordByHisSidePoembyThomasMoore.txt',
 '../input/poemsdataset/forms/lay/LayPoemsLayAGarlandOnMyHearsePoembyBeaumontandFletcher.txt',
 '../input/poemsdataset/forms/lay/LayPoemsAsILayWithHeadInYourLapCameradoPoembyWaltWhitman.txt',
 '../input/poemsdataset/forms/lay/LayPoemsTheDeerLayDownTheirBonesPoembyRobinsonJeffers.txt']

In [3]:
len(files)

6322

In [4]:
all_texts = [open(f, encoding='utf8').read() for f in files]

In [5]:
text = [f for sublist in all_texts for f in sublist]

In [6]:
char_set = set(text)

print('Total Length:', len(text))
print('Unique characters:', len(char_set))

Total Length: 6697076
Unique characters: 1054


In [7]:
chars_sorted = sorted(char_set)
char2int = {ch:i for i, ch in enumerate(chars_sorted)}

char_array = np.array(chars_sorted)

text_encoded = np.array([char2int[ch] for ch in text], dtype=np.int32)

print('Text encoded shape: ', text_encoded.shape)

print(text[:15], '     == Encoding ==> ', text_encoded[:15])
print(text_encoded[15:21], ' == Reverse  ==> ', ''.join(char_array[text_encoded[15:21]]))

Text encoded shape:  (6697076,)
['L', 'a', 'y', ' ', 'a', ' ', 'g', 'a', 'r', 'l', 'a', 'n', 'd', ' ', 'o']      == Encoding ==>  [47 68 92  3 68  3 74 68 85 79 68 81 71  3 82]
[81  3 80 92  3 75]  == Reverse  ==>  n my h


In [8]:
len(text), len(text_encoded)

(6697076, 6697076)

In [9]:
for ex in text_encoded[:5]:
    print('{} -> {}'.format(ex, char_array[ex]))

47 -> L
68 -> a
92 -> y
3 ->  
68 -> a


In [10]:
seq_length = 20
chunk_size = seq_length + 1

text_chunks = [text_encoded[i:i+chunk_size] 
               for i in range(len(text_encoded)-chunk_size)] 

## inspection:
for seq in text_chunks[:1]:
    input_seq = seq[:seq_length]
    target = seq[seq_length] 
    print(input_seq, ' -> ', target)
    print(repr(''.join(char_array[input_seq])), 
          ' -> ', repr(''.join(char_array[target])))

[47 68 92  3 68  3 74 68 85 79 68 81 71  3 82 81  3 80 92  3]  ->  75
'Lay a garland on my '  ->  'h'


In [11]:
import torch
from torch.utils.data import Dataset

class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()
    


In [12]:
seq_dataset = TextDataset(torch.tensor(text_chunks))

In [13]:
for i, (seq, target) in enumerate(seq_dataset):
    print(' Input (x):', repr(''.join(char_array[seq])))
    print('Target (y):', repr(''.join(char_array[target])))
    print()
    if i == 1:
        break
    

 Input (x): 'Lay a garland on my '
Target (y): 'ay a garland on my h'

 Input (x): 'ay a garland on my h'
Target (y): 'y a garland on my he'



In [14]:
device = torch.device("cuda:0")
# device = 'cpu'

In [15]:
from torch.utils.data import DataLoader
 
batch_size = 64

torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [16]:
len(seq_dataset), len(seq_dl)

(6697055, 104641)

## Building a character-level RNN model

In [17]:
import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim) 
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size, 
                           batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden.to(device), cell.to(device)
    
vocab_size = len(char_array)
embed_dim = 64
rnn_hidden_size = 64

torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size) 
model = model.to(device)
model

RNN(
  (embedding): Embedding(1054, 64)
  (rnn): LSTM(64, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=1054, bias=True)
)

In [18]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10000

torch.manual_seed(1)

for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(seq_dl))
    seq_batch = seq_batch.to(device)
    target_batch = target_batch.to(device)
    optimizer.zero_grad()
    loss = 0
    for c in range(seq_length):
        pred, hidden, cell = model(seq_batch[:, c], hidden, cell) 
        loss += loss_fn(pred, target_batch[:, c])
    loss.backward()
    optimizer.step()
    loss = loss.item()/seq_length
    if epoch % 500 == 0:
        print(f'Epoch {epoch} loss: {loss:.4f}')

Epoch 0 loss: 6.9602
Epoch 500 loss: 2.6471
Epoch 1000 loss: 2.3748
Epoch 1500 loss: 2.3142
Epoch 2000 loss: 2.3239
Epoch 2500 loss: 2.3572
Epoch 3000 loss: 2.1681
Epoch 3500 loss: 2.1780
Epoch 4000 loss: 2.1150
Epoch 4500 loss: 2.0841
Epoch 5000 loss: 2.1928
Epoch 5500 loss: 2.1314
Epoch 6000 loss: 2.1072
Epoch 6500 loss: 1.9743
Epoch 7000 loss: 2.0691
Epoch 7500 loss: 2.0408
Epoch 8000 loss: 2.0885
Epoch 8500 loss: 2.0165
Epoch 9000 loss: 2.0640
Epoch 9500 loss: 2.0229


In [19]:
from torch.distributions.categorical import Categorical

def sample(model, starting_str, 
           len_generated_text=500, 
           scale_factor=1.0):

    encoded_input = torch.tensor([char2int[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1))

    generated_str = starting_str

    model.eval()
    hidden, cell = model.init_hidden(1)
    hidden = hidden.to('cpu')
    cell = cell.to('cpu')
    for c in range(len(starting_str)-1):
        _, hidden, cell = model(encoded_input[:, c].view(1), hidden, cell) 
    
    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):
        logits, hidden, cell = model(last_char.view(1), hidden, cell) 
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(char_array[last_char])
        
    return generated_str

torch.manual_seed(1)
model.to('cpu')
print(sample(model, starting_str='The island'))

The island call love his and suétor by in too weel deem, bless to thight cheef daughs of teathings cynempictrymenten's, with crears wech menoty set nears camain.X.
The neak.
A fas wish is to knop con
Hand soulds
Cronger-moor, I wnoinialtrains a wordghing in the chee
Irrated to me , ald the addous live
hout to to patter wilf bey over bordibark shown,
Of the sink warl;
Xst doin'd smake of Ca giver.
And gellious the woult,
Bood and reature-past to, sfaid and excher,)"
She lay wanky then' perplied thing, seep


In [None]:
# https://github.com/rasbt/machine-learning-book/blob/main/ch15/ch15_part3.ipynb

In [20]:
print(sample(model, starting_str='The silence tree'))

The silence tree my fate/1
The fust think rodum do nothould; at Gatu wial cruckled moterness with man
Bate spirgued now dore they
To ulsoners to they usen's sautor comes to take,
This flownim nots!
Whines
Anto fare,
He vin by outhous, , a eoces the fan rail stausly, Reza do of roose foroched my ge a catched beat comion.Yess an love wafts rince, sacally is follave of musccitish sheind or eiv tyink the gell away
trure now dorgay:
all the verround,
An greass of bear,
Ou toem' the schiffed them loid.
Wher bacting a


In [21]:
print(sample(model, starting_str='The silence tree',len_generated_text=50))

The silence tree!
Nor wab thear all somate up birm at the erchill 


In [22]:
print(sample(model, starting_str='In the end',len_generated_text=50))

In the ends,
right
her fide dival to their sun,
Nocks;
selt



In [23]:
print(sample(model, starting_str='I see the light',len_generated_text=100))

I see the light ©4 6ই prais, cruthereswing wonded hin humme,
'ou blowry and res
their grave;
Loviced seete his came
