In [87]:
from tinygrad.tensor import Tensor
from tinygrad.nn import Linear
from extra.utils import print_tree
from tinygrad.nn.optim import Adam
from tinygrad.jit import TinyJit
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
import numpy as np

## data

In [2]:
with open("./training.txt") as f:
    data = f.read()

In [3]:
chars = ["","<unk>"]+list(set(data))
char2idx = {c:i for i,c in enumerate(chars)}
nchars = len(chars)

In [4]:
def encode_text(*text: list[str]):
    maxlen = max(len(t) for t in text)
    X = np.zeros((len(text), maxlen, nchars))
    for i,line in enumerate(text):
        for j,char in enumerate(line):
            if char in char2idx:X[i, j, char2idx[char]] = 1
            else: X[i, j, char2idx['<unk>']] = 1
    return Tensor(X,dtype=dtypes.float) 

def decode_text(X:Tensor):
    X = X.numpy()
    return ["".join([chars[char.argmax()] for char in line])
            for line in X]

assert decode_text(encode_text("hello","world€€€")) == ['hello','world<unk><unk><unk>']

## model

In [128]:
class RNN():

    def __init__(self, input_size, hidden_size, output_size):
        self.hh = Linear(hidden_size, hidden_size)
        self.xh = Linear(input_size, hidden_size)
        self.hy = Linear(hidden_size, output_size)
        self.h = Tensor.zeros(1,hidden_size)
        self.layers = [self.hh, self.xh, self.hy]
        self.params = [param for layer in self.layers for param in (layer.weight, layer.bias)]


def forward (rnn:RNN,x:Tensor,h:Tensor):
    h = (rnn.hh(h.layernorm()) + rnn.xh(x)).relu().realize()
    return rnn.hy(h).softmax().realize(), h

def call(rnn,X:Tensor):
    assert X.shape [2] == nchars
    res = []
    
    for s in X:
        seq = []
        rnn.h = Tensor.zeros(rnn.h.shape)
        jitted_forward = TinyJit(forward)

        for x in s:
            x_ = Tensor(x.lazydata.buffers[0].toCPU()).reshape(1,-1)
            p,rnn.h = jitted_forward(rnn,x_,rnn.h)
            seq.append(p)
        res.append(Tensor.cat(*seq,dim=0))
    return Tensor.cat(*res).reshape(X.shape[0], X.shape[1], -1)

rnn = RNN(nchars, 256, nchars)


In [129]:
pred = call(rnn,encode_text("hello","world"))
print_tree(pred)

  0 ━┳ RESHAPE (2, 5, 67)
  1  ┗━┳ ADD  
  2    ┣━┳ ADD 
  3    ┃ ┣━┳ ADD 
  4    ┃ ┃ ┣━┳ ADD 
  5    ┃ ┃ ┃ ┣━┳ ADD 
  6    ┃ ┃ ┃ ┃ ┣━┳ PAD ((0, 9), (0, 0))
  7    ┃ ┃ ┃ ┃ ┃ ┗━━ realized float (1, 67)  
  8    ┃ ┃ ┃ ┃ ┗━┳ PAD ((1, 8), (0, 0)) 
  9    ┃ ┃ ┃ ┃   ┗━━ realized float (1, 67)  
 10    ┃ ┃ ┃ ┗━┳ PAD ((2, 7), (0, 0)) 
 11    ┃ ┃ ┃   ┗━━ realized float (1, 67)  
 12    ┃ ┃ ┗━┳ PAD ((3, 6), (0, 0)) 
 13    ┃ ┃   ┗━━ realized float (1, 67)  
 14    ┃ ┗━┳ PAD ((4, 5), (0, 0)) 
 15    ┃   ┗━━ realized float (1, 67)  
 16    ┗━┳ ADD  
 17      ┣━┳ ADD 
 18      ┃ ┣━┳ ADD 
 19      ┃ ┃ ┣━┳ ADD 
 20      ┃ ┃ ┃ ┣━┳ PAD ((5, 4), (0, 0))
 21      ┃ ┃ ┃ ┃ ┗━━ realized float (1, 67)  
 22      ┃ ┃ ┃ ┗━┳ PAD ((6, 3), (0, 0)) 
 23      ┃ ┃ ┃   ┗━━ realized float (1, 67)  
 24      ┃ ┃ ┗━┳ PAD ((7, 2), (0, 0)) 
 25      ┃ ┃   ┗━━ realized float (1, 67)  
 26      ┃ ┗━┳ PAD ((8, 1), (0, 0)) 
 27      ┃   ┗━━ realized float (1, 67)  
 28      ┗━┳ PAD ((9, 0), (0, 0)) 
 29        ┗━━ realized fl