# Starting code

In [29]:
import numpy as np
import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

# Data

In [30]:
data = "ABCD#"
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print("data has %d characters, %d unique." % (data_size, vocab_size))
char_to_idx = {ch: i for i, ch in enumerate(data)}
idx_to_char = {i: ch for i, ch in enumerate(data)}
str_to_idx = lambda x: list(map(lambda xx: char_to_idx[xx], x))
idx_to_str = lambda x: "".join(list(map(lambda xx: idx_to_char[xx], x)))
tensor_to_str = lambda x: "".join(list(map(lambda xx: idx_to_char[xx.item()], x)))


def str_to_tensor(x, device="cpu"):
    return torch.tensor(list(map(lambda xx: char_to_idx[xx], x)), device=device)

data has 5 characters, 5 unique.


In [31]:
print(chars)
print(char_to_idx)
str_to_idx(chars)

['#', 'C', 'D', 'B', 'A']
{'A': 0, 'B': 1, 'C': 2, 'D': 3, '#': 4}


[4, 2, 3, 1, 0]

In [32]:
idxes = str_to_idx(chars)
idxes_oh = F.one_hot(torch.tensor(idxes), vocab_size)

for char, idx, idx_oh in zip(chars, idxes, idxes_oh):
    print(char, idx, idx_oh)

# 4 tensor([0, 0, 0, 0, 1])
C 2 tensor([0, 0, 1, 0, 0])
D 3 tensor([0, 0, 0, 1, 0])
B 1 tensor([0, 1, 0, 0, 0])
A 0 tensor([1, 0, 0, 0, 0])


In [33]:
def random_string(ending_pattern: str):
    ret = torch.randint(0, vocab_size - 1, size=(torch.randint(5, 15, (1,)),))
    ret = torch.concatenate([ret, torch.tensor(str_to_idx(ending_pattern + "#"))])

    return ret


tensor_to_str(random_string("ABBCD"))

'DCABCAABBBCABBCD#'

In [34]:
class TextSampler(Dataset):
    def __init__(self, pattern, vocab_size):
        self.sample_size = 128
        self.pattern = pattern

    def __len__(self):
        return self.sample_size

    def __getitem__(self, idx):
        data = random_string(self.pattern)
        x = np.array(data[:-1], dtype=int)
        y = np.array(data[1:], dtype=int)

        return x, y

# Model

In [35]:
class RNN(L.LightningModule):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.Wxh = nn.Parameter(
            torch.normal(0, 1, size=(hidden_size, vocab_size)) * 0.01
        )
        self.Whh = nn.Parameter(
            torch.normal(0, 1, size=(hidden_size, hidden_size)) * 0.01
        )
        self.bh = nn.Parameter(torch.zeros(size=(hidden_size, 1), dtype=torch.float32))

    def forward(self, x, h_prev=None):
        seq_length, _ = x.shape
        x = torch.unsqueeze(x, dim=-1)

        out = torch.zeros(
            size=(seq_length, self.hidden_size, 1),
            dtype=torch.float32,
            device=self.device,
        )

        if h_prev == None:
            h_prev = torch.zeros(
                size=(self.hidden_size, 1), dtype=torch.float32, device=self.device
            )

        for t in range(seq_length):
            h_prev = nn.LeakyReLU()(self.Wxh @ x[t] + self.Whh @ h_prev + self.bh)
            out[t] = h_prev

        return out


input = F.one_hot(random_string("ABBCD"), vocab_size)

RNN(7, vocab_size)(input.float()).shape

torch.Size([13, 7, 1])

In [36]:
class LanguageDetector(L.LightningModule):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.train_loss = []

        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        self.rnn = RNN(hidden_size, vocab_size)
        self.dense = nn.Linear(hidden_size, vocab_size, bias=True)

    def forward(self, x, h_prev=None):
        hiddens = self.rnn(x, h_prev)
        out = self.dense(hiddens.squeeze(dim=-1))

        return out, hiddens[-1]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x, y = x[0], y[0]  # remove batch
        x = F.one_hot(x, self.vocab_size).float()
        out = self.forward(x)[0]
        loss = nn.CrossEntropyLoss()(out, y)

        self.train_loss.append(loss.item())
        return loss

    def on_train_epoch_end(self):
        avg_train_acc = sum(self.train_loss) / len(self.train_loss)
        self.train_loss.clear()
        self.print(f"Epoch {self.current_epoch} Training Loss: {avg_train_acc:.5f}")
        print("------------------------------------")
        print(self.generate("A"))
        print("------------------------------------")

    def generate(self, start="A", max_len=100):
        ret = start

        last_hidden = torch.zeros(
            size=(self.hidden_size, 1), dtype=torch.float32, device=self.device
        )

        for t in range(max_len):
            input = F.one_hot(
                str_to_tensor(ret[-1], device=self.device), vocab_size
            ).float()
            out, last_hidden = self.forward(input, last_hidden)
            out = torch.softmax(out, dim=-1)
            next = torch.multinomial(out, num_samples=1).squeeze()
            input = F.one_hot(next, vocab_size).float()
            ret = ret + idx_to_char[next.item()]

            if ret[-1] == "#":
                break

        return ret

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt


