In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
import numpy as np
import string
import random
import torch
import math
import os
import re

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if device == 'cuda':
    torch.cuda.manual_seed_all(seed)

In [3]:
data = []
with open("/kaggle/input/khmer-text/general-text.txt") as f:
    for i, line in enumerate(f, 1):
        data.append({'text' : line.replace('\n', '')})
        
df = pd.DataFrame(data)
df.head()
print(df.shape)

(582511, 1)


In [4]:
df['text'] = df['text'].replace('', np.nan)
df.dropna(inplace=True)
print(df.shape)

(471632, 1)


In [5]:
import re
import string

unwanted_chars = [
    '\u200b', '\u200c', '\u200d', '\ufeff',
    '៙', '៚', '៖', 'ៗ', '៛', '៝', '៸', '៓'
]

khmer_punct = '។៕៘'

replace_map = {
    'ឝ': 'គ',
    'ឞ': 'ម',
}

def clean_text(text):
    text = ''.join(c for c in text if c not in unwanted_chars)
    text = re.sub(r'[A-Za-z0-9]+', '', text)
    text = ''.join(c for c in text if c not in string.punctuation)
    text = re.sub(r'[^\u1780-\u17FF\u17E0-\u17E9\s' + khmer_punct + ']', '', text)
    for old, new in replace_map.items():
        text = text.replace(old, new)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

df['text'] = df['text'].apply(clean_text)

In [6]:
df['text']

0         សាលារាជធានីថា មិនទាន់ទទួលបាន លិខិតសុំធ្វើបាតុក...
2         យប់នេះប៉ូលិសដាក់ប៉ុស្តិ៍រហូតដល់ ៧កន្លែង បងប្អូ...
4         លោកស្រី ឃួន សុដារី អនុប្រធានកាកបាទក្រហមកម្ពុជា...
6         គ្រោះថ្នាក់ចរាចរណ៍ទូទាំងប្រទេសថ្ងៃ១៥ ខែកុម្ភៈម...
8         លោក ហ៊ុន ម៉ានី ជួបប្រជុំជាមួយអភិបាលខេត្តសៀមរាប...
                                ...                        
582504    ដូចទៅនិង ភរិយាទីពីររបស់ ដែលត្រូវគាត់បោះបង់ចោលដ...
582505    ទំនួលខុសត្រូវសង្គមរបស់ពួកគេគឺជាភាពស្មោះត្រង់ខា...
582506    នេះតម្រូវឲ្យមានការគោរពដល់ឯកជនភាព កិត្តិយសរបស់ប...
582507    យោងតាម ទាំងនេះគឺជាកត្តាដែលរារាំងដល់វិជ្ជាជីវៈស...
582508                       គំនិតដ៏វៀងវៃគឺជាសម្បត្តិមហាសាល
Name: text, Length: 471632, dtype: object

In [7]:
def tokenize(text):
    return list(text)

special_tokens = ["<pad>", "<unk>", "<sos>", "<eos>"]
unwanted_tokens = ['៘','។', '៕', '឴']
all_text = df['text'].tolist()
tokens = [t for sentence in all_text for t in tokenize(sentence) if t not in unwanted_tokens]

vocab = sorted(set(tokens))
vocab = vocab + special_tokens 
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}

vocab_size = len(vocab)
print(vocab_size)
print(stoi)

94
{' ': 0, 'ក': 1, 'ខ': 2, 'គ': 3, 'ឃ': 4, 'ង': 5, 'ច': 6, 'ឆ': 7, 'ជ': 8, 'ឈ': 9, 'ញ': 10, 'ដ': 11, 'ឋ': 12, 'ឌ': 13, 'ឍ': 14, 'ណ': 15, 'ត': 16, 'ថ': 17, 'ទ': 18, 'ធ': 19, 'ន': 20, 'ប': 21, 'ផ': 22, 'ព': 23, 'ភ': 24, 'ម': 25, 'យ': 26, 'រ': 27, 'ល': 28, 'វ': 29, 'ស': 30, 'ហ': 31, 'ឡ': 32, 'អ': 33, 'ឣ': 34, 'ឤ': 35, 'ឥ': 36, 'ឦ': 37, 'ឧ': 38, 'ឩ': 39, 'ឪ': 40, 'ឫ': 41, 'ឬ': 42, 'ឭ': 43, 'ឮ': 44, 'ឯ': 45, 'ឰ': 46, 'ឱ': 47, 'ឲ': 48, 'ឳ': 49, '឵': 50, 'ា': 51, 'ិ': 52, 'ី': 53, 'ឹ': 54, 'ឺ': 55, 'ុ': 56, 'ូ': 57, 'ួ': 58, 'ើ': 59, 'ឿ': 60, 'ៀ': 61, 'េ': 62, 'ែ': 63, 'ៃ': 64, 'ោ': 65, 'ៅ': 66, 'ំ': 67, 'ះ': 68, 'ៈ': 69, '៉': 70, '៊': 71, '់': 72, '៌': 73, '៍': 74, '៎': 75, '៏': 76, '័': 77, '៑': 78, '្': 79, '០': 80, '១': 81, '២': 82, '៣': 83, '៤': 84, '៥': 85, '៦': 86, '៧': 87, '៨': 88, '៩': 89, '<pad>': 90, '<unk>': 91, '<sos>': 92, '<eos>': 93}


