In [1]:
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Определим ключ и словарь
key = 2
vocab = [char for char in ' -ABCDEFGHIJKLMNOPQRSTUVWXYZ']

In [3]:
# Напишем функцию, которая делает шифр
def encrypt(text, key):
    """Returns the encrypted form of 'text'."""
    indexes = [vocab.index(char) for char in text]
    encrypted_indexes = [(idx + key) % len(vocab) for idx in indexes]
    encrypted_chars = [vocab[idx] for idx in encrypted_indexes]
    encrypted = ''.join(encrypted_chars)
    return encrypted

print(encrypt('RNN IS NOT AI', key))

TPPAKUAPQVACK


In [4]:
num_examples = 256 # размер датасета
seq_len = 18 # максимальная длина строки


def encrypted_dataset(dataset_len, k):
    """
    Return: List(Tuple(Tensor encrypted, Tensor source))
    """
    dataset = []
    for x in range(dataset_len):
        random_message  = ''.join([random.choice(vocab) for x in range(seq_len)])
        encrypt_random_message = encrypt(''.join(random_message), k)
        src = [vocab.index(x) for x in random_message]
        tgt = [vocab.index(x) for x in encrypt_random_message]
        dataset.append([torch.tensor(tgt), torch.tensor(src)])
    return dataset

In [5]:
class Decipher(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, 
                rnn_type='simple'):
        """
        :params: int vocab_size 
        :params: int embedding_dim
        :params
        """
        super(Decipher, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        if rnn_type == 'simple':
            self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers = 2)
        
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.initial_hidden = torch.zeros(2, 1, hidden_dim)

        
    def forward(self, cipher):
        # CHECK INPUT SIZE
        # Unsqueeze 1 dimension for batches
        embd_x = self.embed(cipher).unsqueeze(1)
        out_rnn, hidden = self.rnn(embd_x, self.initial_hidden)
        # Apply the affine transform and transpose output in appropriate way
        # because you want to get the softmax on vocabulary dimension
        # in order to get probability of every letter
        return self.fc(out_rnn).transpose(1, 2)

In [6]:
# Определим параметры нашей модели
embedding_dim = 5
hidden_dim = 10
vocab_size = len(vocab) 
lr = 1e-3

criterion = nn.CrossEntropyLoss()

# Инициализируйте модель
model = Decipher(vocab_size, embedding_dim, hidden_dim)

# Инициализируйте оптимизатор: рекомендуется Adam
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)


# Реализация ранней остановки

In [7]:
k = 10
# Инициализация переменных для ранней остановки
best_accuracy = 0.0
epochs_no_improve = 0
n_epochs_stop = 5
num_epochs = 20


for x in range(num_epochs):
    print('Epoch: {:2.0f}'.format(x), end=', ')
    for encrypted, original in encrypted_dataset(num_examples, k):

        scores = model(encrypted)
        original = original.unsqueeze(1)
        # Calculate loss
        loss = criterion(scores, original)
        # Zero grads
        optimizer.zero_grad()
        # Backpropagate
        loss.backward()
        # Update weights
        optimizer.step()
    print('Loss: {:6.4f}'.format(loss.item()), end=',  ')

    with torch.no_grad():
        matches, total = 0, 0
        for encrypted, original in encrypted_dataset(num_examples, k):
            # Compute a softmax over the outputs
            predictions = F.softmax(model(encrypted), 1)
            # Choose the character with the maximum probability (greedy decoding)
            _, batch_out = predictions.max(dim=1)
            # Remove batch
            batch_out = batch_out.squeeze(1)
            # Calculate accuracy
            matches += torch.eq(batch_out, original).sum().item()
            total += torch.numel(batch_out)
        accuracy = matches / total
        print('Accuracy: {:4.2f}%'.format(accuracy * 100))
        
        # Обработка ранней остановки
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
    
        # Остановка, если качество не улучшается
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!')
            break
        

