In [1]:
from llama3_transformer_block import *
import math

In [2]:
NUM_HEAD = 16
NUM_KV_HEAD = 8
NUM_LAYER = 4
EMBED_DIM = 512
HEAD_DIM = EMBED_DIM // NUM_HEAD
ROPE_BASE = 10000
MLP_SCALE = 3.5
EPS_NORM = 1e-5
DROPOUT = math.sin(math.sqrt(math.e * math.pi))
MAX_SEQUENCE = 1024

In [3]:
class Encoder(nn.Module):
    def __init__(self, INPUT_DIM):
        super(Encoder, self).__init__()
        self.INPUT_DIM = INPUT_DIM
        self.TOKEN_EMBEDDING = nn.Embedding(INPUT_DIM, EMBED_DIM)

        self.SELF_ATTENTION = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=RotaryPositionalEmbedding(
                dim=HEAD_DIM,
                max_seq_len=MAX_SEQUENCE,
                base=ROPE_BASE,
            ),
            max_seq_len=MAX_SEQUENCE,
            attn_dropout=DROPOUT,
        )
        self.MLP = FeedForward(
            gate_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
            down_proj=nn.Linear(int(EMBED_DIM * MLP_SCALE), EMBED_DIM, bias=False),
            up_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
        )

        self.ENCODER_LAYER = TransformerEncoderLayer(
            attn=self.SELF_ATTENTION,
            mlp=copy.deepcopy(self.MLP),
            sa_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            mlp_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
        )

        self.encoder = TransformerEncoder(
            tok_embedding=self.TOKEN_EMBEDDING,
            layer=self.ENCODER_LAYER,
            num_layers=NUM_LAYER,
            max_seq_len=MAX_SEQUENCE,
            num_heads=NUM_HEAD,
            head_dim=HEAD_DIM,
            norm=RMSNorm(EMBED_DIM, eps=EPS_NORM),
        )
        
    def forward(self, x):
        out = self.encoder(x)
        return out

In [4]:
class Decoder(nn.Module):
    def __init__(self, OUTPUT_DIM):
        super(Decoder, self).__init__()
        self.OUTPUT_DIM = OUTPUT_DIM

        self.TOKEN_EMBEDDING = nn.Embedding(OUTPUT_DIM, EMBED_DIM)
        
        self.MLP = FeedForward(
            gate_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
            down_proj=nn.Linear(int(EMBED_DIM * MLP_SCALE), EMBED_DIM, bias=False),
            up_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
        )
        self.ROPE = RotaryPositionalEmbedding(
            dim=HEAD_DIM,
            max_seq_len=MAX_SEQUENCE,
            base=ROPE_BASE,
        )
        self.SELF_ATTENTION_1 = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=self.ROPE,
            max_seq_len=MAX_SEQUENCE,
            attn_dropout=DROPOUT,
        )
        self.SELF_ATTENTION_2 = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=self.ROPE,
            max_seq_len=MAX_SEQUENCE,
            attn_dropout=DROPOUT,
        )
        self.DECODER_LAYER = TransformerDecoderLayer(
            attn1=self.SELF_ATTENTION_1,
            attn2=self.SELF_ATTENTION_2,
            mlp=copy.deepcopy(self.MLP),
            sa_norm_x1=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            sa_norm_x2=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            mlp_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
        )
        self.OUT_PROJECTION = nn.Linear(EMBED_DIM, OUTPUT_DIM, bias=False)
        self.decoder = TransformerDecoder(
            tok_embedding=self.TOKEN_EMBEDDING,
            layer=self.DECODER_LAYER,
            num_layers=NUM_LAYER,
            max_seq_len=MAX_SEQUENCE,
            num_heads=NUM_HEAD,
            head_dim=HEAD_DIM,
            norm=RMSNorm(EMBED_DIM, eps=EPS_NORM),
            output=self.OUT_PROJECTION,
        )

    def forward(self, x, encoder_out):
        out = self.decoder(x, encoder_out)
        return out

In [5]:
class TranslationModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(TranslationModel, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x, y): # x is input sequence, y is target sequence
        encoder_out = self.encoder(x)
        decoder_out = self.decoder(y, encoder_out)
        return decoder_out
    

