In [1]:
import torch
import torch.nn as nn
import numpy as np

# Get Data (Shakespeare)

In [2]:
# store text in one long string called 'data'
f = open("./shakespeare.txt")
data = f.read()

In [5]:
# create dictionary to one-hot encode each characters to
char2int = dict([(c, i) for i, c in enumerate(set(data))])
int2char = dict([(char2int[k], k) for k in char2int])

In [6]:
print(char2int)
print(int2char)

{'.': 0, 'n': 1, 'U': 2, 't': 3, 's': 4, 'K': 5, 'q': 6, 'u': 7, 'g': 8, 'F': 9, 'O': 10, 'N': 11, 'E': 12, 'v': 13, 'e': 14, 'w': 15, 'B': 16, '-': 17, 'W': 18, 'z': 19, ';': 20, '!': 21, ',': 22, 'j': 23, 'r': 24, 'D': 25, 'G': 26, 'l': 27, ')': 28, 'P': 29, 'c': 30, '\n': 31, 'd': 32, 'M': 33, 'p': 34, 'h': 35, ':': 36, 'T': 37, 'C': 38, 'Y': 39, 'H': 40, 'S': 41, 'y': 42, 'b': 43, 'R': 44, 'a': 45, 'm': 46, 'V': 47, 'i': 48, 'o': 49, ' ': 50, 'f': 51, "'": 52, 'L': 53, '?': 54, '(': 55, 'A': 56, 'I': 57, 'k': 58, 'J': 59, 'x': 60}
{0: '.', 1: 'n', 2: 'U', 3: 't', 4: 's', 5: 'K', 6: 'q', 7: 'u', 8: 'g', 9: 'F', 10: 'O', 11: 'N', 12: 'E', 13: 'v', 14: 'e', 15: 'w', 16: 'B', 17: '-', 18: 'W', 19: 'z', 20: ';', 21: '!', 22: ',', 23: 'j', 24: 'r', 25: 'D', 26: 'G', 27: 'l', 28: ')', 29: 'P', 30: 'c', 31: '\n', 32: 'd', 33: 'M', 34: 'p', 35: 'h', 36: ':', 37: 'T', 38: 'C', 39: 'Y', 40: 'H', 41: 'S', 42: 'y', 43: 'b', 44: 'R', 45: 'a', 46: 'm', 47: 'V', 48: 'i', 49: 'o', 50: ' ', 51: 'f',

In [7]:
# one-hot encode and move to Tensors
x_encoded = nn.functional.one_hot((torch.Tensor([char2int[x] for x in data]).long()))

In [8]:
# example of reverting tensor back to single char for use later
int2char[np.argmax(x_encoded[1].tolist())]

'H'

In [9]:
print(x_encoded.shape)

torch.Size([94275, 61])


# Build RNN

In [118]:

class RNN(nn.Module):
    def __init__(self, hidden_dim):
        super(RNN, self).__init__()

        self.hidden_dim = hidden_dim
        
        self.in_linear = nn.Linear(61, self.hidden_dim, bias=False)
        self.hid_linear = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.out_linear = nn.Linear(self.hidden_dim, 61)

    def forward(self, x, hidden):
        if hidden == None:
            hidden = torch.zeros(self.hidden_dim).to(device)
        
        ht = nn.functional.relu(self.in_linear(x) + self.hid_linear(hidden))
        ot = self.out_linear(ht)
        return ht, ot
        

In [119]:
model = RNN(200)

# Set System Standards

In [120]:
# set device for pytorch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# set loss function as cross entropy loss, optimizer and SGD
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

cpu


In [121]:
# move everything to CUDA
model = model.to(device)
x_encoded = x_encoded.float().to(device)

In [130]:
from tqdm import trange

def train(model, criterion, optimizer, data):

    hidden = None
    correct = 0
    
    # set to train mode
    model.train()
    for i in range(len(data)):
        # skip first index, since at this point we have no context to train with
        if i == 0:
            continue

        # label is the current char data[i] and input is the previous char data[i-1]
        optimizer.zero_grad()
        hidden, pred = model(data[i-1], None)
        loss = criterion(pred, data[i])
        
        correct += np.argmax(pred.tolist())==np.argmax(data[i].tolist())

        # grad. descent
        loss.backward()
        optimizer.step()
        
        if i%10000==0:
            print(f"accuracy={correct/10000.0}")
            correct = 0
        

In [131]:
epochs = 1
for t in range(epochs):
    print(f"Epoch {t} -----------------------------------")
    train(model, criterion, optimizer, x_encoded)

Epoch 0 -----------------------------------
accuracy=0.2845
accuracy=0.2917
accuracy=0.2945
accuracy=0.2763
accuracy=0.2854
accuracy=0.2825
accuracy=0.2704
accuracy=0.2697
accuracy=0.2868


# Generate Text from Trained Model

In [134]:
# switch to eval mode here to prevent gradients from being calculated
model.eval()

# extra characters to make
generate_size = 100

# keep running
text = x_encoded[2]
text = text.reshape(1, 61)

hidden = None
for i in range(generate_size):
    hidden, next_chr = model(text[i], hidden)
    text = torch.cat((text, next_chr.reshape(1, next_chr.shape[0])), dim=0)

text_real = [int2char[np.argmax(text[i].tolist())] for i in range(len(text))]

In [137]:
print(''.join(text_real))

E d   nene  nene e ene e ene e ene e e e e e e e e ..................................................