Epoch:  0, Loss: 2.6906,  Accuracy: 34.09%
Epoch:  1, Loss: 1.7214,  Accuracy: 77.78%
Epoch:  2, Loss: 1.1368,  Accuracy: 87.02%
Epoch:  3, Loss: 0.7830,  Accuracy: 97.22%
Epoch:  4, Loss: 0.5837,  Accuracy: 100.00%
Epoch:  5, Loss: 0.3572,  Accuracy: 100.00%
Epoch:  6, Loss: 0.2724,  Accuracy: 100.00%
Epoch:  7, Loss: 0.2102,  Accuracy: 100.00%
Epoch:  8, Loss: 0.1548,  Accuracy: 100.00%
Epoch:  9, Loss: 0.1329,  Accuracy: 100.00%
Early stopping!


# Добавим логирование результатов в Tensorboard

1. Импортировать TensorBoard: Для начала необходимо импортировать необходимый класс SummaryWriter из torch.utils.tensorboard.
2. Создать экземпляр SummaryWriter: Это будет использоваться для записи данных.
3. Логировать данные: Вы можете логировать различные метрики (например, потери и точность) во время обучения и проверки модели.
4. Запустить TensorBoard: После запуска скрипта, TensorBoard можно запустить локально для просмотра результатов.

In [8]:
from torch.utils.tensorboard import SummaryWriter

In [12]:
k = 10
# Инициализация переменных для ранней остановки
best_accuracy = 0.0
epochs_no_improve = 0
n_epochs_stop = 5
num_epochs = 20

# Создаем экземпляр для TensorBoard
writer = SummaryWriter()

for x in range(num_epochs):
    print('Epoch: {:2.0f}'.format(x), end=', ')
    for encrypted, original in encrypted_dataset(num_examples, k):

        scores = model(encrypted)
        original = original.unsqueeze(1)
        # Calculate loss
        loss = criterion(scores, original)
        # Zero grads
        optimizer.zero_grad()
        # Backpropagate
        loss.backward()
        # Update weights
        optimizer.step()
    print('Loss: {:6.4f}'.format(loss.item()), end=',  ')
    
    # Логирование потерь в TensorBoard
    writer.add_scalar('Loss/train', loss.item(), x)

    with torch.no_grad():
        matches, total = 0, 0
        for encrypted, original in encrypted_dataset(num_examples, k):
            # Compute a softmax over the outputs
            predictions = F.softmax(model(encrypted), 1)
            # Choose the character with the maximum probability (greedy decoding)
            _, batch_out = predictions.max(dim=1)
            # Remove batch
            batch_out = batch_out.squeeze(1)
            # Calculate accuracy
            matches += torch.eq(batch_out, original).sum().item()
            total += torch.numel(batch_out)
        accuracy = matches / total
        print('Accuracy: {:4.2f}%'.format(accuracy * 100))
        
        # Логирование точности в TensorBoard
        writer.add_scalar('Accuracy/train', accuracy, x)

        # Обработка ранней остановки
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
    
        # Остановка, если качество не улучшается
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!')
            break
        

Epoch:  0, Loss: 2.6422,  Accuracy: 32.92%
Epoch:  1, Loss: 1.8693,  Accuracy: 66.97%
Epoch:  2, Loss: 1.2075,  Accuracy: 84.96%
Epoch:  3, Loss: 0.8928,  Accuracy: 98.46%
Epoch:  4, Loss: 0.6485,  Accuracy: 100.00%
Epoch:  5, Loss: 0.4252,  Accuracy: 100.00%
Epoch:  6, Loss: 0.3391,  Accuracy: 100.00%
Epoch:  7, Loss: 0.3324,  Accuracy: 100.00%
Epoch:  8, Loss: 0.2439,  Accuracy: 100.00%
Epoch:  9, Loss: 0.1465,  Accuracy: 100.00%
Early stopping!


In [10]:
# !tensorboard --logdir=runs

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6006/ (Press CTRL+C to quit)