In [6]:
vi_data_link = "viet-lao-dataset/data-lao-viet/data.vi"
lao_data_link = "viet-lao-dataset/data-lao-viet/data.lo"

In [7]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
import pandas as pd

def train_bpe(input_file, tokenizer_name, vocab_size=8000):
    tokenizer = Tokenizer(BPE(unk_token="<unk>"))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=["<pad>", "<s>", "</s>", "<unk>"])
    tokenizer.train([input_file], trainer)
    tokenizer.save(f"{tokenizer_name}.json")

train_bpe(lao_data_link, "lo_tokenizer")
train_bpe(vi_data_link, "vi_tokenizer")


In [8]:
with open(vi_data_link, 'r', encoding='utf-8') as f_vi:
    vi_sentences = f_vi.readlines()

with open(lao_data_link, 'r', encoding='utf-8') as f_lo:
    lo_sentences = f_lo.readlines()

df = pd.DataFrame({
    'vi': [sentence.strip() for sentence in vi_sentences],
    'lo': [sentence.strip() for sentence in lo_sentences]
})

lo_tok = Tokenizer.from_file("lo_tokenizer.json")
vi_tok = Tokenizer.from_file("vi_tokenizer.json")

def encode_sentence(tok, sentence, add_special_tokens=True):
    if add_special_tokens:
        return [tok.token_to_id("<s>")] + tok.encode(sentence).ids + [tok.token_to_id("</s>")]
    return tok.encode(sentence).ids

df["lo_ids"] = df["lo"].apply(lambda x: encode_sentence(lo_tok, x))
df["vi_ids"] = df["vi"].apply(lambda x: encode_sentence(vi_tok, x))

print(df.head())


                                                  vi  \
0  Phán quyết của Tòa án quốc tế (PCA) năm 2016 l...   
1  8 di sản văn hóa có hình dạng: Nhóm Di tích Tù...   
2                                000 bà mẹ còn sống.   
3  Việt Nam tổ chức thành công Diễn đàn Kinh tế T...   
4  Một lần nữa, tại cuộc bầu cử lần này, số lượng...   

                                                  lo  \
0  ຄຳຕັດສິນຂອງສານກຳມະການສາກົນ (PCA) ປີ 2016 ແມ່ນສ...   
1  8 ມໍ​ລະ​ດົກ​​ວັດ​ທະ​ນະ​ທຳ​ມີ​ຮູບ​ຮ່າງຄື : ກຸ່ມ...   
2                           000 ແມ່ທີ່ຍັງມີຊີວິດຢູ່.   
3  ຫວຽດນາມ ຈັດຕັ້ງເວທີປາໄສເສດຖະກິດໂລກກ່ຽວກັບອາຊຽນ...   
4  ອັນໜຶ່ງອີກ, ທີ່ການເລືອກຕັ້ງຄັ້ງນີ້, ຈຳນວນພັກກາ...   

                                              lo_ids  \
0  [1, 4047, 6857, 1181, 770, 10, 4968, 11, 212, ...   
1  [1, 26, 1344, 187, 202, 187, 2324, 285, 512, 1...   
2         [1, 395, 1086, 2992, 234, 924, 272, 16, 2]   
3  [1, 256, 1030, 6033, 713, 212, 561, 3799, 6033...   
4  [1, 554, 837, 877, 14, 238, 2196, 2366, 14,

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class TranslationDataset(Dataset):
    def __init__(self, src_ids, tgt_ids):
        self.src = src_ids
        self.tgt = tgt_ids

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return torch.tensor(self.src[idx]), torch.tensor(self.tgt[idx])

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_pad_id = vi_tok.token_to_id("<pad>")
    tgt_pad_id = lo_tok.token_to_id("<pad>")
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=src_pad_id)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_pad_id)
    return src_batch, tgt_batch

BATCH_SIZE = 16


from sklearn.model_selection import train_test_split

src_train, src_valid, tgt_train, tgt_valid = train_test_split(
    df["vi_ids"].tolist(), df["lo_ids"].tolist(), test_size=0.1, random_state=42
)

train_dataset = TranslationDataset(src_train, tgt_train)
valid_dataset = TranslationDataset(src_valid, tgt_valid)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)


In [10]:
vi_vocab_size = vi_tok.get_vocab_size()
lo_vocab_size = lo_tok.get_vocab_size()

In [11]:
INPUT_DIM = vi_vocab_size
OUTPUT_DIM = lo_vocab_size

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

encoder = Encoder(INPUT_DIM)
decoder = Decoder(OUTPUT_DIM)
model = TranslationModel(encoder, decoder)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params / 1e6:.2f}M")

