AUTOREGRESSIVE RNN

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Toy dataset
text = "hello*"
chars = sorted(list(set(text)))
char2idx = {ch: idx for idx, ch in enumerate(chars)}
idx2char = {idx: ch for ch, idx in char2idx.items()}
vocab_size = len(chars)

In [6]:
class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CharRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size).to(device)


In [7]:
# One-hot encoding function
def one_hot_encode(char_idx, vocab_size):
    vec = torch.zeros(1, 1, vocab_size)
    vec[0, 0, char_idx] = 1
    return vec.to(device)

model = CharRNN(input_size=vocab_size, hidden_size=8, output_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training
n_epochs = 100
for epoch in range(n_epochs):
    hidden = model.init_hidden()
    loss = 0

    for i in range(len(text) - 1):
        input_char = one_hot_encode(char2idx[text[i]], vocab_size)
        target = torch.tensor([char2idx[text[i + 1]]]).to(device)

        output, hidden = model(input_char, hidden)
        loss += criterion(output.view(-1, vocab_size), target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')


Epoch [20/100], Loss: 4.1096
Epoch [40/100], Loss: 1.2606
Epoch [60/100], Loss: 0.4244
Epoch [80/100], Loss: 0.2046
Epoch [100/100], Loss: 0.1272


In [9]:
def generate_text(start_char='h', max_length=20, stop_char='*'):
    model.eval()
    hidden = model.init_hidden()
    input_char = one_hot_encode(char2idx[start_char], vocab_size)
    generated = start_char

    for _ in range(max_length):
        output, hidden = model(input_char, hidden)
        probs = torch.softmax(output.view(-1), dim=0).detach().cpu().numpy()
        char_idx = torch.multinomial(torch.tensor(probs), 1).item()
        char = idx2char[char_idx]
        
        if char == stop_char:
            break
        
        generated += char
        input_char = one_hot_encode(char_idx, vocab_size)

    return generated

print("Generated:", generate_text('h'))


Generated: hello
