In [1]:
!pip install torch sentencepiece



### 1. **Imports**

In [2]:
import math, random
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import sentencepiece as spm

In [3]:
def make_split(
    in_path,
    train_out="train.txt",
    valid_out="valid.txt",
    valid_ratio=0.05,
    seed=42
):
    random.seed(seed)

    text = Path(in_path).read_text(encoding="utf-8", errors="ignore")
    lines = [l.strip() for l in text.splitlines() if l.strip()]
    random.shuffle(lines)

    n_valid = int(len(lines) * valid_ratio)

    Path(train_out).write_text("\n".join(lines[n_valid:]), encoding="utf-8")
    Path(valid_out).write_text("\n".join(lines[:n_valid]), encoding="utf-8")

    print(f"train lines: {len(lines[n_valid:])}")
    print(f"valid lines: {n_valid}")

In [4]:
make_split("/content/cleaned-general-text.txt")

train lines: 175066
valid lines: 9213


Build SentencePiece Tokenizer with Vocabsize = 8000 (khm_spm.model)

In [5]:
VOCAB_SIZE = 8000

spm.SentencePieceTrainer.train(
    input="train.txt",
    model_prefix="kh_spm",
    vocab_size=VOCAB_SIZE,
    model_type="bpe",
    character_coverage=0.9995,
    pad_id=0, unk_id=1, bos_id=2, eos_id=3,
)

sp = spm.SentencePieceProcessor(model_file="kh_spm.model")

print("Vocab size:", sp.get_piece_size())


Vocab size: 8000


Load Data

In [6]:
class LMDataset(Dataset):
    def __init__(self, path, sp, block_size=128, step=None):
        self.samples = []
        step = step or block_size // 2

        for line in Path(path).read_text(encoding="utf-8").splitlines():
            ids = [sp.bos_id()] + sp.encode(line, out_type=int) + [sp.eos_id()]
            for i in range(0, len(ids) - 1, step):
                chunk = ids[i:i + block_size + 1]
                if len(chunk) >= 2:
                    self.samples.append(chunk)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        seq = self.samples[idx]
        return (
            torch.tensor(seq[:-1], dtype=torch.long),
            torch.tensor(seq[1:], dtype=torch.long)
        )

In [7]:
def collate_pad(batch, pad_id):
    xs, ys = zip(*batch)
    max_len = max(x.size(0) for x in xs)

    x_pad = torch.full((len(xs), max_len), pad_id)
    y_pad = torch.full((len(xs), max_len), pad_id)

    for i, (x, y) in enumerate(zip(xs, ys)):
        x_pad[i, :x.size(0)] = x
        y_pad[i, :y.size(0)] = y

    return x_pad, y_pad


Model

In [8]:
class LSTMLM(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, hid_dim=512, layers=2, pad_id=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.lstm = nn.LSTM(
            emb_dim, hid_dim, layers,
            batch_first=True,
            dropout=0.2 if layers > 1 else 0.0
        )
        self.fc = nn.Linear(hid_dim, vocab_size)

    def forward(self, x, h=None):
        x = self.emb(x)
        x, h = self.lstm(x, h)
        return self.fc(x), h

    @torch.no_grad()
    def step(self, token_id, h=None, device="cpu"):
        x = torch.tensor([[token_id]], device=device)
        logits, h = self.forward(x, h)
        return logits[0, -1], h

Configuration Setup

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds = LMDataset("train.txt", sp)
valid_ds = LMDataset("valid.txt", sp)

train_dl = DataLoader(
    train_ds, batch_size=64, shuffle=True,
    collate_fn=lambda b: collate_pad(b, sp.pad_id())
)
valid_dl = DataLoader(
    valid_ds, batch_size=64, shuffle=False,
    collate_fn=lambda b: collate_pad(b, sp.pad_id())
)

model = LSTMLM(sp.get_piece_size(), pad_id=sp.pad_id()).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=sp.pad_id())

In [10]:
@torch.no_grad()
def generate_text(model, sp, text, max_new_tokens=20):
    model.eval()

    ids = [sp.bos_id()] + sp.encode(text, out_type=int)
    x = torch.tensor(ids, device=device).unsqueeze(0)
    h = None

    for _ in range(max_new_tokens):
        logits, h = model(x[:, -1:], h)
        next_id = torch.argmax(logits[0, -1])
        x = torch.cat([x, next_id.view(1, 1)], dim=1)
        if next_id.item() == sp.eos_id():
            break

    return sp.decode(x[0].tolist())

