# RNN Text Generation

## Imports

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

## Text File Import

In [2]:
with open('shakespeare.txt','r',encoding='utf8') as f:
    text = f.read()

In [3]:
type(text)

str

In [21]:
print(text[:670])


                     1
  From fairest creatures we desire increase,
  That thereby beauty's rose might never die,
  But as the riper should by time decease,
  His tender heir might bear his memory:
  But thou contracted to thine own bright eyes,
  Feed'st thy light's flame with self-substantial fuel,
  Making a famine where abundance lies,
  Thy self thy foe, to thy sweet self too cruel:
  Thou that art now the world's fresh ornament,
  And only herald to the gaudy spring,
  Within thine own bud buriest thy content,
  And tender churl mak'st waste in niggarding:
    Pity the world, or else this glutton be,
    To eat the world's due, by the grave and thee.


  


In [7]:
len(text)

5445609

## Text Encoding

In [22]:
#all unique characters
all_characters = set(text)

In [25]:
len(all_characters)

84

In [31]:
# DECODER
# number to letter

'''
for pair in enumerate(all_characters):
    print(pair)
'''

decoder = dict(enumerate(all_characters))

In [33]:
# ENCODER
# letter to number

encoder = {char: ind for ind,char in decoder.items()}

In [41]:
# full text encoding (characters to numericall values)

encoded_text = np.array([encoder[char] for char in text])

In [43]:
len(encoded_text) == len(text)

True

In [44]:
encoded_text[:100]

array([75, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61, 61,
       61, 61, 61, 61, 61, 39, 75, 61, 61,  6, 17, 54, 20, 61, 38,  9, 12,
       17, 28, 40, 62, 61, 66, 17, 28,  9, 62, 13, 17, 28, 40, 61, 49, 28,
       61,  7, 28, 40, 12, 17, 28, 61, 12, 25, 66, 17, 28,  9, 40, 28, 29,
       75, 61, 61, 83, 14,  9, 62, 61, 62, 14, 28, 17, 28, 81,  0, 61, 81,
       28,  9, 13, 62,  0,  5, 40, 61, 17, 54, 40, 28, 61, 20, 12])

## One Hot Encoding

In [48]:
#number of unique characters
num_uni_chars = len(set(text))

In [70]:
def one_hot_encoder(encoded_text , num_uni_chars):
    
    #encoded text - batch of encoded text
    #num_uni_chars - number of unique characters in whole text file
    
    one_hot = np.zeros((encoded_text.size , num_uni_chars)) #prepare array with correct dimensions
    
    one_hot = one_hot.astype(np.float32) #data type for PyTorch
    
    one_hot[np.arange(one_hot.shape[0]),encoded_text.flatten()] = 1.0 #put ones in the position which coresponds to encoded char value
    
    one_hot = one_hot.reshape(*encoded_text.shape,num_uni_chars)
    
    return one_hot

In [71]:
# testing one_hot_encoder on small array
arr = np.array([1,2,0])
arr

array([1, 2, 0])

In [72]:
one_hot_encoder(arr,3)

array([[0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.]], dtype=float32)

## Training Batches

In [76]:
example_text = np.arange(10)

In [77]:
example_text

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [78]:
example_text.reshape(5,-1)

array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])

In [79]:
def generate_batches(encoded_text , sam_per_batch = 10 , seq_len=50):
    
    # X -> encoded text of length 'seq_len'
    # Y -> encoded text shifted by 1
    
    # how many characters per batch
    char_per_batch = sam_per_batch * seq_len
    
    # how many batches possible in entire text
    num_batches_avail = int(len(encoded_text)/char_per_batch)
    
    # cut off the end of the encoded text
    encoded_text = encoded_text[:num_batches_avail*char_per_batch]
    
    
    encoded_text = encoded_text.reshape(sam_per_batch,-1)
    
    for n in range(0,encoded_text.shape[1],seq_len):
        
        x = encoded_text[:,n:n+seq_len]
        
        y = np.zeros_like(x)
        
        try:
            
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,n+seq_len]
            
        except:
            
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:,0]
            
        yield x,y         
            

In [140]:
# generator test
sample_text = np.arange(20)

In [141]:
sample_text

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19])

In [142]:
batch_generator = generate_batches(sample_text,sam_per_batch=2,seq_len=5)

In [143]:
type(batch_generator)

generator

In [144]:
x,y = next(batch_generator)

In [145]:
x

array([[ 0,  1,  2,  3,  4],
       [10, 11, 12, 13, 14]])

In [146]:
y

array([[ 1,  2,  3,  4,  5],
       [11, 12, 13, 14, 15]])

In [147]:
x,y = next(batch_generator)

In [148]:
x

array([[ 5,  6,  7,  8,  9],
       [15, 16, 17, 18, 19]])

In [149]:
y

array([[ 6,  7,  8,  9,  0],
       [16, 17, 18, 19, 10]])