In [36]:
import torch
import torch.nn as nn
import itertools

In [21]:
class RNN(nn.Module):

    def __init__(
        self,
        dim_inputs,
        hidden_state_dim,
    ):
        super().__init__()

        self.hidden_state_dim = hidden_state_dim
        self.w_xh = nn.Parameter(torch.rand(dim_inputs, hidden_state_dim))
        self.w_hh = nn.Parameter(torch.rand(hidden_state_dim, hidden_state_dim))

        self.bias = nn.Parameter(torch.rand(hidden_state_dim))

    def forward(self, inputs, hidden_state = None):
        # input dims: (ctx len, batch size, input dim)
        if hidden_state is None:
            hidden_state = nn.Parameter(torch.rand(self.hidden_state_dim))

        outputs = []
        for x_t in inputs:
            hidden_state = torch.tanh(
                torch.matmul(x_t, self.w_xh)
                + torch.matmul(hidden_state, self.w_hh)
                + self.bias
            )
            outputs.append(hidden_state)

        output_t = torch.stack(outputs)

        return output_t, hidden_state

In [24]:
with open("./datasets/tinyshakespeare.txt") as f:
    tiny_shakespeare = f.read()

In [38]:
context_len = 8

In [32]:
chars = list(tiny_shakespeare)
vocab = {
    c: i
    for i, c in enumerate(list(set(chars)))
}
rev_vocab = {
    i: c
    for c, i in vocab.items()
}

In [33]:
tokenized = [
    vocab[c]
    for c in chars
]

In [52]:
X = torch.chunk(torch.tensor(tokenized), len(tokenized) // context_len)

In [54]:
Y = torch.chunk(torch.tensor(tokenized[1:]), len(tokenized) // context_len)

In [62]:
split_mask = (torch.rand(len(X)) < 0.2).int()

In [65]:
split_mask

tensor([1, 1, 1,  ..., 1, 1, 0], dtype=torch.int32)