In [11]:
@torch.no_grad()
def evaluate():
    model.eval()
    loss_sum = tok = 0

    for x, y in valid_dl:
        x, y = x.to(device), y.to(device)
        logits, _ = model(x)

        loss = loss_fn(
            logits.view(-1, logits.size(-1)),
            y.view(-1)
        )

        n = (y != sp.pad_id()).sum().item()
        loss_sum += loss.item() * n
        tok += n

    model.train()
    avg_loss = loss_sum / tok
    return avg_loss, math.exp(avg_loss)


Training Loop

In [12]:
EPOCHS = 20

for epoch in range(EPOCHS):
    model.train()
    train_loss_sum = 0
    train_tok = 0

    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        logits, _ = model(x)

        loss = loss_fn(
            logits.view(-1, logits.size(-1)),
            y.view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        n = (y != sp.pad_id()).sum().item()
        train_loss_sum += loss.item() * n
        train_tok += n

    train_loss = train_loss_sum / train_tok
    val_loss, val_ppl = evaluate()

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print(f"Train loss: {train_loss:.4f}")
    print(f"Val loss  : {val_loss:.4f}")
    print(f"Val ppl   : {val_ppl:.2f}")



Epoch 1/20
Train loss: 6.9556
Val loss  : 6.1276
Val ppl   : 458.35

Epoch 2/20
Train loss: 5.7056
Val loss  : 5.3815
Val ppl   : 217.35

Epoch 3/20
Train loss: 5.1582
Val loss  : 5.0294
Val ppl   : 152.84

Epoch 4/20
Train loss: 4.8480
Val loss  : 4.8238
Val ppl   : 124.44

Epoch 5/20
Train loss: 4.6315
Val loss  : 4.6836
Val ppl   : 108.15

Epoch 6/20
Train loss: 4.4635
Val loss  : 4.5749
Val ppl   : 97.02

Epoch 7/20
Train loss: 4.3259
Val loss  : 4.4907
Val ppl   : 89.18

Epoch 8/20
Train loss: 4.2084
Val loss  : 4.4240
Val ppl   : 83.43

Epoch 9/20
Train loss: 4.1053
Val loss  : 4.3639
Val ppl   : 78.56

Epoch 10/20
Train loss: 4.0141
Val loss  : 4.3147
Val ppl   : 74.79

Epoch 11/20
Train loss: 3.9317
Val loss  : 4.2698
Val ppl   : 71.50

Epoch 12/20
Train loss: 3.8575
Val loss  : 4.2333
Val ppl   : 68.95

Epoch 13/20
Train loss: 3.7889
Val loss  : 4.1988
Val ppl   : 66.61

Epoch 14/20
Train loss: 3.7273
Val loss  : 4.1717
Val ppl   : 64.82

Epoch 15/20
Train loss: 3.6691
Val lo

Save Model

In [13]:
torch.save({
    "model_state": model.state_dict(),
    "vocab_size": sp.get_piece_size(),
    "pad_id": sp.pad_id(),
    "emb_dim": 256,
    "hid_dim": 512,
    "layers": 2,
}, "khmer_lstm_lm.pt")

print("Model saved: khmer_lstm_lm.pt")

Model saved: khmer_lstm_lm.pt


Testing

In [14]:
@torch.no_grad()
def predict_next_word(model, sp, text, top_k=5):
    model.eval()

    # Encode input
    ids = [sp.bos_id()] + sp.encode(text, out_type=int)
    h = None

    # Feed context except last token
    for tid in ids[:-1]:
        _, h = model.step(tid, h=h, device=device)

    last_token = ids[-1]
    logits, _ = model.step(last_token, h=h, device=device)

    probs = F.softmax(logits, dim=-1)
    top_probs, top_ids = torch.topk(probs, top_k)

    results = []
    for tid, p in zip(top_ids.tolist(), top_probs.tolist()):
        word = sp.decode([int(tid)])   # ✅ convert tensor to int
        results.append((word, float(p)))

    return results

In [37]:
text = "ខ្ញុំទៅសាលារៀនជាមួយបង"
preds = predict_next_word(model, sp, text)

print("Input:", text)
print("Next word predictions:")
for w, p in preds:
    print(f"  {w}  ({p:.3f})")


Input: ខ្ញុំទៅសាលារៀនជាមួយបង
Next word predictions:
  ប្រុស  (0.593)
  កែវ  (0.165)
  ស្រី  (0.060)
  ថ្លៃ  (0.016)
  ល  (0.009)
