AUTOREGRESSIVE RNN

In [1]:
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import one_hot
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create vocabulary
sequences = ["hello*", "help*", "held*", "hero*"]

text = "".join(sequences)
chars = sorted(list(set(text)))
char2idx = {ch: idx for idx, ch in enumerate(chars)}
idx2char = {idx: ch for ch, idx in char2idx.items()}
vocab_size = len(chars)

# Convert to index tensors
def encode_sequence(seq):
    return torch.tensor([char2idx[ch] for ch in seq], dtype=torch.long)

encoded_seqs = [encode_sequence(seq) for seq in sequences]

# Pad sequences to equal length
padded_seqs = pad_sequence(encoded_seqs, batch_first=True, padding_value=char2idx['*'])

# Inputs: everything except last character
# Targets: everything except first character
inputs = padded_seqs[:, :-1]
targets = padded_seqs[:, 1:]


In [20]:
vocab_size

8

In [2]:
class CharLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CharLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.lstm(x, hidden)
        out = self.fc(out)
        return out, hidden

    def init_hidden(self, batch_size):
        # Hidden state and cell state for LSTM
        h_0 = torch.zeros(1, batch_size, self.hidden_size).to(device)
        c_0 = torch.zeros(1, batch_size, self.hidden_size).to(device)
        return (h_0, c_0)


In [3]:
# One-hot encode the whole batch
inputs_onehot = one_hot(inputs, num_classes=vocab_size).float().to(device)
targets = targets.to(device)

model = CharLSTM(vocab_size, hidden_size=32, output_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

n_epochs = 100
for epoch in range(n_epochs):
    model.train()
    hidden = model.init_hidden(batch_size=inputs.size(0))
    
    output, hidden = model(inputs_onehot, hidden)
    loss = criterion(output.view(-1, vocab_size), targets.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}")



Epoch [10/100], Loss: 1.4937
Epoch [20/100], Loss: 1.1803
Epoch [30/100], Loss: 0.8045
Epoch [40/100], Loss: 0.5258
Epoch [50/100], Loss: 0.3842
Epoch [60/100], Loss: 0.3205
Epoch [70/100], Loss: 0.2973
Epoch [80/100], Loss: 0.2890
Epoch [90/100], Loss: 0.2854
Epoch [100/100], Loss: 0.2836


In [32]:
print(inputs[0])
print(sequences[0])

tensor([3, 2, 4, 4, 5])
hello*


In [43]:
sequences[0:2]
# sequences[2:]

['hello*', 'help*']

In [38]:
print(inputs_onehot[2])

tensor([[0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')


In [None]:
def generate_text(start_char='h', max_len=20, stop_char='*'):
    model.eval()
    input_char_idx = torch.tensor([[char2idx[start_char]]])
    input_onehot = one_hot(input_char_idx, num_classes=vocab_size).float().to(device)
    hidden = model.init_hidden(batch_size=1)
    generated = start_char

    for _ in range(max_len):
        output, hidden = model(input_onehot, hidden)
        probs = torch.softmax(output[0, -1], dim=0).detach().cpu()
        next_idx = torch.multinomial(probs, 1).item()
        next_char = idx2char[next_idx]
        if next_char == stop_char:
            break
        generated += next_char

        input_onehot = one_hot(torch.tensor([[next_idx]]), num_classes=vocab_size).float().to(device)

    return generated

print("Generated:", generate_text('h'))


tensor([2.8710e-04, 2.2229e-04, 9.9574e-01, 2.2217e-04, 1.2073e-03, 2.9980e-05,
        1.6896e-04, 2.1258e-03])
tensor([3.2202e-05, 1.0839e-03, 7.8603e-04, 1.1017e-05, 7.4959e-01, 2.8671e-05,
        1.1256e-03, 2.4735e-01])
tensor([1.7208e-04, 3.3177e-01, 1.5226e-05, 2.5069e-05, 3.3064e-01, 7.5784e-03,
        3.2781e-01, 1.9950e-03])
tensor([1.5544e-02, 2.0982e-03, 2.4264e-06, 1.6779e-05, 1.0912e-03, 9.7899e-01,
        2.2333e-03, 2.6440e-05])
tensor([9.9909e-01, 9.7474e-06, 8.7536e-06, 9.0646e-06, 2.1478e-05, 8.5166e-04,
        8.9895e-06, 1.7287e-06])
Generated: hello
