<a href="https://colab.research.google.com/github/tg-93/min-char-rnn/blob/main/semi_linear_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

data = open('wot1.txt', 'r').read() # should be simple plain text file
chars = sorted(list(set(data)))
data_size, vocab_size = len(data), len(chars)
print(f'data has {data_size} characters, {vocab_size} unique.')
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }
print(ix_to_char)

data has 1630939 characters, 76 unique.
{0: '\n', 1: ' ', 2: '!', 3: '"', 4: '$', 5: "'", 6: '(', 7: ')', 8: ',', 9: '-', 10: '.', 11: '0', 12: '1', 13: '2', 14: '3', 15: '4', 16: '5', 17: '6', 18: '7', 19: '8', 20: '9', 21: ':', 22: ';', 23: '?', 24: 'A', 25: 'B', 26: 'C', 27: 'D', 28: 'E', 29: 'F', 30: 'G', 31: 'H', 32: 'I', 33: 'J', 34: 'K', 35: 'L', 36: 'M', 37: 'N', 38: 'O', 39: 'P', 40: 'Q', 41: 'R', 42: 'S', 43: 'T', 44: 'U', 45: 'V', 46: 'W', 47: 'Y', 48: 'Z', 49: '`', 50: 'a', 51: 'b', 52: 'c', 53: 'd', 54: 'e', 55: 'f', 56: 'g', 57: 'h', 58: 'i', 59: 'j', 60: 'k', 61: 'l', 62: 'm', 63: 'n', 64: 'o', 65: 'p', 66: 'q', 67: 'r', 68: 's', 69: 't', 70: 'u', 71: 'v', 72: 'w', 73: 'x', 74: 'y', 75: 'z'}


In [2]:
class Linear:
  def __init__(self, fan_in, fan_out, device='cpu'):
    self.w = torch.randn(fan_in, fan_out, device=device) / fan_in**0.5
    self.b = torch.zeros(1, fan_out, requires_grad=True, device = device)
    self.w.requires_grad = True

  def __call__(self, input):
    return input @ self.w + self.b

  def sample(self, input):
    return self.__call__(input)

  def params(self):
    return [self.w, self.b]

  def is_stateful(self):
    return False

  def __str__(self):
    return "Linear: {}".format(self.w.shape)

class Tanh:
  def __call__(self, input):
    self.out = torch.tanh(input)
    return self.out

  def sample(self, input):
    return self.__call__(input)

  def params(self):
    return []

  def is_stateful(self):
    return False

  def __str__(self):
    return "Tanh"

In [3]:
class VanillaRNN:
  def __init__(self, fan_in, hidden_size, batch_size, device='cpu', linear_time=False):
    self.hidden_size = hidden_size
    self.linear_time = linear_time
    self.hidden = torch.zeros(batch_size, hidden_size, requires_grad=True, device=device)
    self.lin = Linear(fan_in + hidden_size, hidden_size, device=device)
    self.tanh = Tanh()

  def is_stateful(self):
    return True

  def __call__(self, input, sample=False):
    xh = torch.hstack((input, self.hidden))
    if self.linear_time:
      self.hidden = self.lin(xh) # no tanh on feedforward through time
      return self.tanh(self.hidden)
    self.hidden = self.tanh(self.lin(xh))
    return self.hidden

  def sample(self, input):
    with torch.no_grad():
      if self.hidden_sample is None:
        self.hidden_sample = self.hidden[0,:].view(1, self.hidden_size)
      xh = torch.hstack((input, self.hidden_sample))
      if self.linear_time:
        self.hidden_sample = self.lin(xh)
        return self.tanh(self.hidden_sample)
      self.hidden_sample = self.tanh(self.lin(xh))
      return self.hidden_sample

  def reset(self):
    self.hidden = torch.zeros_like(self.hidden)
    self.hidden_sample = None

  def reset_grads(self):
    self.hidden.detach_()

  def params(self):
    return self.lin.params()

  def __str__(self):
    return "VanillaRNN: {}".format([p.shape for p in self.params()]) if not self.linear_time else "TimeLinearRNN: {}".format([p.shape for p in self.params()])

class LinearRNN:
  def __init__(self, fan_in, hidden_size, batch_size, device='cpu'):
    self.hidden_size = hidden_size
    self.hidden = torch.zeros(batch_size, hidden_size, requires_grad=True, device=device)
    self.lin = Linear(fan_in + hidden_size, hidden_size, device=device)
    self.tanh = Tanh() # this is stored to access the outputs stored inside it for analysis

  def __call__(self, input, sample=False):
    xh = torch.hstack((input, self.hidden))
    self.hidden = self.lin(xh) # no tanh on feedforward through time
    return self.tanh(self.hidden)

  def sample(self, input):
    with torch.no_grad():
      if self.hidden_sample is None:
        self.hidden_sample = self.hidden[0,:].view(1, self.hidden_size)
      xh = torch.hstack((input, self.hidden_sample))
      self.hidden_sample = self.lin(xh)
      return self.tanh(self.hidden_sample)

  def reset(self):
    self.hidden = torch.zeros_like(self.hidden)
    self.hidden_sample = None

  def reset_grads(self):
    self.hidden.detach_()

  def params(self):
    return self.lin.params()

  def is_stateful(self):
    return True

  def __str__(self):
    return "LinearRNN: {}".format([p.shape for p in self.params()])

