In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

In [2]:
with open("anna.txt", "r") as f:
  text = f.read()

In [3]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars = tuple(set(text))

int2char = dict(enumerate(chars))

char2int = {ch: ii for ii, ch in int2char.items()}

encoded = np.array([char2int[ch] for ch in text])

In [5]:
encoded[:100]

array([81,  4, 12, 25, 41, 32, 66, 20,  1, 45, 45, 45, 80, 12, 25, 25, 74,
       20, 82, 12, 71, 61, 56, 61, 32, 28, 20, 12, 66, 32, 20, 12, 56, 56,
       20, 12, 56, 61, 76, 32,  6, 20, 32, 24, 32, 66, 74, 20,  2,  3,  4,
       12, 25, 25, 74, 20, 82, 12, 71, 61, 56, 74, 20, 61, 28, 20,  2,  3,
        4, 12, 25, 25, 74, 20, 61,  3, 20, 61, 41, 28, 20, 38, 50,  3, 45,
       50, 12, 74, 55, 45, 45, 21, 24, 32, 66, 74, 41,  4, 61,  3])

In [6]:
def one_hot_encode(arr, n_labels):
  one_hot = np.zeros((arr.size, n_labels), dtype = np.float32)

  one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1

  one_hot = one_hot.reshape((*arr.shape, n_labels))

  return one_hot

In [7]:
test = np.array([[2, 3, 6]])

print(one_hot_encode(test, 8))

[[[0. 0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 1. 0.]]]


In [8]:
def get_batches(arr, batch_size, seq_length):
  total_batch_size = batch_size * seq_length

  n_batches = len(arr)//total_batch_size

  arr = arr[:n_batches*total_batch_size]

  arr = arr.reshape((batch_size, -1))

  for n in range(0, arr.shape[1], seq_length):
    x = arr[:,n : n+seq_length]

    y = np.zeros_like(x)

    try:
      y[:,:-1], y[:, -1] = x[:,1:], arr[:,n+seq_length]
    except IndexError:
      y[:,:-1], y[:,-1] = x[:,1:], arr[:,0]
    yield x, y

In [9]:
batches = get_batches(encoded, 8, 50)

x, y = next(batches)

In [10]:
print(f'x=> {x[:10,:10]}')
print(f'y=> {y[:10,:10]}')

x=> [[81  4 12 25 41 32 66 20  1 45]
 [28 38  3 20 41  4 12 41 20 12]
 [32  3 51 20 38 66 20 12 20 82]
 [28 20 41  4 32 20 53  4 61 32]
 [20 28 12 50 20  4 32 66 20 41]
 [53  2 28 28 61 38  3 20 12  3]
 [20 57  3  3 12 20  4 12 51 20]
 [36 65 56 38  3 28 76 74 55 20]]
y=> [[ 4 12 25 41 32 66 20  1 45 45]
 [38  3 20 41  4 12 41 20 12 41]
 [ 3 51 20 38 66 20 12 20 82 38]
 [20 41  4 32 20 53  4 61 32 82]
 [28 12 50 20  4 32 66 20 41 32]
 [ 2 28 28 61 38  3 20 12  3 51]
 [57  3  3 12 20  4 12 51 20 28]
 [65 56 38  3 28 76 74 55 20 40]]


In [11]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

print(device)

cuda


In [12]:
class charRNN(nn.Module):
  def __init__(self, tokens, hidden_dim = 512, n_layers = 2, drop_prob = 0.5, lr = 0.01):
    super().__init__()

    self.hidden_dim = hidden_dim
    self.n_layers = n_layers
    self.drop_prob = drop_prob
    self.lr = lr

    self.chars = tokens
    self.int2char = dict(enumerate(self.chars))
    self.char2int = {ch:ii for ii, ch in self.int2char.items()}

    self.lstm = nn.LSTM(len(self.chars), hidden_dim, n_layers, dropout=drop_prob, batch_first = True)

    self.dropout = nn.Dropout(p=drop_prob)

    self.fc = nn.Linear(hidden_dim, len(self.chars))
  
  def forward(self, x, hidden):

    output, hidden = self.lstm(x, hidden)

    output = self.dropout(output)

    output = output.contiguous().view(-1,self.hidden_dim)

    output = self.fc(output)

    return output, hidden
  
  def init_hidden(self, batch_size):
    weight = next(self.parameters()).data

    hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
              weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device))

    return hidden

In [13]:
def train(net, data, epochs = 10, batch_size = 10, seq_length = 50, lr = 0.001, clip = 5, val_frac = 0.1, for_every = 10):
  net.train()

  opt = torch.optim.Adam(net.parameters(), lr = lr)
  criterion = torch.nn.CrossEntropyLoss()

  val_idx = int((1-val_frac)*len(data))
  data, val_data = data[:val_idx], data[val_idx:]

  net.to(device)

  counter = 0
  n_chars = len(net.chars)

  for epoch in range(epochs):
    h = net.init_hidden(batch_size)

    for x, y in get_batches(data, batch_size, seq_length):
      counter += 1

      x = one_hot_encode(x, n_chars)

      inputs, targets = torch.from_numpy(x), torch.from_numpy(y)

      h = tuple([each.data for each in h])

      inputs = inputs.to(device)
      targets = targets.to(device)

      outputs, h = net(inputs, h)

      loss = criterion(outputs, targets.view(batch_size*seq_length).long())
      opt.zero_grad()
      loss.backward()
      nn.utils.clip_grad_norm_(net.parameters(), clip)
      opt.step()

      if counter%for_every == 0:
        val_h = net.init_hidden(batch_size)
        val_losses = []
        net.eval()
        for x, y in get_batches(val_data, batch_size, seq_length):
          x = one_hot_encode(x, n_chars)
          inputs, targets = torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)
          val_h = tuple([each.data for each in val_h])
          outputs, val_h = net(inputs, val_h)
          val_loss = criterion(outputs, targets.view(batch_size*seq_length).long())
          val_losses.append(val_loss.item())
      
        net.train()

        print(f"{epoch+1}/{epochs}, counter = {counter}, train_loss = {loss.item()}, val_loss = {np.mean(val_losses)}")