In [8]:
def split_sentences(text):
    """Split Khmer text into sentences based on punctuation"""
    sentences = re.split(r'[។៕៘]', text)
    return [s.strip() for s in sentences if s.strip()]

def chunk_text_recursive(sentence, chunk_size=150, overlap=50):
    chunks = []
    start = 0
    while start < len(sentence):
        end = start + chunk_size
        chunks.append(sentence[start:end])
        # Move start by (chunk_size - overlap)
        start += chunk_size - overlap
    return chunks

def sentence_to_char_indices(sentence, vocab):
    """Convert sentence to list of token indices with <sos> and <eos>"""
    chars = list(sentence)
    return [vocab["<sos>"]] + [vocab.get(c, vocab["<unk>"]) for c in chars] + [vocab["<eos>"]]

def prepare_dataset(texts, vocab, chunk_size=150, overlap=50):
    inputs, targets = [], []
    for text in texts:     
        sentences = split_sentences(text)
        for sent in sentences:
            chunks = chunk_text_recursive(sent, chunk_size, overlap)
            for chunk in chunks:
                ids = sentence_to_char_indices(chunk, vocab)
                inputs.append(ids[:-1])
                targets.append(ids[1:])
    return inputs, targets

In [9]:
df_clean = df[df['text'].str.strip().astype(bool)]
df_clean = df_clean[df_clean['text'].str.len() >= 120]

# Randomly select 30,000 rows
df_sample = df_clean.sample(n=30000, random_state=seed)

In [10]:
train_texts, temp_texts = train_test_split(df_sample['text'].tolist(), test_size=0.2, random_state=seed)
val_texts, test_texts = train_test_split(temp_texts, test_size=0.5, random_state=seed)

X_train, Y_train = prepare_dataset(train_texts, stoi)
X_val, Y_val = prepare_dataset(val_texts, stoi)
X_test, Y_test = prepare_dataset(test_texts, stoi)

In [11]:
def decode_ids_to_text(ids, itos, remove_special_tokens=True):
    """Convert a list of token IDs back into readable text."""
    tokens = [itos[i] for i in ids if i in itos]
    
    if remove_special_tokens:
        tokens = [t for t in tokens if t not in ("<sos>", "<eos>", "<pad>", "<unk>")]
    
    return "".join(tokens)

ids = Y_train[0]
print(ids)
decoded_text = decode_ids_to_text(ids, itos)
print(decoded_text)

[33, 51, 6, 21, 5, 79, 1, 47, 79, 26, 25, 51, 20, 1, 57, 20, 25, 52, 20, 3, 79, 27, 21, 72, 2, 63, 0, 18, 51, 27, 1, 1, 59, 16, 25, 1, 25, 51, 20, 18, 25, 79, 5, 20, 72, 18, 51, 21, 0, 21, 10, 79, 31, 51, 20, 62, 68, 20, 54, 5, 21, 20, 79, 30, 28, 72, 18, 56, 1, 20, 57, 29, 22, 28, 29, 52, 21, 51, 1, 19, 79, 5, 20, 72, 19, 79, 5, 27, 8, 51, 6, 79, 27, 59, 20, 6, 67, 23, 65, 68, 18, 51, 27, 1, 11, 57, 6, 8, 51, 1, 79, 27, 52, 20, 0, 1, 5, 79, 29, 68, 33, 51, 31, 51, 27, 57, 21, 16, 79, 17, 25, 79, 24, 0, 2, 79, 30, 65, 26, 22, 79, 28, 57, 29, 11, 5, 79, 31, 59, 25, 93]
អាចបង្កឱ្យមានកូនមិនគ្រប់ខែ ទារកកើតមកមានទម្ងន់ទាប បញ្ហានេះនឹងបន្សល់ទុកនូវផលវិបាកធ្ងន់ធ្ងរជាច្រើនចំពោះទារកដូចជាក្រិន កង្វះអាហារូបត្ថម្ភ ខ្សោយផ្លូវដង្ហើម


