In [27]:
import torch
import torch.nn as nn
from torch.distributions import Categorical
import torch.nn.functional as F

In [28]:
path_to_file = r'shakespeare.txt'

In [29]:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print('Length of text: {} characters'.format(len(text)))

Length of text: 1115394 characters


In [30]:
print(text[:400])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it 


In [31]:
if torch.cuda.is_available():
  device = "cuda:0"
else:
  device = "cpu"

In [32]:
vocab = set(text)

In [33]:
char2idx = {ch:i for i,ch in enumerate(vocab)}
idx2char = {i:ch for i,ch in enumerate(vocab)}

In [37]:
class MyModel(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(MyModel, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.embedding = nn.Embedding(input_size, input_size)
    self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=2)
    self.decoder = nn.Linear(hidden_size, output_size)
  def forward(self, x, hidden):
    embedding = self.embedding(x)
    outputs, h_n = self.rnn(embedding, hidden)
    outputs = self.decoder(outputs)
    return outputs, h_n.detach()

In [38]:
vocab_size = len(vocab)
hidden_size = 512
learning_rate = 0.0001
input_size = vocab_size
output_size = vocab_size
seq_length = 100
text_size = len(text)
EPOCHS = 15

In [39]:
inputs = [char2idx[i] for i in text[:seq_length]]
outputs = [char2idx[i] for i in text[1:seq_length+1]]

In [41]:
enc_text = [char2idx[i] for i in text]
enc_text = torch.tensor(enc_text).long().to(device)
enc_text.unsqueeze_(1)
enc_text.shape

torch.Size([1115394, 1])

In [42]:
inputs = enc_text[:seq_length]
# inputs = inputs.unsqueeze(1)
inputs.shape

torch.Size([100, 1])

In [None]:
model = MyModel(input_size, hidden_size, output_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [47]:
# LEARNING RATE = 0.0005
import numpy as np
for epoch in range(EPOCHS):
  p = 0
  seq_len = seq_length
  total_loss = 0
  hidden = None
  cnt = 0
  while p+seq_len+1 < text_size:
    inputs = enc_text[p:p+seq_length]
    targets = enc_text[p+1:p+seq_length+1]
    outputs, hidden = model(inputs, hidden)
    loss = loss_fn(torch.squeeze(outputs), torch.squeeze(targets))
    total_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cnt += 1
    p += seq_length
    if cnt % 500 == 0:
      print('|',end='')
  print('\n--------------------------')
      # print("Cnt : {} Loss : {:5f}".format(cnt, total_loss/cnt))
  print("Epoch : {} Epoch_loss : {:5f}".format(epoch+1, total_loss/cnt))
  total_loss = 0
  cnt = 0
  with torch.no_grad():
    rand_index = np.random.randint(text_size-1)
    input_seq = enc_text[rand_index:rand_index+1]
    print(idx2char[input_seq[0][0].item()], end='')
    hidden = None
    size = 0
    while size < 200:
      outputs, hidden = model(input_seq, hidden)
      outputs = F.softmax(torch.squeeze(outputs), dim=0)
      dist = Categorical(outputs)
      index = dist.sample()
      # input_seq[0][0] = index.item()
      input_seq[0][0] = index.item()
      # input = input.unsqueeze(0)
      print(idx2char[index.item()], end='')
      size += 1
    print("\n-----------------------")

||||||||||||||||||||||
--------------------------
Epoch : 1 Epoch_loss : 1.376428
thy,
Unsure gentlemony toward last cause,
Should feable again'd gave time it a strepl
To our weets, fouches the dwelter minuble and
The lightsable. When, trelow,
The unlikent after men! hear, kemire,
W
-----------------------
||||||||||||||||||||||
--------------------------
Epoch : 2 Epoch_loss : 1.308181
ozQ
AN wise, was a words perfactionloban!

ANTONIO:
No, sir, my widowing.

Lard Katharine:
Is tending the biek of wancentate,
Than old has a little dlumbering frem it.

GONZALO:
So, in Byich, Good Trum
-----------------------
||||||||||||||||||||||
--------------------------
Epoch : 3 Epoch_loss : 1.260548
ent.

ALONSO:
What was not she hath like; fie, thou hast vent there weak;
now, pristress in youry-qound entreater, both the death:
Who would be dolination build thy lights be blind.

MENENIUS:
Nothing 
-----------------------
||||||||||||||||||||||
--------------------------
Epoch : 4 Epoch_loss : 1.2

In [48]:
#LEARNING RATE = 0.0001

for epoch in range(EPOCHS):
  p = 0
  seq_len = seq_length
  total_loss = 0
  hidden = None
  cnt = 0
  while p+seq_len+1 < text_size:
    inputs = enc_text[p:p+seq_length]
    targets = enc_text[p+1:p+seq_length+1]
    outputs, hidden = model(inputs, hidden)
    loss = loss_fn(torch.squeeze(outputs), torch.squeeze(targets))
    total_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cnt += 1
    p += seq_length
    if cnt % 500 == 0:
      print('|',end='')
  print('\n--------------------------')
      # print("Cnt : {} Loss : {:5f}".format(cnt, total_loss/cnt))
  print("Epoch : {} Epoch_loss : {:5f}".format(epoch+1, total_loss/cnt))
  total_loss = 0
  cnt = 0
  with torch.no_grad():
    rand_index = np.random.randint(text_size-1)
    input_seq = enc_text[rand_index:rand_index+1]
    print(idx2char[input_seq[0][0].item()], end='')
    hidden = None
    size = 0
    while size < 300:
      outputs, hidden = model(input_seq, hidden)
      outputs = F.softmax(torch.squeeze(outputs), dim=0)
      dist = Categorical(outputs)
      index = dist.sample()
      # input_seq[0][0] = index.item()
      input_seq[0][0] = index.item()
      # input = input.unsqueeze(0)
      print(idx2char[index.item()], end='')
      size += 1
    print("\n-----------------------")

||||||||||||||||||||||
--------------------------
Epoch : 1 Epoch_loss : 0.890426
NGEL:
To-morrow.

ANTONIO:
Why thence pitch--home?

GONZALO:
It is foul postern built what I bones on me;
Nay, if thou fond pursues truly.

VINCENTIO:
Content thee.

GONZALO:
Who in love of is?

ANTONIO:
Stop the harvest freek, on Thomas that worthy spent
Is mock'd with a little; and would say, if st
-----------------------
||||||||||||||||||||||
--------------------------
Epoch : 2 Epoch_loss : 0.864858
ike:
He dropp'd upon my state, not again.

HORTENSIO:
Come; I fear the storm. Would I make money?

ANTONIO:
Faith, dear madam.

GONZALO:
It was against us.

GONZALO:
And she have both entreat me horse; on my hand
In sweet wagunation. I
affer this wide of more kind
In your instructions for Angelo;
I m
-----------------------
||||||||||||||||||||||
--------------------------
Epoch : 3 Epoch_loss : 0.839790
t,
O, I am put time to cross to soo!
So fain!

SEBASTIAN:
Why, away, toward Signior Baptista.

GONZALO

In [49]:
total_loss = 0
cnt = 0
with torch.no_grad():
  rand_index = np.random.randint(text_size-1)
  input_seq = enc_text[rand_index:rand_index+1]
  print(idx2char[input_seq[0][0].item()], end='')
  hidden = None
  size = 0
  while size < 800:
    outputs, hidden = model(input_seq, hidden)
    outputs = F.softmax(torch.squeeze(outputs), dim=0)
    dist = Categorical(outputs)
    index = dist.sample()
    # input_seq[0][0] = index.item()
    input_seq[0][0] = index.item()
    # input = input.unsqueeze(0)
    print(idx2char[index.item()], end='')
    size += 1
  print("\n-----------------------")

eds:
That Claudio, Signior Capitoo,
Whom I ever sent for shame; for I have spoke,
To be thy need to make them obeys, good daughter!
O, how peaches it to do so! My father hath
stay; the frosting fury seal o' the sister.

COMINIUS:
Hear me, my drink.

SEBASTIAN:
Sir, but he's for the fight!

PETRUCHIO:
Go far on; marry, speak take, and then go with me;
Nor I am coming that you made, my life
Uncur to choose away our neck orderers;
The false Binnot dry your eye, but indeed
Name hath stifl us.

ADRIAN:
And most comfort's.

ANTONIO:
Not possible.

Boatswain:
Would I make mercy my sister with honey-gabse-winged
what when it would contrary at thy abject son.

GONZALO:
I would be very wrong, you think.

OnOLUME:
Should by this shame.
How now! what makes you?

ARIEL:
So long as violent for Tranio in 
-----------------------


In [55]:
    torch.save(model.state_dict(), 'model.pt') 