In [14]:
hidden_dim = 512
n_layers = 2

# print(hidden_dim)
net = charRNN(chars,hidden_dim, n_layers)
print(net)

charRNN(
  (lstm): LSTM(83, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=83, bias=True)
)


In [15]:
batch_size = 128
seq_length = 100
n_epochs = 20

train(net, encoded, epochs = n_epochs, batch_size=batch_size, seq_length=seq_length, lr = 0.001)

1/20, counter = 10, train_loss = 3.2600297927856445, val_loss = 3.2011275768280028
1/20, counter = 20, train_loss = 3.1458301544189453, val_loss = 3.1305514494578044
1/20, counter = 30, train_loss = 3.1346373558044434, val_loss = 3.1215799331665037
1/20, counter = 40, train_loss = 3.1121761798858643, val_loss = 3.1190361340840655
1/20, counter = 50, train_loss = 3.1393866539001465, val_loss = 3.117169189453125
1/20, counter = 60, train_loss = 3.1170217990875244, val_loss = 3.1149874210357664
1/20, counter = 70, train_loss = 3.106755256652832, val_loss = 3.112496280670166
1/20, counter = 80, train_loss = 3.116550922393799, val_loss = 3.1056500752766927
1/20, counter = 90, train_loss = 3.1061971187591553, val_loss = 3.090000947316488
1/20, counter = 100, train_loss = 3.066688299179077, val_loss = 3.0538565158843993
1/20, counter = 110, train_loss = 3.0306363105773926, val_loss = 3.010475524266561
1/20, counter = 120, train_loss = 2.97794246673584, val_loss = 2.9088003158569338
1/20, coun

In [16]:
model_name = "rnn_20_epoch_charRNN.net"

checkpoint = {
    'hidden_dim':net.hidden_dim,
    'n_layers':net.n_layers,
    'state_dict':net.state_dict(),
    'tokens':net.chars
}

In [17]:
with open(model_name, 'wb') as f:
  torch.save(checkpoint, f)

In [18]:
model = torch.load(model_name)

In [22]:
def predict(net, char, h = None, top_k = None):
  net.to(device)

  h = tuple([each.data for each in h])

  inp = np.array([[net.char2int[char]]])
  inp = one_hot_encode(inp, len(net.chars))

  inp = torch.from_numpy(inp)

  inp = inp.to(device)

  
  out, h = net(inp, h)
  
  p = torch.nn.functional.softmax(out, dim=1).data

  p = p.cpu()

  if top_k is None:
    top_ch = np.arange(len(net.chars))
  else:
    p, top_ch = p.topk(top_k)
    top_ch = top_ch.numpy().squeeze()
  
  p = p.numpy().squeeze()
  char = np.random.choice(top_ch, p = p/p.sum())

  return net.int2char[char], h

In [23]:
def sample(net, size, prime = "The", top_k = None):
  net.to(device)

  chars = [ch for ch in prime]

  net.eval()
  
  h = net.init_hidden(1)

  for ch in prime:
    char, h = predict(net, ch, h, top_k=top_k)
  
  chars.append(char)

  for ii in range(size):
    char, h = predict(net, chars[-1], h, top_k=top_k)
    chars.append(char)
  
  return ''.join(chars)

In [24]:
print(sample(net, 500, prime = "Anna", top_k = 5))

Anna's, that
there was nothing to go to him. The sight of anything she could not come on
all her ten thing, but the setsed conversation of the
candle, then talking on the crowd as the money them had been a
good-nut one to him, and they saw that his bride stour
she saw the sound and supple arressing the priest, who had
been busines, and
suppositively was not seeming, and he had standed out the sight and
things, he felt her husband to see his hands were the second
condition, his face too hade and ask i


In [25]:
print(sample(net, 500, prime = "Vathsa", top_k = 5))

Vathsa," said
Levin.

"Well, the chances of your face of the sort of terrible arouses in
his fruchlin, though, to that is they are, what I was standing and
tensters as the minutes, was in the state, but the most mosers for the
sour of her fact it are more into them, was a little again."

Alexey Alexandrovitch smedled as his face, was as how to be
service. Her hands almost come at the better of the story, was
as it all the prince. He can off on his brother, he felt that the princess
saw the same thing w


In [27]:
with open("rnn_20_epoch_charRNN.net", "rb") as f:
  checkpoint = torch.load(f)

loaded = charRNN(checkpoint["tokens"] ,hidden_dim = checkpoint["hidden_dim"], n_layers = checkpoint["n_layers"])
loaded.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [29]:
print(sample(loaded, 500, 'Dara', top_k = 5))

Daral
Stepan Arkadyevitch, but at him till she had taken a strange time, and the
counting mare was at that moment with his way all that. The dropped he
had taken under the department with her. She stopped him. They
could not hear some starches with a legter. He could not have to say in
speaking, and their conversation, and threw happiness, and
his face was to tried. Still he had been discassed on the corn in
the marsh of them, and and she was daightering that the
servants of the morness of their posi
