In [1]:
# lstm_char_lm.py
import math, random, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# -----------------------------
# 1) Toy corpus (replace with your own text)
# -----------------------------
corpus = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them.
"""

# -----------------------------
# 2) Preprocessing: char vocab + encoding
# -----------------------------
chars = sorted(list(set(corpus)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars)

def encode(s): return [stoi[c] for c in s]
def decode(ids): return "".join(itos[i] for i in ids)

data = torch.tensor(encode(corpus), dtype=torch.long)

# -----------------------------
# 3) Dataset: (input_seq -> next_char)
# -----------------------------
class CharDataset(Dataset):
    def __init__(self, data, seq_len=64):
        self.data = data
        self.seq_len = seq_len
    def __len__(self):
        return max(0, len(self.data) - self.seq_len)
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_len]
        y = self.data[idx+1:idx+self.seq_len+1]
        return x, y

seq_len = 64
dataset = CharDataset(data, seq_len=seq_len)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

# -----------------------------
# 4) LSTM Language Model
# -----------------------------
class LSTMCharLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hidden_dim=256, num_layers=2, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.ln = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hx=None):
        # x: (B, T)
        x = self.embed(x)              # (B, T, E)
        out, hx = self.lstm(x, hx)     # out: (B, T, H)
        out = self.ln(out)
        logits = self.head(out)        # (B, T, V)
        return logits, hx

# -----------------------------
# 5) Training setup
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LSTMCharLM(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
criterion = nn.CrossEntropyLoss()

# -----------------------------
# 6) Training loop
# -----------------------------
epochs = 15
model.train()
for epoch in range(1, epochs + 1):
    total_loss, total_tokens = 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)   # (B, T)
        optimizer.zero_grad()
        logits, _ = model(x)                 # (B, T, V)
        loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # stabilize training
        optimizer.step()

        total_loss += loss.item() * x.numel()
        total_tokens += x.numel()

    ppl = math.exp(total_loss / total_tokens)
    print(f"Epoch {epoch:02d} | Loss: {total_loss/total_tokens:.4f} | Perplexity: {ppl:.2f}")

# -----------------------------
# 7) Sampling (text generation)
# -----------------------------
@torch.no_grad()
def generate(model, start_text="To ", max_new_tokens=300, temperature=0.8, top_k=0):
    model.eval()
    x = torch.tensor([encode(start_text)], dtype=torch.long, device=device)
    hx = None
    out_text = list(start_text)

    for _ in range(max_new_tokens):
        logits, hx = model(x[:, -seq_len:], hx)     # keep last seq_len tokens for stability
        logits = logits[:, -1, :] / temperature     # (B, V)

        if top_k > 0:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("Inf")

        probs = torch.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)  # (B, 1)
        out_text.append(itos[int(next_id)])
        x = torch.cat([x, next_id], dim=1)

    return "".join(out_text)

print("\n--- Generated Sample ---")
print(generate(model, start_text="To ", max_new_tokens=400, temperature=0.9, top_k=20))


Epoch 01 | Loss: 3.1554 | Perplexity: 23.46
Epoch 02 | Loss: 2.6180 | Perplexity: 13.71
Epoch 03 | Loss: 1.7365 | Perplexity: 5.68
Epoch 04 | Loss: 1.2622 | Perplexity: 3.53
Epoch 05 | Loss: 0.8496 | Perplexity: 2.34
Epoch 06 | Loss: 0.5468 | Perplexity: 1.73
Epoch 07 | Loss: 0.3446 | Perplexity: 1.41
Epoch 08 | Loss: 0.2292 | Perplexity: 1.26
Epoch 09 | Loss: 0.1613 | Perplexity: 1.18
Epoch 10 | Loss: 0.1258 | Perplexity: 1.13
Epoch 11 | Loss: 0.1003 | Perplexity: 1.11
Epoch 12 | Loss: 0.0855 | Perplexity: 1.09
Epoch 13 | Loss: 0.0777 | Perplexity: 1.08
Epoch 14 | Loss: 0.0693 | Perplexity: 1.07
Epoch 15 | Loss: 0.0615 | Perplexity: 1.06

--- Generated Sample ---
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end ther 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by op