In [1]:
import torch
import torch.nn as nn


In [2]:
# load text
text = open("data.txt").read()

# build vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}

# encode text
encoded = torch.tensor([stoi[c] for c in text])


In [3]:
block_size = 8
X, Y = [], []

for i in range(len(encoded) - block_size):
    X.append(encoded[i:i+block_size])
    Y.append(encoded[i+1:i+block_size+1])

X = torch.stack(X)
Y = torch.stack(Y)

print(X.shape, Y.shape)


torch.Size([84, 8]) torch.Size([84, 8])


In [4]:
class CharLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, 32)
        self.lstm = nn.LSTM(32, 64, batch_first=True)
        self.fc = nn.Linear(64, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        out, _ = self.lstm(x)
        out = self.fc(out)
        return out


In [5]:
model = CharLSTM()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(200):
    optimizer.zero_grad()
    logits = model(X)
    loss = loss_fn(logits.view(-1, vocab_size), Y.view(-1))
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, loss = {loss.item():.4f}")


Epoch 0, loss = 3.0058
Epoch 20, loss = 0.5064
Epoch 40, loss = 0.2285
Epoch 60, loss = 0.2077
Epoch 80, loss = 0.2030
Epoch 100, loss = 0.2025
Epoch 120, loss = 0.2010
Epoch 140, loss = 0.2005
Epoch 160, loss = 0.2001
Epoch 180, loss = 0.1999


In [6]:
def generate(start="h", length=100):
    model.eval()
    idx = torch.tensor([[stoi[start]]])
    result = start

    for _ in range(length):
        logits = model(idx)
        probs = torch.softmax(logits[0, -1], dim=0)
        next_idx = torch.multinomial(probs, 1).item()
        result += itos[next_idx]
        idx = torch.tensor([[next_idx]])

    return result

print(generate("h"))


helearnins worng
ls
m
he lorllo s
marng
lesellllelsquestmachearls
m marng
heleache s ls ls s llearlle
