In [63]:
import pandas as pd
import time
import torch

# Encoding/decoding algorithm

In [245]:
letters = "abcdefghijklmnopqrstuvwxyz .,!:;\'\"?()"
letters_len = 26
all_chats_len = len(letters)


In [246]:
def encode(text, k = 3):
    return [letters[(letters.index(c) + k) % letters_len] if letters.index(c) < letters_len else c for c in text ]
def decode(text, k = 3):
    return [letters[(letters.index(c) - k) % letters_len] if letters.index(c) < letters_len else c for c in text ]

def text_to_idx(text):
    indices = torch.zeros(len(text))
    for i in range(len(text)):
        indices[i] = letters.index(text[i])
    return indices.int()

def idx_to_text(indices):
    text = ""
    for i in indices:
        text += letters[int(i)]
    return text

def encode_idx(idx_tens, k = 3):
    result = idx_tens.clone().detach()
    mask1 = result < letters_len
    result[mask1] += k
    mask2 = mask1 & (result > letters_len - 1)
    result[mask2] -= letters_len
    return result

def decode_idx(idx_tens, k = 3):
    result = idx_tens.clone().detach()
    mask1 = result < letters_len
    result[mask1] -= k
    mask2 = mask1 & (result < 0)
    result[mask2] += letters_len
    return result

In [247]:
indices = text_to_idx('abc?')
indices

tensor([ 0,  1,  2, 34], dtype=torch.int32)

In [248]:
idx_to_text(indices)

'abc?'

In [249]:
encode("abs")

['d', 'e', 'v']

In [250]:
decode("dev")

['a', 'b', 's']

In [251]:
test = torch.randint(0, all_chats_len - 1, (5, ))
test

tensor([ 3, 10, 35, 21, 31])

In [252]:
encoded = encode_idx(test)
encoded

tensor([ 6, 13, 35, 24, 31])

In [253]:
decode_idx(encoded)

tensor([ 3, 10, 35, 21, 31])

In [254]:
indices = text_to_idx("hello world!")
encoded = encode_idx(indices)
decoded = decode_idx(encoded)
idx_to_text(decoded)

'hello world!'

# Create data

In [255]:
import random
DATA_CHARS = 5000
X_train = torch.randint(0, all_chats_len - 1, (DATA_CHARS, ))
X_test = torch.randint(0, all_chats_len - 1, (int(0.4 * DATA_CHARS), ))

In [256]:
embedding = torch.nn.Embedding(all_chats_len, 30)
embedding(X_train[:4]).shape

torch.Size([4, 30])

# Create model

In [267]:
model = torch.nn.Sequential(
    torch.nn.Embedding(all_chats_len, 128),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, all_chats_len)
)

In [268]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.05)
BATCH_SIZE = 100

In [269]:
for ep in range(10):
    start = time.time()
    train_loss = 0.
    train_passed = 0
    for i in range(int(len(X_train) / BATCH_SIZE)):
        batch = X_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        Y = decode_idx(batch)
        optimizer.zero_grad()
        answers = model(batch)
        l = loss(answers, Y)
        train_loss += l.item()
        l.backward()
        optimizer.step()
        train_passed += 1
    print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))

Epoch 0. Time: 0.066, Train loss: 0.121
Epoch 1. Time: 0.070, Train loss: 0.000
Epoch 2. Time: 0.070, Train loss: 0.000
Epoch 3. Time: 0.070, Train loss: 0.000
Epoch 4. Time: 0.069, Train loss: 0.000
Epoch 5. Time: 0.075, Train loss: 0.000
Epoch 6. Time: 0.066, Train loss: 0.000
Epoch 7. Time: 0.074, Train loss: 0.000
Epoch 8. Time: 0.068, Train loss: 0.000
Epoch 9. Time: 0.068, Train loss: 0.000


# Check accuracy

In [270]:
model(X_test[:5]).argmax(axis = 1)

tensor([11, 33, 16, 29, 29])

In [271]:
decode_idx(X_test[:5])

tensor([11, 33, 16, 29, 29])

In [272]:
Y_test = decode_idx(X_test)
accuracy = float((model(X_test).argmax(axis=1) == Y_test).sum() / len(Y_test))
accuracy

1.0

In [273]:
phrase = "hello world ?"
encoded = encode(phrase)
indices = text_to_idx(encoded)
decoded = model(indices).argmax(axis = 1)
idx_to_text(decoded)

'hello world ?'