In [12]:
class CharDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)

def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=stoi["<pad>"])
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=stoi["<pad>"])
    return inputs_padded, targets_padded

batch_size = 32
train_dataset = CharDataset(X_train, Y_train)
val_dataset = CharDataset(X_val, Y_val)
test_dataset = CharDataset(X_test, Y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

for x, y in train_loader:
    print("Input batch shape:", x.shape)
    print("Target batch shape:", y.shape)
    break

Input batch shape: torch.Size([32, 151])
Target batch shape: torch.Size([32, 151])


In [13]:
class LSTMAutoencoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.decoder = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        emb = self.embedding(x)
        enc_out, enc_hidden = self.encoder(emb)
        dec_out, dec_hidden = self.decoder(emb, enc_hidden)
        logits = self.fc(dec_out)
        return logits

In [14]:
save_dir = "/kaggle/working/checkpoints"
os.makedirs(save_dir, exist_ok=True)

# Initialize model, optimizer, loss
model = LSTMAutoencoder(vocab_size=len(stoi)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=stoi["<pad>"])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 20
patience = 3
best_val_loss = float('inf')
wait = 0

for epoch in range(epochs):

    # ---- Training Phase ----
    model.train()
    total_train_loss = 0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)

    for x_batch, y_batch in train_bar:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        logits = model(x_batch)  # only returns logits
        loss = criterion(logits.reshape(-1, logits.size(-1)), y_batch.reshape(-1))
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        optimizer.step()

        total_train_loss += loss.item()
        train_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_loader)
    train_ppl = math.exp(avg_train_loss)

    # ---- Validation Phase ----
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            logits = model(x_batch)  # FIXED (no unpacking)
            loss = criterion(logits.reshape(-1, logits.size(-1)), y_batch.reshape(-1))
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_ppl = math.exp(avg_val_loss)

    # ---- Logging ----
    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {avg_train_loss:.4f} (PPL {train_ppl:.4f}) | "
          f"Val Loss: {avg_val_loss:.4f} (PPL {val_ppl:.4f})")

    # ---- Checkpointing ----
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        wait = 0

        save_path = os.path.join(save_dir, "best_autoencoder.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vocab_size': len(stoi),
            'stoi': stoi,
            'itos': itos,
            'embedding_dim': 128,
            'hidden_dim': 256,
            'num_layers': 2,
            'best_val_loss': best_val_loss,
            'epoch': epoch + 1
        }, save_path)

        print(f"  ** Validation improved, model saved to {save_path}")
    else:
        wait += 1
        print(f"  ** No improvement ({wait}/{patience})")
        if wait >= patience:
            print("Early stopping triggered.")
            break

print(f"Training finished. Best Validation Loss: {best_val_loss:.4f}")

In [15]:
checkpoint = torch.load("/kaggle/input/best/pytorch/default/1/best_autoencoder.pt", map_location="cpu")
stoi = checkpoint["stoi"]
itos = checkpoint["itos"]

model = LSTMAutoencoder(
    vocab_size=checkpoint["vocab_size"],
    embed_dim=checkpoint["embedding_dim"],   
    hidden_dim=checkpoint["hidden_dim"],
    num_layers=checkpoint["num_layers"]
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)

LSTMAutoencoder(
  (embedding): Embedding(94, 128)
  (encoder): LSTM(128, 256, num_layers=2, batch_first=True)
  (decoder): LSTM(128, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=94, bias=True)
)

In [16]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

model.eval()
predictions = []

with torch.no_grad():
    for x_batch, _ in test_loader:  # targets are same as input for autoencoder
        x_batch = x_batch.to(device)
        logits = model(x_batch)  # [batch, seq_len, vocab_size]
        pred_ids = logits.argmax(dim=-1)
        predictions.extend(pred_ids.cpu().tolist())


smoothie = SmoothingFunction().method4
total_bleu = 0

for pred_ids, target_ids in zip(predictions, X_test):  # X_test = input sequences
    pred_chars = [itos[idx] for idx in pred_ids if idx not in (stoi['<pad>'],)]
    target_chars = [itos[idx] for idx in target_ids if idx not in (stoi['<pad>'],)]
    total_bleu += sentence_bleu([target_chars], pred_chars, smoothing_function=smoothie)

avg_bleu = total_bleu / len(predictions)
print(f"Pre-trained model char-level BLEU: {avg_bleu:.4f}")


Pre-trained model char-level BLEU: 0.3089