class Sequential:
  def __init__(self, vocab_size, layers, device='cpu'):
    self.vocab_size = vocab_size
    self.layers = layers
    self.device = device
    params = []
    for layer in layers:
      params += layer.params()
    self.params = params
    count = sum([p.nelement() for p in self.params])
    print(self)
    print(f'Parameter Count: {count}')

  def __call__(self, inputs):
    logits = torch.zeros(inputs.shape[0], inputs.shape[1], self.vocab_size, dtype=torch.float32, device=self.device)
    for t in range(inputs.shape[1]):
      x = F.one_hot(inputs[:,t], self.vocab_size).float()
      for layer in self.layers:
        x = layer(x)
      logits[:, t] = x
    self.reset()
    return logits

  def sample(self, input, n):
    self.reset()
    samples = [input]
    for t in range(n):
      x = F.one_hot(samples[-1], self.vocab_size).float().view(1, self.vocab_size)
      for layer in self.layers:
        x = layer.sample(x)
      probs = torch.softmax(x, dim=1)
      samples.append(torch.multinomial(probs, 1)[0])
    return samples

  def reset(self):
    for layer in self.layers:
      if layer.is_stateful():
        layer.reset()

  def __str__(self):
    s = "Sequential:\n"
    for layer in self.layers:
      s += str(layer) + "\n"
    return s

In [4]:
batch_size = 64
model = Sequential(vocab_size,
 [VanillaRNN(vocab_size, 500, batch_size, device='cuda', linear_time=True),
  VanillaRNN(500, 400, batch_size, device='cuda'),
  Linear(400, vocab_size, device='cuda')], device='cuda')

Sequential:
VanillaRNN: [torch.Size([576, 500]), torch.Size([1, 500])]
VanillaRNN: [torch.Size([900, 400]), torch.Size([1, 400])]
Linear: torch.Size([400, 76])

Parameter Count: 679376


In [5]:
smooth_loss = -torch.log(torch.tensor(1.0/vocab_size, device='cuda')) # loss at iteration 0
batch_gap = int(len(data)//batch_size)
seq_length = 64
n = 0
optimizer = torch.optim.Adam(model.params)
input_end = len(data) - seq_length
for n in range(50000):
  batch_starts = np.random.randint(0, input_end, batch_size)
  inputs = torch.tensor([[char_to_ix[ch] for ch in data[b_i : b_i + seq_length]] for b_i in batch_starts], device='cuda')
  targets = torch.tensor([[char_to_ix[ch] for ch in data[b_i + 1 : b_i + seq_length + 1]] for b_i in batch_starts], device='cuda')
  # sample from the model now and then
  if n>0 and n%200 == 0:
    samples = model.sample(inputs[0, 0].detach(), 200)
    # print(samples)
    txt = ''.join(ix_to_char[ix.item()] for ix in samples)
    print('*** Sample: ***')
    print(f'----\n{txt}\n----')
  # forward seq_length characters through the net and fetch gradient
  logits = model(inputs)
  loss = torch.tensor(0.0, device='cuda')
  for b_i in range(batch_size):
    loss += F.cross_entropy(logits[b_i], targets[b_i])
  loss /= batch_size
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  smooth_loss = smooth_loss * 0.995 + loss * 0.005
  if n>0 and n % 50 == 0:
    print(f'iter: {n}, loss: {smooth_loss.item()}') # print progress
    # plt.figure(figsize=(20, 4)) # width and height of the plot
    # legends = []
    for i, layer in enumerate(model.layers[:-1]): # note: exclude the output layer
      if isinstance(layer, LinearRNN) or isinstance(layer, VanillaRNN):
        t = layer.tanh.out
        print('layer %d (%10s): mean %+.2f, std %.2f, saturated: %.2f%%' % (i, layer.__class__.__name__, t.mean(), t.std(), (t.abs() > 0.97).float().mean()*100))
    #     hy, hx = torch.histogram(t, density=True)
    #     plt.plot(hx[:-1].detach(), hy.detach())
    #     legends.append(f'layer {i} ({layer.__class__.__name__}')
    # plt.legend(legends);
    # plt.title('activation distribution')

iter: 50, loss: 4.102715492248535
layer 0 (VanillaRNN): mean +0.01, std 1.00, saturated: 100.00%
layer 1 (VanillaRNN): mean +0.06, std 0.95, saturated: 77.07%
iter: 100, loss: 3.8804380893707275
layer 0 (VanillaRNN): mean +0.01, std 1.00, saturated: 100.00%
layer 1 (VanillaRNN): mean +0.06, std 0.96, saturated: 79.84%
iter: 150, loss: 3.7070417404174805
layer 0 (VanillaRNN): mean +0.02, std 1.00, saturated: 100.00%
layer 1 (VanillaRNN): mean +0.06, std 0.97, saturated: 81.11%


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0