In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn

# 2. Preprocessor network

## 2.1. Class definition

In [6]:
from torch.nn import Module, Linear, ReLU, Sequential, Embedding

class Preprocessor(Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Preprocessor, self).__init__()

        self.embedding = Embedding(vocab_size, embedding_dim)
        
        self.encoder = Sequential(
            Linear(embedding_dim, 32),
            ReLU(),
            Linear(32, 16),
        )
        
        self.decoder = Sequential(
            Linear(16, 32),
            ReLU(),
            Linear(32, embedding_dim)
        )
        
        self.output = Linear(embedding_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        encoded = self.encoder(embedded)
        decoded = self.decoder(encoded)
        logits = self.output(decoded)
        return logits

## 2.2. Initialization

In [7]:
VOCAB_SIZE = 7
EMBEDDING_DIM = 50

model = Preprocessor(VOCAB_SIZE, EMBEDDING_DIM)
model

tensor([[1., 2., 3., 4., 5., 6.]])


In [None]:
from torch import tensor, argmax

fixed_input = tensor([[ 1, 2, 3, 4, 5, 6]])

logits = model(fixed_input)
output = argmax(logits, dim=-1)
output.shape, logits.shape, logits

## 2.3. Training

In [None]:
from torch import argmax, randint
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 32
SEQ_LENGTH = 6

model.train()

criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-3)

num_epochs = 2500
for epoch in range(num_epochs):
    total_loss = 0

    inputs = randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
    targets = inputs.clone()

    dataset = TensorDataset(inputs, targets)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    for batch_inputs, batch_targets in data_loader:
        logits = model(batch_inputs)
        loss = criterion(logits.permute(0, 2, 1), batch_targets)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch + 1) % 250 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

## 2.4. Testing

In [None]:
logits = model(fixed_input)
output = argmax(logits, dim=-1)
output

## 2.5. Save model

In [None]:
from pathlib import Path
from torch import save

Path("models").mkdir(exist_ok=True)
save(obj=model.state_dict(), f='./models/preprocessor.pt')