# Optimizer và loss
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=lo_tok.token_to_id("<pad>"))

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

num_epochs = 100
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for src_batch, tgt_batch in loop:
        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        optimizer.zero_grad()

        output = model(src_batch, tgt_batch[:, :-1])  # output shape: (B, T, vocab)
        loss = criterion(output.reshape(-1, output.size(-1)), tgt_batch[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for src_batch, tgt_batch in valid_loader:
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
            output = model(src_batch, tgt_batch[:, :-1])
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_batch[:, 1:].reshape(-1))
            val_loss += loss.item()

    val_loss /= len(valid_loader)
    print(f"Validation Loss: {val_loss:.4f}")

    # Scheduler update
    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pt")
        print("Saved best model.")

    torch.save(model.state_dict(), "last_model.pt")


Number of parameters: 57.13M


  from .autonotebook import tqdm as notebook_tqdm
  output = nn.functional.scaled_dot_product_attention(
Epoch 1/100:   6%|▋         | 44/689 [00:07<01:45,  6.09it/s, loss=5.72]


KeyboardInterrupt: 

In [13]:
def translate(model, src_sentence, src_tokenizer, tgt_tokenizer, device, max_len=50):
    model.eval()
    with torch.no_grad():
        # Encode source sentence
        src_ids = [src_tokenizer.token_to_id("<s>")] + src_tokenizer.encode(src_sentence).ids + [src_tokenizer.token_to_id("</s>")]
        src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)  # shape: (1, src_len)

        # Forward pass through the encoder
        encoder_out = model.encoder(src_tensor)

        # Initialize target sequence with <s> token
        tgt_ids = [tgt_tokenizer.token_to_id("<s>")]
        tgt_tensor = torch.tensor(tgt_ids).unsqueeze(0).to(device)  # shape: (1, 1)

        for _ in range(max_len):
            # Forward pass through the decoder
            output = model.decoder(tgt_tensor, encoder_out)  # shape: (1, len_so_far, vocab_size)
            
            # Get logits for the last token in the sequence
            next_token_logits = output[0, -1]  # shape: (vocab_size,)
            
            # Get the token with the highest probability
            next_token_id = torch.argmax(next_token_logits).item()
            tgt_ids.append(next_token_id)

            # If the </s> token is generated, stop decoding
            if next_token_id == tgt_tokenizer.token_to_id("</s>"):
                break

            # Update the target tensor with the newly generated token
            tgt_tensor = torch.tensor(tgt_ids).unsqueeze(0).to(device)

        # Decode the target sequence (excluding <s> and </s> tokens)
        decoded_tokens = [tgt_tokenizer.id_to_token(tid) for tid in tgt_ids[1:-1]]
        return " ".join(decoded_tokens)


In [14]:
import torch
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from tqdm import tqdm

def evaluate_model(model, data_loader, src_tokenizer, tgt_tokenizer, device, max_length=50):
    """
    Evaluate the translation model using BLEU, METEOR, ROUGE-L, and BERTScore metrics.
    
    Args:
        model: Trained translation model
        data_loader: DataLoader for evaluation dataset
        src_tokenizer: Tokenizer for source language
        tgt_tokenizer: Tokenizer for target language
        device: Device to run the model on (cuda/cpu)
        max_length: Maximum length for generated sequences
    
    Returns:
        Dictionary containing evaluation metrics
    """
    model.eval()
    references = []
    hypotheses = []
    
    for src_batch, tgt_batch in tqdm(data_loader, desc="Evaluating"):
        for src_ids, tgt_ids in zip(src_batch, tgt_batch):
            # Decode source sentence for translation
            src_tokens = [src_tokenizer.id_to_token(sid.item()) for sid in src_ids 
                         if sid.item() not in [src_tokenizer.token_to_id("<s>"), 
                                              src_tokenizer.token_to_id("</s>"),
                                              src_tokenizer.token_to_id("<pad>")]]
            src_sentence = " ".join(src_tokens)
            
            # Generate translation
            translated = translate(model, src_sentence, src_tokenizer, tgt_tokenizer, 
                                 device, max_length)
            hypotheses.append(translated)
            
            # Decode reference
            ref_tokens = [tgt_tokenizer.id_to_token(tid.item()) for tid in tgt_ids 
                         if tid.item() not in [tgt_tokenizer.token_to_id("<s>"), 
                                              tgt_tokenizer.token_to_id("</s>"),
                                              tgt_tokenizer.token_to_id("<pad>")]]
            references.append([" ".join(ref_tokens)])
    
    # Calculate metrics
    metrics = {}
    
    # BLEU score
    bleu_score = corpus_bleu(references, hypotheses)
    metrics['BLEU'] = bleu_score
    
    # METEOR score
    meteor_scores = [meteor_score(ref[0].split(), hyp.split()) 
                    for ref, hyp in zip(references, hypotheses)]
    metrics['METEOR'] = sum(meteor_scores) / len(meteor_scores)
    
    # ROUGE-L score
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_scores = [rouge_scorer_obj.score(ref[0], hyp)['rougeL'].fmeasure 
                    for ref, hyp in zip(references, hypotheses)]
    metrics['ROUGE-L'] = sum(rouge_scores) / len(rouge_scores)
    
    # BERTScore
    P, R, F1 = bert_score(hypotheses, [ref[0] for ref in references], 
                         lang="en", verbose=False)
    metrics['BERTScore_F1'] = F1.mean().item()
    
    return metrics

def print_metrics(metrics):
    """Print evaluation metrics in a formatted way."""
    print("\nEvaluation Metrics:")
    print("-" * 30)
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")

In [None]:
metrics = evaluate_model(model, valid_loader, vi_tok, lo_tok, device)
print_metrics(metrics)

Evaluating:   6%|▋         | 5/77 [00:26<06:13,  5.19s/it]

In [None]:
import heapq

def translate_beam_search(model, src_sentence, src_tokenizer, tgt_tokenizer, device, beam_width=3, max_len=50):
    model.eval()
    with torch.no_grad():
        src_ids = [src_tokenizer.token_to_id("<s>")] + src_tokenizer.encode(src_sentence).ids + [src_tokenizer.token_to_id("</s>")]
        src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)
        encoder_out = model.encoder(src_tensor)

        beams = [(0, [tgt_tokenizer.token_to_id("<s>")])]  # (score, token_ids)

        for _ in range(max_len):
            new_beams = []
            for score, seq in beams:
                tgt_tensor = torch.tensor(seq).unsqueeze(0).to(device)
                output = model.decoder(tgt_tensor, encoder_out)
                next_logits = output[0, -1]
                topk = torch.topk(next_logits, beam_width)

                for i in range(beam_width):
                    token_id = topk.indices[i].item()
                    token_score = topk.values[i].item()
                    new_seq = seq + [token_id]
                    new_score = score + token_score
                    new_beams.append((new_score, new_seq))

            beams = heapq.nlargest(beam_width, new_beams, key=lambda x: x[0])

            if all(seq[-1] == tgt_tokenizer.token_to_id("</s>") for _, seq in beams):
                break
        best_seq = max(beams, key=lambda x: x[0])[1]
        decoded = [tgt_tokenizer.id_to_token(i) for i in best_seq[1:-1]]
        return " ".join(decoded)


In [None]:
sentence = "di sản văn hóa"
translated = translate(model, sentence, vi_tok, lo_tok, device)
print("Translated:", translated)


Translated: 000 ຄົນ , ທ່ານ ​ ໃນ ​ ການ ​ ຄ້າ ​ ປະ ​ ເທດ ​ ໃນ ​ ການ ​ ຄ້າ ​ ໃນ ​ ການ ​ ຄ້າ ​ ການ ​ ຄ້າ ​ ສົ່ງ ​ ອອກ ​ ໃນ ​ ການ ​ ຄ້າ ​ ຮ່ວມ ​ ການ ​ ຄ້າ ​ ໃນ ​