hidden_size = 7
pattern = "ABBCDA"

model = LanguageDetector(7, vocab_size)
model.generate()

train_dataset = TextSampler(pattern=pattern, vocab_size=vocab_size)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=12)

trainer = L.Trainer(max_epochs=100, accelerator="cpu")
trainer.fit(model, train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params
---------------------------------
0 | rnn   | RNN    | 91    
1 | dense | Linear | 40    
---------------------------------
131       Trainable params
0         Non-trainable params
131       Total params
0.001     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Epoch 0 Training Loss: 1.55319
------------------------------------
ADCCBCCCABCB#
------------------------------------
Epoch 1 Training Loss: 1.42898
------------------------------------
ADBBACADAADABACDABDBBBAAABBCD#
------------------------------------
Epoch 2 Training Loss: 1.28116
------------------------------------
ADCABCCDDABBDACBBBCDABBCCACBBCDAABADA#
------------------------------------
Epoch 3 Training Loss: 1.21343
------------------------------------
ABADD#
------------------------------------
Epoch 4 Training Loss: 1.17392
------------------------------------
ABBBDAA#
------------------------------------
Epoch 5 Training Loss: 1.16157
------------------------------------
AADBBCCDACBBBAA#
------------------------------------
Epoch 6 Training Loss: 1.14426
------------------------------------
AAADAABCDA#
------------------------------------
Epoch 7 Training Loss: 1.12554
------------------------------------
AAABCAABCDAABBCDA#
------------------------------------
Epoch 8 Trai

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99 Training Loss: 0.92316
------------------------------------
ABABBCAABCDA#
------------------------------------


In [37]:
cnt = 0
for i in range(100):
    gen = model.generate()[:-1]
    print(gen)
    cnt += int(gen[-len(pattern) :] == pattern)

print(cnt)

ABCBABCABBCCA
ADCCBDBBBCAAABBCDA
ABDCDACDAABBCDA
ABDDAABBCDA
ACCBCDABBCDA
ADABCAACBABBCDA
AADDCDABBCDA
AACBACCCCABBCDA
ACDBBBBABBCDA
ACBDABBCDBABBCDA
ADDCBBACAABBCDA
AAADDCDDDCDBABBCDA
ACBBAABBCDA
ABABBADABBCDA
ADABADDDABBCDA
ADDCDCAAAABBCDA
ABDDABCDDBACABBCDA
ADBCCAABBCDA
AABBBAACBBCDA
ADCCDADAABBCDA
ABBCABBCDA
ABCADBBCDADADBBCDA
ADBDBBBBCDABABBCDA
AAADABABBCDA
ADDABCABABBCDA
AADDABBCDA
ADABCBBCCABBCDA
ABCDACBBBCDCBBCDA
ADDACADABBCDA
ABABBCCBABBCDA
AABABCCAABABBCDA
AADBABADBCABBCDA
ABACAAABBCDA
ACACBABBBCDABCDA
ACBDBBCDCADABBCDA
ACDBCABBCDA
ADCDDDACBABBCDA
ABACDADAAAABDDABBCDA
AADCABBAABBCDA
ABCCAADBBCDACBBCDA
ACACAADACBABBCDA
ADADCDACDACAABBCDA
AABDCAABBCDA
ACCAABBADABCDABBCDA
ADAACACBCCABBCDA
ADBDABACAABBCDA
ABCCCCABABBCDA
ADDBCDADADABBCDA
ABDAABBCDDAABBCDA
ACCDDACADABBCDA
AABADBABCDDABABBCDA
ADCADCBCAABCCABBCDA
AABDCDDCCABBCDA
ABAAAAABCDABBCDA
ABBDBBCADBABABBCDA
ACABBDABBCDA
AACBADBAAAABBCDA
AABACACACAAABBCDA
ADACBBDABAABBCDA
ADDDBCCCABBCDA
ACCDADACCBCBBABBCDA
ACBAABBDDBBDAABBCDA
A