In [1]:
import string

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
abc = ' ' + string.punctuation + string.digits + string.ascii_letters
len_abc = len(abc)

itoc = {i:c for i,c in enumerate(abc)}
ctoi = {c:i for i,c in enumerate(abc)}

print(len_abc, abc)

95  !"#$%&'()*+,-./:;<=>?@[\]^_`{|}~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ


In [3]:
def encode(s: str) -> list[int]:
    return [ctoi[c] for c in s]

def decode(l: list[int]) -> str:
    return ''.join([itoc[i] for i in l])

def str2seq(s: str) -> torch.Tensor:
    encoded = torch.tensor(encode(s)).long()
    return F.one_hot(encoded, len_abc).float()

def seq2str(s: torch.Tensor) -> str:
    encoded = [torch.argmax(t).int().item() for t in s]
    return decode(encoded)

In [4]:
rnn = nn.RNN(len_abc, 8, batch_first=True)

In [5]:
text = 'the quick brown fox jumps over the lazy dog'
sequence = str2seq(text)

In [6]:
states, context = rnn(sequence)

print(context.shape)
print(states.shape)

torch.Size([1, 8])
torch.Size([43, 8])


In [7]:
text1 = 'the quick brown fox jumps over the lazy dog'
text2 = 'god yzal eht revo spmuj xof nworb kciuq eht'
batch = torch.stack([str2seq(text1), str2seq(text2)])
print(batch.shape)

torch.Size([2, 43, 95])


In [12]:
states, context = rnn(batch)

print(states.shape)
print(context.shape)

torch.Size([2, 43, 8])
torch.Size([1, 2, 8])


In [9]:
def calculate_attention(x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
    scores = F.softmax(x @ m, dim=1)
    values = scores.unsqueeze(1) * m
    return torch.sum(values, dim=2)

In [10]:
text = 'the quick brown fox jumps over the lazy dog'
sequence = str2seq(text)

In [11]:
n_hidden = 64

encoder_rnn = nn.RNN(len_abc, n_hidden)
decoder_rnn = nn.RNN(len_abc, n_hidden)
final = nn.Linear(2*n_hidden, len_abc)

states, context = encoder_rnn(sequence)

outputs = []
out = torch.zeros(1, len_abc)

for i in range(128):
    _, context = decoder_rnn(out, context)
    attention = calculate_attention(context, states.T)
    ctx_att = torch.cat((context, attention), dim=1)
    out = F.softmax(final(ctx_att), dim=1)
    outputs.append(out)

outputs = torch.stack(outputs).squeeze(1)
print(seq2str(outputs))

N,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
