In [1]:
!git clone https://github.com/UniversalDependencies/UD_Vietnamese-VTB.git

Cloning into 'UD_Vietnamese-VTB'...


remote: Enumerating objects: 802, done.[K
remote: Counting objects: 100% (39/39), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 802 (delta 32), reused 27 (delta 23), pack-reused 763 (from 3)[K
Receiving objects: 100% (802/802), 20.76 MiB | 14.19 MiB/s, done.
Resolving deltas: 100% (431/431), done.


In [None]:
!pip install conllu
!pip install torch
!pip install numpy
!pip install gensim
!pip install matplotlib

[0m

In [2]:
!pip install transformers

[0m

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from conllu import parse_incr
from collections import Counter
import numpy as np
import random
import os
import copy
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer


In [2]:
import torch
import numpy as np
import random
import os

def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # N·∫øu d√πng multi-GPU
    
    # ƒê·∫£m b·∫£o t√≠nh to√°n tr√™n cuDNN l√† deterministic (c√≥ th·ªÉ l√†m ch·∫≠m t·ªëc ƒë·ªô train m·ªôt ch√∫t)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"‚úÖ ƒê√£ thi·∫øt l·∫≠p random seed: {seed}")

set_seed(42)

‚úÖ ƒê√£ thi·∫øt l·∫≠p random seed: 42


In [3]:
def map_sentence_to_ids(sentences, p_vocab, d_vocab):
    """Chuy·ªÉn ƒë·ªïi to√†n b·ªô nh√£n UPOS v√† DEPREL trong c√¢u sang ID tr∆∞·ªõc khi training."""
    for sent in sentences:
        for token in sent:
            token['upos_id'] = p_vocab.get(token['upos'], p_vocab.get("<UNK>", 0))
            token['deprel_id'] = d_vocab.get(token['deprel'], 0)
    return sentences

In [35]:
import random
import copy

class Oracle:
    def __init__(self, vocab_pos, vocab_deps):
        self.p_vocab = vocab_pos
        self.d_vocab = vocab_deps
        # Danh s√°ch ID c√°c nh√£n quan h·ªá ƒë·ªÉ gi·∫£ l·∫≠p vi·ªác "g√°n nh·∫ßm"
        self.all_dep_ids = [v for k, v in vocab_deps.items() if k not in ["<NULL>", "<ROOT>", "<UNK>"]]
        self.action_map = self._build_action_map()

    def _build_action_map(self):
        mapping = {"SHIFT": 0}
        idx = 1
        for dep in self.d_vocab:
            if dep not in ["<NULL>", "<ROOT>", "<UNK>"]:
                mapping[f"LEFT-ARC_{dep}"] = idx
                idx += 1
                mapping[f"RIGHT-ARC_{dep}"] = idx
                idx += 1
        return mapping

    def extract_features(self, stack, buffer, sentence, sent_idx, left_children_map, right_children_map, 
                         stack_size=3, buffer_size=3, mutate=False):
        # 1. STACK VIEW
        s_window = stack[-stack_size:] if stack else [0]
        if len(s_window) < stack_size:
            # Pad LEFT ƒë·ªÉ c√°c ph·∫ßn t·ª≠ quan tr·ªçng (s0, s1) n·∫±m ·ªü cu·ªëi list gi·ªëng nh∆∞ stack th·∫≠t
            s_window = [None] * (stack_size - len(s_window)) + s_window 
            
        s_feats = {
            'words': [], 'pos': [],
            'l_child_idx': [], 'l_child_dep': [],  'l_child_pos': [],
            'r_child_idx': [], 'r_child_dep': [],  'r_child_pos': [] 
        }

        for i, idx in enumerate(s_window):
            if idx is None:
                for k in s_feats: s_feats[k].append(0)
                continue

            s_feats['words'].append(idx)
            if idx == 0: # ROOT
                s_feats['pos'].append(0)
                for k in ['l_child_idx', 'l_child_dep', 'l_child_pos', 'r_child_idx', 'r_child_dep', 'r_child_pos']:
                    s_feats[k].append(0)
            else:
                token = sentence[idx - 1]
                s_feats['pos'].append(token['upos_id'])

                l_children = left_children_map.get(idx, [])
                r_children = right_children_map.get(idx, [])

                # L·∫•y th√¥ng tin con th·ª±c t·∫ø (Gold)
                l_idx, l_dep, l_pos = l_children[0] if l_children else (0, 0, 0)
                r_idx, r_dep, r_pos = r_children[-1] if r_children else (0, 0, 0)

                # --- LOGIC MUTATION (GI·∫¢ L·∫¨P SAI B∆Ø·ªöC TR∆Ø·ªöC) ---
                # i < stack_size - 1 ƒë·∫£m b·∫£o ch·ªâ mutate s1, s2... (c√°c t·ª´ ƒë√£ n·∫±m l√¢u tr√™n stack)
                if mutate and i < (stack_size - 1):
                    # Gi·∫£ l·∫≠p vi·ªác g√°n nh·∫ßm nh√£n quan h·ªá cho c√°c con c·ªßa s1, s2
                    if l_dep != 0 and random.random() < 0.5:
                        l_dep = random.choice(self.all_dep_ids)
                    if r_dep != 0 and random.random() < 0.5:
                        r_dep = random.choice(self.all_dep_ids)
                # -----------------------------------------------

                s_feats['l_child_idx'].append(l_idx)
                s_feats['l_child_dep'].append(l_dep)
                s_feats['l_child_pos'].append(l_pos)
                s_feats['r_child_idx'].append(r_idx)
                s_feats['r_child_dep'].append(r_dep)
                s_feats['r_child_pos'].append(r_pos)

        # 2. BUFFER VIEW
        b_window = buffer[:buffer_size] if buffer else [0]
        if len(b_window) < buffer_size:
            b_window = b_window + [None] * (buffer_size - len(b_window))
            
        b_feats = {'words': [], 'pos': []}
        for idx in b_window:
            if idx is None:
                b_feats['words'].append(0); b_feats['pos'].append(0)
                continue
            b_feats['words'].append(idx)
            b_feats['pos'].append(0 if idx == 0 else sentence[idx-1]['upos_id'])

        return {'stack': s_feats, 'buffer': b_feats, 'sent_id': sent_idx}

    def create_training_data_multiview(self, sentence, sent_idx, aug_p=0.2):
        """
        aug_p: X√°c su·∫•t t·∫°o th√™m m·ªôt m·∫´u d·ªØ li·ªáu nhi·ªÖu t·∫°i m·ªói b∆∞·ªõc transition.
        """
        stack, buffer = [0], list(range(1, len(sentence) + 1))
        left_children, right_children = {}, {}
        train_examples = []
        gold_heads = {token['id']: token['head'] for token in sentence}

        while buffer or len(stack) > 1:
            # 1. Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng GOLD (D·ªØ li·ªáu s·∫°ch)
            feat_gold = self.extract_features(stack, buffer, sentence, sent_idx, left_children, right_children, mutate=False)

            # --- X√°c ƒë·ªãnh Action ID ƒë√∫ng ---
            target_action_id = None
            if len(stack) >= 2:
                s1, s2 = stack[-1], stack[-2]
                if gold_heads.get(s2) == s1: # Left-Arc
                    token_s2 = sentence[s2 - 1]
                    target_action_id = self.action_map.get(f"LEFT-ARC_{token_s2['deprel']}", 0)
                elif gold_heads.get(s1) == s2 and not any(gold_heads.get(b) == s1 for b in buffer): # Right-Arc
                    token_s1 = sentence[s1 - 1]
                    target_action_id = self.action_map.get(f"RIGHT-ARC_{token_s1['deprel']}", 0)
                elif buffer: # Shift
                    target_action_id = self.action_map["SHIFT"]
            elif buffer:
                target_action_id = self.action_map["SHIFT"]

            if target_action_id is None: break

            # 2. TH√äM D·ªÆ LI·ªÜU V√ÄO T·∫¨P HU·∫§N LUY·ªÜN
            # Th√™m m·∫´u Gold (lu√¥n gi·ªØ nguy√™n d·ªØ li·ªáu g·ªëc)
            train_examples.append((feat_gold, target_action_id))

            # Augmentation: Th√™m m·∫´u b·ªã mutation (gi·∫£ l·∫≠p sai b∆∞·ªõc tr∆∞·ªõc)
            if aug_p > 0 and random.random() < aug_p:
                feat_mutated = self.extract_features(stack, buffer, sentence, sent_idx, 
                                                     left_children, right_children, mutate=True)
                # D·∫°y m√¥ h√¨nh: "D√π ƒë·∫∑c tr∆∞ng n√†y b·ªã nhi·ªÖu (mutation), h√†nh ƒë·ªông ƒë√∫ng v·∫´n l√† target_action_id"
                train_examples.append((feat_mutated, target_action_id))

            # 3. C·∫¨P NH·∫¨T TR·∫†NG TH√ÅI (S·ª≠ d·ª•ng Gold Action ƒë·ªÉ ƒëi ti·∫øp - Teacher Forcing)
            if target_action_id == self.action_map["SHIFT"]:
                stack.append(buffer.pop(0))
            else:
                action_name = [k for k, v in self.action_map.items() if v == target_action_id][0]
                if action_name.startswith("LEFT-ARC"):
                    child = stack.pop(-2)
                    head = stack[-1]
                    if head not in left_children: left_children[head] = []
                    left_children[head].append((child, sentence[child-1]['deprel_id'], sentence[child-1]['upos_id']))
                else: # RIGHT-ARC
                    child = stack.pop(-1)
                    head = stack[-1]
                    if head not in right_children: right_children[head] = []
                    right_children[head].append((child, sentence[child-1]['deprel_id'], sentence[child-1]['upos_id']))

        return train_examples

In [36]:
def load_data_and_build_vocab(file_path, is_train=True, word_v=None, pos_v=None, dep_v=None):
    sentences = []
    word_counts = Counter()
    pos_counts = Counter()
    dep_counts = Counter()

    for path in file_path:

        with open(path, "r", encoding="utf-8") as f:
          for tokenlist in parse_incr(f):
              # Ch·ªâ th√™m nh·ªØng c√¢u c√≥ n·ªôi dung
              if len(tokenlist) > 0:
                  if not is_train or 'train' in path:
                      sentences.append(tokenlist)
                  if is_train:
                      for token in tokenlist:
                          word_counts[token['form'].lower()] += 1
                          pos_counts[token['upos']] += 1
                          dep_counts[token['deprel']] += 1

    if is_train:
        # Kh·ªüi t·∫°o ID theo y√™u c·∫ßu: NULL=0, ROOT=1, UNK=2
        def create_mapping(counts, is_dep=False):
            vocab = {"<NULL>": 0, "<ROOT>": 1}
            if not is_dep: vocab["<UNK>"] = 2
            for item, count in counts.items():
                if item not in vocab and (is_dep or count > 1):
                    vocab[item] = len(vocab)
            return vocab

        word_v = create_mapping(word_counts)
        pos_v = create_mapping(pos_counts)
        dep_v = create_mapping(dep_counts, is_dep=True)
        return sentences, word_v, pos_v, dep_v

    return sentences

In [37]:
def preprocess_data(train_path, dev_path, test_path):
    # B∆Ø·ªöC 1: Load d·ªØ li·ªáu v√† ch·ªâ x√¢y d·ª±ng Vocab cho POS v√† Dep
    # B·∫°n c√≥ th·ªÉ d√πng l·∫°i h√†m load_data_and_build_vocab c≈© nh∆∞ng b·ªè qua ph·∫ßn word_v
    print("ƒêang ƒë·ªçc t·∫≠p Train v√† x√¢y d·ª±ng Vocab cho POS/Dep...")
    train_sentences, _, p_vocab, d_vocab = load_data_and_build_vocab([train_path], is_train=True)

    # B∆Ø·ªöC 2: Kh·ªüi t·∫°o Oracle m·ªõi (Kh√¥ng d√πng w_vocab)
    # ƒê·∫£m b·∫£o b·∫°n ƒë√£ c·∫≠p nh·∫≠t l·ªõp Oracle nh∆∞ h∆∞·ªõng d·∫´n ·ªü b∆∞·ªõc tr∆∞·ªõc
    oracle = Oracle(p_vocab, d_vocab)

    # B∆Ø·ªöC 3: T·∫°o d·ªØ li·ªáu hu·∫•n luy·ªán (X, y) d·∫°ng Sequence
    all_training_data = []
    print(f"ƒêang t·∫°o chu·ªói transitions cho {len(train_sentences)} c√¢u t·∫≠p Train...")

    for idx, sentence in enumerate(train_sentences):
        try:
            # create_training_data b√¢y gi·ªù tr·∫£ v·ªÅ danh s√°ch c√°c dict ch·ª©a chu·ªói indices
            examples = oracle.create_training_data(sentence, idx)
            all_training_data.extend(examples)
        except Exception as e:
            print(f"L·ªói logic t·∫°i c√¢u: {sentence.metadata['text']} - {e}")
            continue

    print(f"‚úÖ Th√†nh c√¥ng! T·ªïng s·ªë b∆∞·ªõc transition hu·∫•n luy·ªán: {len(all_training_data)}")

    # B∆Ø·ªöC 4: Load t·∫≠p Dev v√† Test (gi·ªØ nguy√™n c√¢u g·ªëc ƒë·ªÉ PhoBERT encode sau n√†y)
    print("ƒêang chu·∫©n b·ªã t·∫≠p Dev v√† Test...")
    dev_sentences = load_data_and_build_vocab([dev_path], is_train=False)
    test_sentences = load_data_and_build_vocab([test_path], is_train=False)

    print(f"S·ªë c√¢u: Train={len(train_sentences)}, Dev={len(dev_sentences)}, Test={len(test_sentences)}")

    # Tr·∫£ v·ªÅ th√™m train_sentences ƒë·ªÉ t√≠ n·ªØa d√πng l·∫•y PhoBERT embedding
    return all_training_data, (p_vocab, d_vocab), (train_sentences, dev_sentences, test_sentences)

In [38]:
import torch
from torch.utils.data import Dataset

class TransitionDataset(Dataset):
    def __init__(self, training_examples):
        self.examples = training_examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        feat, action_id = self.examples[idx]

        # Stack features
        s_words = torch.tensor(feat['stack']['words'], dtype=torch.long)
        s_pos = torch.tensor(feat['stack']['pos'], dtype=torch.long)
        s_lc_idx = torch.tensor(feat['stack']['l_child_idx'], dtype=torch.long)
        s_lc_dep = torch.tensor(feat['stack']['l_child_dep'], dtype=torch.long) # M·ªöI
        s_lc_pos = torch.tensor(feat['stack']['l_child_pos'], dtype=torch.long) # M·ªöI
        s_rc_idx = torch.tensor(feat['stack']['r_child_idx'], dtype=torch.long)
        s_rc_dep = torch.tensor(feat['stack']['r_child_dep'], dtype=torch.long) # M·ªöI
        s_rc_pos = torch.tensor(feat['stack']['r_child_pos'], dtype=torch.long) # M·ªöI

        # Buffer features
        b_words = torch.tensor(feat['buffer']['words'], dtype=torch.long)
        b_pos = torch.tensor(feat['buffer']['pos'], dtype=torch.long)

        return (s_words, s_pos, s_lc_idx, s_lc_dep, s_lc_pos, s_rc_idx, s_rc_dep, s_rc_pos,
                b_words, b_pos, torch.tensor(action_id), feat['sent_id'])

In [39]:
from torch.nn.utils.rnn import pad_sequence
def multiview_collate_fn(batch):
    # Unpack 11 th√†nh ph·∫ßn
    (s_words, s_pos, s_lc_idx, s_lc_dep, s_lc_pos, s_rc_idx, s_rc_dep, s_rc_pos,
     b_words, b_pos, labels, sent_ids) = zip(*batch)

    # Helper function ƒë·ªÉ pad cho g·ªçn
    def pad(tensors): return pad_sequence(tensors, batch_first=True, padding_value=0)

    return (
        pad(s_words), pad(s_pos),
        pad(s_lc_idx), pad(s_lc_dep), pad(s_lc_pos), # Pad LC
        pad(s_rc_idx), pad(s_rc_dep), pad(s_rc_pos), # Pad RC
        pad(b_words), pad(b_pos),
        torch.stack(labels), sent_ids
    )

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SubViewProjector(nn.Module):
    """Chi·∫øu c√°c nh√≥m ƒë·∫∑c tr∆∞ng (Word+POS/Dep) v·ªÅ c√πng m·ªôt kh√¥ng gian ·∫©n."""
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, output_dim),
            nn.LayerNorm(output_dim),
            nn.Mish(),
            nn.Dropout(dropout_rate)
        )
    def forward(self, x):
        return self.net(x)
        
class GatedFusion(nn.Module):
    """S·ª≠ d·ª•ng c·ªïng Sigmoid ƒë·ªÉ ki·ªÉm so√°t lu·ªìng th√¥ng tin t·ª´ Main v√† c√°c con."""
    def __init__(self, dim):
        super().__init__()
        # Input l√† 3 views n·ªëi t·∫ßng (Main, LC, RC) -> Output l√† 3 b·ªô tr·ªçng s·ªë t∆∞∆°ng ·ª©ng
        self.gate_layer = nn.Sequential(
            nn.Linear(dim * 3, dim * 3),
            nn.Sigmoid()
        )
        self.project = nn.Linear(dim * 3, dim)
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, views):
        # views shape: [batch, 3, dim] (ƒë√£ ƒë∆∞·ª£c tr√≠ch xu·∫•t cho t·ª´ng token tr√™n stack)
        # L∆∞u √Ω: ·ªû ƒë√¢y x·ª≠ l√Ω cho t·ª´ng v·ªã tr√≠ tr√™n c·ª≠a s·ªï stack (S0, S1, S2)
        batch, n_tokens, n_views, dim = views.shape
        
        # Flatten views: [batch, n_tokens, dim * 3]
        flat_views = views.view(batch, n_tokens, -1)
        
        # T√≠nh to√°n b·ªô l·ªçc (gate)
        gates = self.gate_layer(flat_views) # [batch, n_tokens, dim * 3]
        
        # √Åp d·ª•ng l·ªçc th√¥ng tin
        gated_info = flat_views * gates
        
        # Chi·∫øu v·ªÅ l·∫°i kh√¥ng gian ·∫©n node_dim
        fused = self.project(gated_info)
        return self.layer_norm(fused)

class BiaffineRelationalProjector(nn.Module):
    """S·ª≠ d·ª•ng Biaffine ƒë·ªÉ tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng quan h·ªá Cha-Con thay cho MLP."""
    def __init__(self, head_dim, child_dim, output_dim, dropout_rate=0.2):
        super().__init__()
        # Ma tr·∫≠n Biaffine ƒë·ªÉ h·ªçc t∆∞∆°ng t√°c (Bilinear)
        # output_dim ·ªü ƒë√¢y ƒë√≥ng vai tr√≤ nh∆∞ c√°c 'feature maps' c·ªßa quan h·ªá
        self.weight = nn.Parameter(torch.Tensor(output_dim, head_dim + 1, child_dim + 1))
        
        self.norm = nn.LayerNorm(output_dim)
        self.mish = nn.Mish()
        self.dropout = nn.Dropout(dropout_rate)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)

    def forward(self, head_embed, child_embed):
        # 1. Th√™m bias cho c·∫£ head v√† child
        # head: [batch, dim] -> [batch, dim+1]
        h = torch.cat([head_embed, head_embed.new_ones(head_embed.shape[:-1]).unsqueeze(-1)], -1)
        c = torch.cat([child_embed, child_embed.new_ones(child_embed.shape[:-1]).unsqueeze(-1)], -1)
        
        # 2. T√≠nh t∆∞∆°ng t√°c Biaffine: [batch, output_dim]
        # M·ªói chi·ªÅu c·ªßa output_dim s·∫Ω h·ªçc m·ªôt lo·∫°i 'kh·ªõp n·ªëi' kh√°c nhau gi·ªØa cha v√† con
        rel_feat = torch.einsum('bi,oij,bj->bo', h, self.weight, c)
        
        return self.dropout(self.mish(self.norm(rel_feat)))

class BiaffineDependencyModel(nn.Module):
    def __init__(self,
                 pos_vocab_size, dep_vocab_size, num_actions,
                 phobert_dim=768, pos_dim=64, dep_dim=64,
                 node_dim=256, hidden_dim=1024, dropout_rate=0.5):
        super().__init__()

        self.node_dim = node_dim

        self.pos_embed = nn.Embedding(pos_vocab_size, pos_dim, padding_idx=0)
        self.dep_embed = nn.Embedding(dep_vocab_size, dep_dim, padding_idx=0)

        # 1. Projector cho Main Word
        self.main_proj = SubViewProjector(phobert_dim + pos_dim, node_dim * 2, node_dim)
        
        # 2. Relational Projectors cho Children (D√πng th√¥ng tin cha ƒë·ªÉ hi·ªÉu con)
        # Input dim = node_dim (c·ªßa cha) + (phobert + dep + pos c·ªßa con)
        child_input_dim = phobert_dim + dep_dim + pos_dim
        self.lc_rel_proj = BiaffineRelationalProjector(node_dim, child_input_dim, node_dim)
        self.rc_rel_proj = BiaffineRelationalProjector(node_dim, child_input_dim, node_dim)

        self.fusion = GatedFusion(node_dim)
        self.buffer_proj = SubViewProjector(phobert_dim + pos_dim, node_dim * 2, node_dim)

        combined_dim = 4 * node_dim

        # 3. Classifier
        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(), # D√πng Mish cho m∆∞·ª£t h∆°n ReLU
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Mish(),
            nn.Linear(hidden_dim // 2, num_actions)
        )
        self.apply(self._init_weights)

    def forward(self, s_vecs, s_pos, s_lc_vecs, s_lc_dep, s_lc_pos, s_rc_vecs, s_rc_dep, s_rc_pos, b_vecs, b_pos):
        # --- B∆Ø·ªöC 1: Chi·∫øu n√∫t g·ªëc (Main Head) ---
        v_main_s = self.main_proj(torch.cat([s_vecs, self.pos_embed(s_pos)], dim=-1))

        # --- B∆Ø·ªöC 2: Chi·∫øu Relational cho n√∫t con ---
        # Ch√∫ng ta d√πng v_main_s nh∆∞ ng·ªØ c·∫£nh ƒë·ªÉ "nh√¨n" c√°c n√∫t con
        feat_lc = torch.cat([s_lc_vecs, self.dep_embed(s_lc_dep), self.pos_embed(s_lc_pos)], dim=-1)
        feat_rc = torch.cat([s_rc_vecs, self.dep_embed(s_rc_dep), self.pos_embed(s_rc_pos)], dim=-1)
        
        v_lc = self.lc_rel_proj(v_main_s.view(-1, self.node_dim), feat_lc.view(-1, feat_lc.size(-1)))
        v_rc = self.rc_rel_proj(v_main_s.view(-1, self.node_dim), feat_rc.view(-1, feat_rc.size(-1)))
        
        # ƒê∆∞a v·ªÅ shape c≈©: [batch, 3, node_dim]
        v_lc = v_lc.view(v_main_s.shape)
        v_rc = v_rc.view(v_main_s.shape)

        # --- B∆Ø·ªöC 3: G·ªôp th√¥ng tin ---
        stack_views = torch.stack([v_main_s, v_lc, v_rc], dim=2)
        s_node_emb = self.fusion(stack_views) + v_main_s 
        
        # Buffer & Feature Selection gi·ªØ nguy√™n logic c≈©
        b_node_emb = self.buffer_proj(torch.cat([b_vecs, self.pos_embed(b_pos)], dim=-1))

        s0 = s_node_emb[:, -1, :] # Top 1 stack
        s1 = s_node_emb[:, -2, :] # Top 2 stack

        b0 = b_node_emb[:, 0, :]
        b1 = b_node_emb[:, 1, :]
  
        combined = torch.cat([s0, s1, b0, b1], dim=-1)
        return self.classifier(combined)
        
    def _init_weights(self, module):
        """H√†m n√†y s·∫Ω ƒë∆∞·ª£c apply ƒë·ªá quy cho m·ªçi module con"""
        if isinstance(module, nn.Linear):
            # Kh·ªüi t·∫°o Linear: Xavier ph√π h·ª£p v·ªõi Mish/Tanh
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
                
        elif isinstance(module, nn.LayerNorm):
            # KH√îNG d√πng Xavier cho LayerNorm. Gi·ªØ nguy√™n 1.0 v√† 0.0
            nn.init.constant_(module.weight, 1.0)
            nn.init.constant_(module.bias, 0.0)
            
        elif isinstance(module, nn.Embedding):
            # Kh·ªüi t·∫°o Embedding m∆∞·ª£t h∆°n cho ti·∫øng Vi·ªát
            nn.init.normal_(module.weight, mean=0, std=0.02)


In [41]:
!pip install scikit-learn
from sklearn.feature_extraction.text import TfidfVectorizer

def compute_tfidf_weights(sentences):
    """
    sentences: List c√°c list ch·ª©a token dict: [[{'form': 'H·ªçc_sinh', ...}, ...], ...]
    Tr·∫£ v·ªÅ: Dictionary {word: weight}
    """
    # Chuy·ªÉn ƒë·ªïi format t·ª´ list of dicts sang list of strings (c√°ch nhau b·ªüi kho·∫£ng tr·∫Øng)
    corpus = [" ".join([token['form'] for token in sent]) for sent in sentences]
    
    vectorizer = TfidfVectorizer(token_pattern=r"(?u)\b\w+\b") # Gi·ªØ nguy√™n c√°c t·ª´ c√≥ d·∫•u/g·∫°ch d∆∞·ªõi
    tfidf_matrix = vectorizer.fit_transform(corpus)
    
    # L·∫•y gi√° tr·ªã TF-IDF trung b√¨nh c·ªßa t·ª´ng t·ª´ tr√™n to√†n b·ªô t·∫≠p d·ªØ li·ªáu l√†m tr·ªçng s·ªë
    # (Ho·∫∑c b·∫°n c√≥ th·ªÉ l·∫•y gi√° tr·ªã Max t√πy v√†o m·ª•c ƒë√≠ch)
    weights = tfidf_matrix.mean(axis=0).A1
    words = vectorizer.get_feature_names_out()
    
    return dict(zip(words, weights))

[0m

In [42]:
def pre_encode_sentences(sentences, model_name="vinai/phobert-base", device='cuda'):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    phobert = AutoModel.from_pretrained(model_name).to(device)
    phobert.eval()

    encoded_bank = {}
    for idx, sent in enumerate(tqdm(sentences)):
        # CH·ªàNH S·ª¨A T·∫†I ƒê√ÇY: Chuy·ªÉn kho·∫£ng tr·∫Øng th√†nh g·∫°ch d∆∞·ªõi ƒë·ªÉ kh·ªõp v·ªõi vocab PhoBERT
        words = [token['form'].replace(' ', '_') for token in sent] 

        inputs = tokenizer(words,
                           is_split_into_words=True,
                           return_tensors="pt",
                           truncation=True,
                           max_length=256).to(device)

        with torch.no_grad():
            outputs = phobert(**inputs)
            full_embeds = outputs.last_hidden_state.squeeze(0)

        word_level_embeds = []
        ptr = 1 
        for word in words:
            # Tokenizer l√∫c n√†y s·∫Ω x·ª≠ l√Ω 'Tr·∫£_l·ªùi' nh∆∞ m·ªôt ƒë∆°n v·ªã
            sub_tokens_for_word = tokenizer.tokenize(word)
            num_subtokens = len(sub_tokens_for_word)

            word_vecs = full_embeds[ptr : ptr + num_subtokens]
            # N·∫øu 'Tr·∫£_l·ªùi' b·ªã t√°ch th√†nh nhi·ªÅu m·∫£nh (sub-words), 
            # Mean Pooling ·ªü ƒë√¢y s·∫Ω g·ªôp ch√∫ng l·∫°i th√†nh 1 vector duy nh·∫•t cho t·ª´ ƒë√≥
            word_mean_vec = word_vecs.mean(dim=0) if word_vecs.size(0) > 0 else torch.zeros(768, device=device)
            word_level_embeds.append(word_mean_vec)
            ptr += num_subtokens

        root_vector = full_embeds[0]
        final_matrix = torch.stack([root_vector] + word_level_embeds)
        encoded_bank[idx] = final_matrix.cpu()

    return encoded_bank

In [43]:
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

def pre_encode_sentences_tfidf(sentences, model_name="vinai/phobert-base", device='cuda'):
    # 1. Kh·ªüi t·∫°o Model & Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    phobert = AutoModel.from_pretrained(model_name).to(device)
    phobert.eval()

    # 2. T√≠nh to√°n TF-IDF weights tr∆∞·ªõc
    print("--- ƒêang t√≠nh to√°n tr·ªçng s·ªë TF-IDF ---")
    tfidf_dict = compute_tfidf_weights(sentences)

    encoded_bank = {}
    print(f"--- ƒêang tr√≠ch xu·∫•t TF-IDF Weighted Embedding cho {len(sentences)} c√¢u ---")

    for idx, sent in enumerate(tqdm(sentences)):
        words = [token['form'] for token in sent]

        # Encode c√¢u
        inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt",
                           truncation=True, max_length=256).to(device)

        with torch.no_grad():
            outputs = phobert(**inputs)
            full_embeds = outputs.last_hidden_state.squeeze(0)

        word_level_embeds = []
        ptr = 1  # B·ªè qua token <s>

        for word in words:
            sub_tokens = tokenizer.tokenize(word)
            num_subtokens = len(sub_tokens)
            
            # L·∫•y c√°c vector sub-tokens v√† t√≠nh Mean ƒë·ªÉ ra vector t·ª´ g·ªëc
            word_vecs = full_embeds[ptr : ptr + num_subtokens]
            
            if word_vecs.size(0) > 0:
                # Mean pooling cho c√°c sub-tokens
                mean_vec = word_vecs.mean(dim=0)
                
                # √Åp d·ª•ng tr·ªçng s·ªë TF-IDF (m·∫∑c ƒë·ªãnh l√† 1.0 n·∫øu t·ª´ kh√¥ng c√≥ trong TF-IDF)
                # L∆∞u √Ω: .lower() ƒë·ªÉ ƒë·ªìng b·ªô h√≥a key
                weight = tfidf_dict.get(word.lower(), 1.0)
                
                # ƒê·ªÉ tr√°nh tr·ªçng s·ªë qu√° nh·ªè (g·∫ßn b·∫±ng 0), b·∫°n c√≥ th·ªÉ d√πng c√¥ng th·ª©c: (1 + weight)
                weighted_vec = mean_vec * weight
                word_level_embeds.append(weighted_vec)
            else:
                word_level_embeds.append(torch.zeros(768, device=device))

            ptr += num_subtokens

        # Vector ROOT (<s>) - Th∆∞·ªùng gi·ªØ nguy√™n tr·ªçng s·ªë 1.0
        root_vector = full_embeds[0]

        final_matrix = torch.stack([root_vector] + word_level_embeds)
        encoded_bank[idx] = final_matrix.cpu()

    return encoded_bank

In [44]:
# T·∫°o b·∫£n ƒë·ªì metadata cho h√†nh ƒë·ªông ƒë·ªÉ tr√°nh parse chu·ªói trong v√≤ng l·∫∑p
def build_action_metadata(id_to_action, d_vocab):
    action_meta = {}
    for act_id, name in id_to_action.items():
        if name == "SHIFT":
            action_meta[act_id] = {'type': 'S', 'label': None, 'dep_id': 0}
        elif name.startswith("LEFT-ARC"):
            label = name.split("_")[-1]
            action_meta[act_id] = {'type': 'L', 'label': label, 'dep_id': d_vocab.get(label, 0)}
        elif name.startswith("RIGHT-ARC"):
            label = name.split("_")[-1]
            action_meta[act_id] = {'type': 'R', 'label': label, 'dep_id': d_vocab.get(label, 0)}
    return action_meta

# G·ªçi h√†m n√†y TR∆Ø·ªöC khi v√†o v√≤ng l·∫∑p ƒë√°nh gi√° (ngo√†i h√†m process_one_sentence)
# action_meta = build_action_metadata(id_to_action, oracle.d_vocab)

In [45]:
import copy
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def train(model, train_loader, train_encoded_bank, dev_sentences, dev_encoded_bank,
          oracle, epochs=20, lr=0.001, weight_decay=1e-4, device='cuda'):

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # --- TH√äM SCHEDULER ---
    # Gi·∫£m LR ƒëi 0.5 l·∫ßn n·∫øu UAS kh√¥ng tƒÉng sau 3 epoch
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, min_lr=1e-5,
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.05) # Gi·∫£m smoothing xu·ªëng 0.05 nh∆∞ ƒë√£ b√†n
    model.to(device)

    best_uas = 0.0
    best_model_state = None
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        batch_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for (s_words, s_pos, s_lc_idx, s_lc_dep, s_lc_pos, s_rc_idx, s_rc_dep, s_rc_pos,
             b_words, b_pos, labels, sent_ids) in batch_bar: 

            optimizer.zero_grad()

            # Helper ƒë·ªÉ fetch vector t·ª´ bank
            def fetch_vecs(indices, sids):
                return torch.stack([
                    train_encoded_bank[sid][indices[i]]
                    for i, sid in enumerate(sids)
                ]).to(device)

            s_vecs    = fetch_vecs(s_words, sent_ids)
            s_lc_vecs = fetch_vecs(s_lc_idx, sent_ids)
            s_rc_vecs = fetch_vecs(s_rc_idx, sent_ids)
            b_vecs    = fetch_vecs(b_words, sent_ids)

            # Move auxiliary data to device
            s_pos, s_lc_dep, s_lc_pos, s_rc_dep, s_rc_pos, b_pos, labels = [
                x.to(device) for x in [s_pos, s_lc_dep, s_lc_pos, s_rc_dep, s_rc_pos, b_pos, labels]
            ]

            # Forward
            logits = model(s_vecs, s_pos, s_lc_vecs, s_lc_dep, s_lc_pos, s_rc_vecs, s_rc_dep, s_rc_pos, b_vecs, b_pos)

            loss = criterion(logits, labels)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            batch_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optimizer.param_groups[0]['lr']:.6f}")

        # --- 4. EVALUATION ---
        print(f"\nüîç ƒêang ƒë√°nh gi√° Epoch {epoch+1}...")
        # should_eval_full = (epoch + 1) % 3 == 0 or epoch > 15
        # eval_sentences = dev_sentences if should_eval_full else dev_sentences[:200] # Subset 200 c√¢u
        uas, las = evaluate_model_batch(
            model, oracle, dev_sentences, dev_encoded_bank, device=device
        )
        print(f"‚ú® K·∫øt qu·∫£: UAS: {uas*100:.2f}% | LAS: {las*100:.2f}%")

        # --- C·∫¨P NH·∫¨T SCHEDULER D·ª∞A TR√äN UAS ---
        scheduler.step(uas)

        # --- 5. SAVE MODEL IF IMPROVED ---
        if uas > best_uas:
            best_uas = uas
            best_model_state = copy.deepcopy(model.state_dict())
            uas_str = f"{uas:.4f}"
            save_path = f"best_model_bi_linear_scorer_{epoch+1}_{uas_str}.pth"
            torch.save(best_model_state, save_path)
            print(f"‚úÖ New best model saved: {save_path}")

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"ü•á Ho√†n t·∫•t. UAS cao nh·∫•t: {best_uas*100:.2f}%")

    return model

In [46]:
def process_one_sentence(sentence, sent_idx, model, encoded_bank, oracle, action_meta, device):
    model.eval()
    stack, buffer, arcs = [0], list(range(1, len(sentence) + 1)), []
    left_children, right_children = {}, {}

    # Chuy·ªÉn matrix sang device 1 l·∫ßn duy nh·∫•t
    sent_matrix = encoded_bank[sent_idx].to(device)

    while buffer or len(stack) > 1:
        # 1. Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng (v·∫´n d√πng h√†m extract_features ƒë√£ t·ªëi ∆∞u c·ªßa b·∫°n)
        feat = oracle.extract_features(stack, buffer, sentence, sent_idx, left_children, right_children)
        s_f, b_f = feat['stack'], feat['buffer']

        # 2. CHU·∫®N B·ªä TENSORS (T·ªëi ∆∞u: Ch·ªâ g·ªçi .to(device) 1 l·∫ßn cho m·ªói nh√≥m)
        # C√°c feature ID th∆∞·ªùng nh·ªè, gom nh√≥m ƒë·ªÉ chuy·ªÉn device hi·ªáu qu·∫£ h∆°n
        s_idx = torch.tensor(s_f['words'], dtype=torch.long, device=device)
        slc_idx = torch.tensor(s_f['l_child_idx'], dtype=torch.long, device=device)
        src_idx = torch.tensor(s_f['r_child_idx'], dtype=torch.long, device=device)
        
        # Nh·∫∑t vector PhoBERT
        s_vecs, s_lc_vecs, s_rc_vecs = sent_matrix[s_idx].unsqueeze(0), sent_matrix[slc_idx].unsqueeze(0), sent_matrix[src_idx].unsqueeze(0)

        # Labels ph·ª•
        s_pos = torch.tensor(s_f['pos'], dtype=torch.long, device=device).unsqueeze(0)
        s_lc_dep = torch.tensor(s_f['l_child_dep'], dtype=torch.long, device=device).unsqueeze(0)
        s_rc_dep = torch.tensor(s_f['r_child_dep'], dtype=torch.long, device=device).unsqueeze(0)
        s_lc_pos = torch.tensor(s_f['l_child_pos'], dtype=torch.long, device=device).unsqueeze(0)
        s_rc_pos = torch.tensor(s_f['r_child_pos'], dtype=torch.long, device=device).unsqueeze(0)

        b_idx = torch.tensor(b_f['words'], dtype=torch.long, device=device)
        b_vecs = sent_matrix[b_idx].unsqueeze(0)
        b_pos = torch.tensor(b_f['pos'], dtype=torch.long, device=device).unsqueeze(0)

        # 3. D·ª∞ ƒêO√ÅN
        with torch.no_grad():
            logits = model(s_vecs, s_pos, s_lc_vecs, s_lc_dep, s_lc_pos, s_rc_vecs, s_rc_dep, s_rc_pos, b_vecs, b_pos)
            # L·∫•y list ID h√†nh ƒë·ªông ƒë√£ s·∫Øp x·∫øp
            sorted_act_ids = torch.argsort(logits[0], descending=True).tolist()

        # 4. √ÅP D·ª§NG H√ÄNH ƒê·ªòNG (D√πng action_meta ƒë·ªÉ b·ªè qua parse chu·ªói)
        applied = False
        for act_id in sorted_act_ids:
            meta = action_meta[act_id]
            act_type = meta['type']
            
            if act_type == 'S' and buffer: # SHIFT
                stack.append(buffer.pop(0))
                applied = True
                break
                
            elif act_type == 'L' and len(stack) >= 2 and stack[-2] != 0: # LEFT-ARC
                head, child = stack[-1], stack[-2]
                arcs.append((head, child, meta['label']))
                
                # L∆∞u v√†o children map d√πng ID c√≥ s·∫µn trong meta v√† sentence
                if head not in left_children: left_children[head] = []
                left_children[head].append((child, meta['dep_id'], sentence[child-1]['upos_id']))
                
                stack.pop(-2)
                applied = True
                break
                
            elif act_type == 'R' and len(stack) >= 2: # RIGHT-ARC
                head, child = stack[-2], stack[-1]
                arcs.append((head, child, meta['label']))
                
                # L∆∞u v√†o children map
                if head not in right_children: right_children[head] = []
                right_children[head].append((child, meta['dep_id'], sentence[child-1]['upos_id']))
                
                stack.pop(-1)
                applied = True
                break

        if not applied: break

    # 5. T√çNH UAS/LAS (Gi·ªØ nguy√™n logic ch√≠nh nh∆∞ng vi·∫øt g·ªçn)
    gold_heads = {t['id']: t['head'] for t in sentence}
    gold_deps = {t['id']: t['deprel'] for t in sentence}
    pred_info = {child: (head, label) for head, child, label in arcs}

    correct_uas = sum(1 for c, g_h in gold_heads.items() if c in pred_info and pred_info[c][0] == g_h)
    correct_las = sum(1 for c, g_h in gold_heads.items() if c in pred_info and pred_info[c][0] == g_h and pred_info[c][1] == gold_deps[c])

    return correct_uas, correct_las, len(gold_heads)

In [47]:
import torch
import torch.nn.functional as F

def process_one_sentence_beam(sentence, sent_idx, model, encoded_bank, oracle, id_to_action, device, beam_size=3):
    model.eval()
    
    # Ma tr·∫≠n PhoBERT cho c√¢u (Lookup table)
    sent_matrix = encoded_bank[sent_idx].to(device)
    
    # M·ªôt tr·∫°ng th√°i bao g·ªìm: (stack, buffer, arcs, left_children, right_children, log_score, is_finished)
    # Kh·ªüi t·∫°o: stack=[0], buffer=1->N, arcs=[], lc={}, rc={}, score=0.0, finished=False
    initial_state = ([0], list(range(1, len(sentence) + 1)), [], {}, {}, 0.0, False)
    beam = [initial_state]

    # S·ªë b∆∞·ªõc t·ªëi ƒëa trong transition-based parsing (Arc-Standard/Arc-Eager) th∆∞·ªùng l√† 2*n
    max_steps = 2 * len(sentence)

    for _ in range(max_steps):
        candidates = []
        active_indices = []
        
        # --- 1. BATCHING: Collect features for all active candidates ---
        batch_inputs = {
            's_idx': [], 'slc_idx': [], 'src_idx': [],
            's_pos': [], 's_lc_dep': [], 's_lc_pos': [], 's_rc_dep': [], 's_rc_pos': [],
            'b_idx': [], 'b_pos': []
        }

        for i, state in enumerate(beam):
            stack, buffer, arcs, lc, rc, score, finished = state
            # N·∫øu tr·∫°ng th√°i ƒë√£ k·∫øt th√∫c ho·∫∑c kh√¥ng th·ªÉ th·ª±c hi·ªán th√™m h√†nh ƒë·ªông
            if finished or (not buffer and len(stack) <= 1):
                candidates.append(state)
                continue
            
            active_indices.append(i)
            
            # --- 1. TR√çCH XU·∫§T ƒê·∫∂C TR∆ØNG T·ª™ TR·∫†NG TH√ÅI HI·ªÜN T·∫†I ---
            feat = oracle.extract_features(stack, buffer, sentence, sent_idx, lc, rc)
            
            # Collect features for batching
            s_f, b_f = feat['stack'], feat['buffer']
            batch_inputs['s_idx'].append(s_f['words'])
            batch_inputs['slc_idx'].append(s_f['l_child_idx'])
            batch_inputs['src_idx'].append(s_f['r_child_idx'])
            batch_inputs['s_pos'].append(s_f['pos'])
            batch_inputs['s_lc_dep'].append(s_f['l_child_dep'])
            batch_inputs['s_rc_dep'].append(s_f['r_child_dep'])
            batch_inputs['s_lc_pos'].append(s_f['l_child_pos'])
            batch_inputs['s_rc_pos'].append(s_f['r_child_pos'])
            batch_inputs['b_idx'].append(b_f['words'])
            batch_inputs['b_pos'].append(b_f['pos'])

        if not active_indices:
            break

        # --- 2. PREPARE BATCH TENSORS ---
        def to_tensor(data):
            return torch.tensor(data, dtype=torch.long, device=device)

        s_idx_t = to_tensor(batch_inputs['s_idx'])
        slc_idx_t = to_tensor(batch_inputs['slc_idx'])
        src_idx_t = to_tensor(batch_inputs['src_idx'])
        s_pos_t = to_tensor(batch_inputs['s_pos'])
        s_lc_dep_t = to_tensor(batch_inputs['s_lc_dep'])
        s_rc_dep_t = to_tensor(batch_inputs['s_rc_dep'])
        s_lc_pos_t = to_tensor(batch_inputs['s_lc_pos'])
        s_rc_pos_t = to_tensor(batch_inputs['s_rc_pos'])
        b_idx_t = to_tensor(batch_inputs['b_idx'])
        b_pos_t = to_tensor(batch_inputs['b_pos'])

        # Lookup embeddings (Batch lookup)
        s_vecs = sent_matrix[s_idx_t]
        s_lc_vecs = sent_matrix[slc_idx_t]
        s_rc_vecs = sent_matrix[src_idx_t]
        b_vecs = sent_matrix[b_idx_t]

        # --- 3. RUN MODEL ON BATCH ---
        with torch.no_grad():
            logits = model(
                s_vecs, s_pos_t, s_lc_vecs, s_lc_dep_t, s_lc_pos_t,
                s_rc_vecs, s_rc_dep_t, s_rc_pos_t, b_vecs, b_pos_t
            )
            log_probs = F.log_softmax(logits, dim=1)

        # --- 4. GENERATE CANDIDATES ---
        for idx_in_batch, beam_idx in enumerate(active_indices):
            stack, buffer, arcs, lc, rc, score, finished = beam[beam_idx]
            probs = log_probs[idx_in_batch]
            
            # L·∫•y top K h√†nh ƒë·ªông ƒë·ªÉ gi·∫£m chi ph√≠ t√≠nh to√°n
            top_k = torch.topk(probs, k=min(beam_size * 2, len(probs)))
            
            for i in range(top_k.values.size(0)):
                act_id = top_k.indices[i].item()
                log_p = top_k.values[i].item()
                
                action_name = id_to_action[act_id]
                new_score = score + log_p
                
                # 1. H√†nh ƒë·ªông SHIFT
                if action_name == "SHIFT" and buffer:
                    new_stack = stack + [buffer[0]]
                    new_buffer = buffer[1:]
                    candidates.append((new_stack, new_buffer, list(arcs), lc, rc, new_score, False))
                
                # 2. H√†nh ƒë·ªông LEFT-ARC
                elif action_name.startswith("LEFT-ARC") and len(stack) >= 2:
                    if stack[-2] != 0: # Kh√¥ng ƒë·ªÉ ROOT l√†m con
                        label = action_name.split("_")[-1]
                        dep_id = oracle.d_vocab.get(label, 0)
                        
                        # --- M·ªöI: Tr√≠ch xu·∫•t POS ID c·ªßa child (stack[-2]) ---
                        child_idx = stack[-2]
                        child_pos_str = sentence[child_idx - 1]['upos']
                        child_pos_id = oracle.p_vocab.get(child_pos_str, 0)
                        
                        new_arcs = arcs + [(stack[-1], stack[-2], label)]
                        new_stack = stack[:]
                        new_stack.pop(-2)
                        
                        # C·∫≠p nh·∫≠t map con tr√°i: th√™m b·ªô 3 (index, dep_id, pos_id)
                        new_lc = lc.copy()
                        head = stack[-1]
                        new_lc[head] = new_lc.get(head, []) + [(child_idx, dep_id, child_pos_id)]
                        
                        candidates.append((new_stack, list(buffer), new_arcs, new_lc, rc, new_score, False))
                
                # 3. H√†nh ƒë·ªông RIGHT-ARC
                elif action_name.startswith("RIGHT-ARC") and len(stack) >= 2:
                    label = action_name.split("_")[-1]
                    dep_id = oracle.d_vocab.get(label, 0)
                    
                    # --- M·ªöI: Tr√≠ch xu·∫•t POS ID c·ªßa child (stack[-1]) ---
                    child_idx = stack[-1]
                    child_pos_str = sentence[child_idx - 1]['upos']
                    child_pos_id = oracle.p_vocab.get(child_pos_str, 0)
                    
                    new_arcs = arcs + [(stack[-2], stack[-1], label)]
                    new_stack = stack[:]
                    new_stack.pop(-1)

                    # C·∫≠p nh·∫≠t map con ph·∫£i: th√™m b·ªô 3 (index, dep_id, pos_id)
                    new_rc = rc.copy()
                    head = stack[-2]
                    new_rc[head] = new_rc.get(head, []) + [(child_idx, dep_id, child_pos_id)]
                    
                    candidates.append((new_stack, list(buffer), new_arcs, lc, new_rc, new_score, False))

        # --- 5. PRUNING ---
        beam = sorted(candidates, key=lambda x: x[5], reverse=True)[:beam_size]

    # --- 5. L·∫§Y K·∫æT QU·∫¢ T·ªêT NH·∫§T T·ª™ BEAM ---
    # State: (stack, buffer, arcs, lc, rc, score, finished)
    best_arcs = beam[0][2]

    # --- 6. T√çNH UAS/LAS ---
    gold_heads = {t['id']: t['head'] for t in sentence}
    gold_deps = {t['id']: t['deprel'] for t in sentence}
    pred_info = {child: (head, label) for head, child, label in best_arcs}

    correct_uas, correct_las = 0, 0
    for child, g_head in gold_heads.items():
        if child in pred_info:
            p_head, p_label = pred_info[child]
            if p_head == g_head:
                correct_uas += 1
                if p_label == gold_deps[child]:
                    correct_las += 1

    return correct_uas, correct_las, len(gold_heads)

In [48]:
def evaluate_model_parallel(model, oracle, sentences, encoded_bank, device='cpu', is_training=True):
    model.eval()
    model.to(device) # Chuy·ªÉn model v·ªÅ device (cuda ho·∫∑c cpu)

    id_to_action = {v: k for k, v in oracle.action_map.items()}
    total_uas, total_las, total_deps = 0, 0, 0

    action_meta = build_action_metadata(id_to_action, oracle.d_vocab)

    # S·ª≠ d·ª•ng tqdm ƒë·ªÉ th·∫•y ti·∫øn tr√¨nh
    with torch.no_grad(): # T·∫Øt gradient ƒë·ªÉ ti·∫øt ki·ªám b·ªô nh·ªõ
        for idx, sent in enumerate(tqdm(sentences, desc="Evaluating")):
            # G·ªçi process_one_sentence v·ªõi logic tr√≠ch xu·∫•t feature m·ªõi
            if not is_training:
                c_uas, c_las, c_deps = process_one_sentence_beam(
                    sent, idx, model, encoded_bank, oracle, id_to_action, device
                )
            else:
                c_uas, c_las, c_deps = process_one_sentence(
                    sent, idx, model, encoded_bank, oracle, action_meta, device
                )
            total_uas += c_uas
            total_las += c_las
            total_deps += c_deps

    uas = total_uas / total_deps if total_deps > 0 else 0
    las = total_las / total_deps if total_deps > 0 else 0

    return uas, las

In [49]:
def process_batch_sentences(batch_sentences, batch_indices, model, encoded_bank, oracle, action_meta, device):
    model.eval()
    n_sentences = len(batch_sentences)
    
    states = []
    for sent in batch_sentences:
        states.append({
            'stack': [0],
            'buffer': list(range(1, len(sent) + 1)),
            'arcs': [],
            'lc': {},
            'rc': {},
            'finished': False
        })

    max_steps = max(len(s) for s in batch_sentences) * 2
    
    for _ in range(max_steps):
        active_indices = [i for i, s in enumerate(states) if not s['finished']]
        if not active_indices:
            break
            
        # 1. Thu th·∫≠p Tensors ƒë√£ tr√≠ch xu·∫•t (c√≥ k√≠ch th∆∞·ªõc c·ªë ƒë·ªãnh theo window size)
        list_s_vecs, list_slc_vecs, list_src_vecs, list_b_vecs = [], [], [], []
        list_s_pos, list_slc_dep, list_slc_pos, list_src_dep, list_src_pos = [], [], [], [], []
        list_b_pos = []
        
        for i in active_indices:
            st = states[i]
            sent_idx = batch_indices[i]
            # Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng th√¥ (indices)
            f = oracle.extract_features(st['stack'], st['buffer'], batch_sentences[i], 
                                        sent_idx, st['lc'], st['rc'])
            
            # L·∫•y ma tr·∫≠n embedding c·ªßa c√¢u t·ª´ bank
            sent_mat = encoded_bank[sent_idx].to(device)
            
            # --- TR√çCH XU·∫§T VECTOR C·ªê ƒê·ªäNH T·∫†I ƒê√ÇY ---
            # f['stack']['words'] lu√¥n c√≥ k√≠ch th∆∞·ªõc c·ªë ƒë·ªãnh (m·∫∑c ƒë·ªãnh l√† 3)
            s_idx_t = torch.tensor(f['stack']['words'], dtype=torch.long, device=device)
            slc_idx_t = torch.tensor(f['stack']['l_child_idx'], dtype=torch.long, device=device)
            src_idx_t = torch.tensor(f['stack']['r_child_idx'], dtype=torch.long, device=device)
            b_idx_t = torch.tensor(f['buffer']['words'], dtype=torch.long, device=device)

            list_s_vecs.append(sent_mat[s_idx_t])
            list_slc_vecs.append(sent_mat[slc_idx_t])
            list_src_vecs.append(sent_mat[src_idx_t])
            list_b_vecs.append(sent_mat[b_idx_t])

            # Thu th·∫≠p c√°c nh√£n POS/Dep
            def to_long_t(data): return torch.tensor(data, dtype=torch.long, device=device)
            list_s_pos.append(to_long_t(f['stack']['pos']))
            list_slc_dep.append(to_long_t(f['stack']['l_child_dep']))
            list_slc_pos.append(to_long_t(f['stack']['l_child_pos']))
            list_src_dep.append(to_long_t(f['stack']['r_child_dep']))
            list_src_pos.append(to_long_t(f['stack']['r_child_pos']))
            list_b_pos.append(to_long_t(f['buffer']['pos']))

        # 2. B√¢y gi·ªù c√≥ th·ªÉ stack v√¨ m·ªçi tensor trong list ƒë·ªÅu c√πng k√≠ch th∆∞·ªõc window
        s_vecs = torch.stack(list_s_vecs)
        s_lc_vecs = torch.stack(list_slc_vecs)
        s_rc_vecs = torch.stack(list_src_vecs)
        b_vecs = torch.stack(list_b_vecs)
        
        s_pos = torch.stack(list_s_pos)
        s_lc_dep = torch.stack(list_slc_dep)
        s_lc_pos = torch.stack(list_slc_pos)
        s_src_dep = torch.stack(list_src_dep)
        s_src_pos = torch.stack(list_src_pos)
        b_pos = torch.stack(list_b_pos)

        # 3. Batch Forward
        with torch.no_grad():
            logits = model(s_vecs, s_pos, s_lc_vecs, s_lc_dep, s_lc_pos, 
                           s_rc_vecs, s_src_dep, s_src_pos, b_vecs, b_pos)
            top_actions = torch.argsort(logits, dim=1, descending=True)

        # 4. C·∫≠p nh·∫≠t tr·∫°ng th√°i t·ª´ng c√¢u (Gi·ªØ nguy√™n logic transition)
        for idx_in_active, i in enumerate(active_indices):
            st = states[i]
            sentence = batch_sentences[i]
            applied = False
            
            for act_id in top_actions[idx_in_active].tolist():
                meta = action_meta[act_id]
                # Logic x·ª≠ l√Ω SHIFT, LEFT-ARC, RIGHT-ARC t∆∞∆°ng t·ª± code c≈© c·ªßa b·∫°n
                # ... (gi·ªØ nguy√™n ph·∫ßn √°p d·ª•ng h√†nh ƒë·ªông) ...
                if meta['type'] == 'S' and st['buffer']:
                    st['stack'].append(st['buffer'].pop(0)); applied = True; break
                elif meta['type'] == 'L' and len(st['stack']) >= 2 and st['stack'][-2] != 0:
                    h, c = st['stack'][-1], st['stack'][-2]
                    st['arcs'].append((h, c, meta['label']))
                    if h not in st['lc']: st['lc'][h] = []
                    st['lc'][h].append((c, meta['dep_id'], sentence[c-1]['upos_id']))
                    st['stack'].pop(-2); applied = True; break
                elif meta['type'] == 'R' and len(st['stack']) >= 2:
                    h, c = st['stack'][-2], st['stack'][-1]
                    st['arcs'].append((h, c, meta['label']))
                    if h not in st['rc']: st['rc'][h] = []
                    st['rc'][h].append((c, meta['dep_id'], sentence[c-1]['upos_id']))
                    st['stack'].pop(-1); applied = True; break
            
            if not applied or (not st['buffer'] and len(st['stack']) <= 1):
                st['finished'] = True

    # 5. T√≠nh to√°n k·∫øt qu·∫£ UAS/LAS cu·ªëi batch
    b_uas, b_las, b_deps = 0, 0, 0
    for i in range(n_sentences):
        gold_heads = {t['id']: t['head'] for t in batch_sentences[i]}
        gold_deps = {t['id']: t['deprel'] for t in batch_sentences[i]}
        pred_info = {c: (h, l) for h, c, l in states[i]['arcs']}
        b_uas += sum(1 for c, gh in gold_heads.items() if c in pred_info and pred_info[c][0] == gh)
        b_las += sum(1 for c, gh in gold_heads.items() if c in pred_info and pred_info[c][0] == gh and pred_info[c][1] == gold_deps[c])
        b_deps += len(gold_heads)
        
    return b_uas, b_las, b_deps

In [50]:
def evaluate_model_batch(model, oracle, sentences, encoded_bank, batch_size=64, device='cuda'):
    model.eval()
    model.to(device)
    
    id_to_action = {v: k for k, v in oracle.action_map.items()}
    action_meta = build_action_metadata(id_to_action, oracle.d_vocab)
    
    total_uas, total_las, total_deps = 0, 0, 0
    
    # Chia sentences th√†nh c√°c batch
    for i in tqdm(range(0, len(sentences), batch_size), desc="Batch Evaluating"):
        batch_sents = sentences[i : i + batch_size]
        batch_idxs = list(range(i, i + len(batch_sents)))
        
        c_uas, c_las, c_deps = process_batch_sentences(
            batch_sents, batch_idxs, model, encoded_bank, oracle, action_meta, device
        )
        
        total_uas += c_uas
        total_las += c_las
        total_deps += c_deps

    uas = total_uas / total_deps if total_deps > 0 else 0
    las = total_las / total_deps if total_deps > 0 else 0
    return uas, las

In [51]:
# 1. Khai b√°o ƒë∆∞·ªùng d·∫´n v√† thi·∫øt b·ªã
train_path = "UD_Vietnamese-VTB/vi_vtb-ud-train.conllu"
dev_path = "UD_Vietnamese-VTB/vi_vtb-ud-dev.conllu"
test_path = "UD_Vietnamese-VTB/vi_vtb-ud-test.conllu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("ƒêang ƒë·ªçc d·ªØ li·ªáu v√† x√¢y d·ª±ng Vocab...")
# L∆∞u √Ω: Th√™m word_vocab ƒë·ªÉ qu·∫£n l√Ω index cho c√°c node con n·∫øu c·∫ßn
# Trong cell th·ª±c hi·ªán load d·ªØ li·ªáu (t∆∞∆°ng t·ª± Cell 16 trong notebook)
train_sentences, word_vocab, p_vocab, d_vocab = load_data_and_build_vocab([train_path], is_train=True)

# G·ªåI H√ÄM N√ÄY NGAY T·∫†I ƒê√ÇY
train_sentences = map_sentence_to_ids(train_sentences, p_vocab, d_vocab)
dev_sentences = map_sentence_to_ids(load_data_and_build_vocab([dev_path], is_train=False), p_vocab, d_vocab)
test_sentences = map_sentence_to_ids(load_data_and_build_vocab([test_path], is_train=False), p_vocab, d_vocab)

vocabs = (p_vocab, d_vocab)

# 3. Kh·ªüi t·∫°o Oracle
# Oracle c·∫ßn ch·ª©a h√†m create_training_data_multiview s·ª≠ d·ª•ng h√†m extract_features m·ªõi c·ªßa ch√∫ng ta
oracle = Oracle(p_vocab, d_vocab)

ƒêang ƒë·ªçc d·ªØ li·ªáu v√† x√¢y d·ª±ng Vocab...


In [52]:
# 4. T·∫°o d·ªØ li·ªáu hu·∫•n luy·ªán
print(f"ƒêang t·∫°o chu·ªói transitions cho {len(train_sentences)} c√¢u...")
training_samples = []
for idx, sentence in enumerate(tqdm(train_sentences)):
    # H√†m n√†y s·∫Ω gi·∫£ l·∫≠p c√°c b∆∞·ªõc ƒëi (Shift, Arc) v√† g·ªçi extract_features t·∫°i m·ªói b∆∞·ªõc
    examples = oracle.create_training_data_multiview(sentence, idx)
    training_samples.extend(examples)

# 6. X√¢y d·ª±ng ng√¢n h√†ng Embedding PhoBERT (M·ªü comment)
# Bank n√†y l∆∞u ma tr·∫≠n [sent_len + 1, 768] cho t·ª´ng c√¢u
print("--- ƒêang tr√≠ch xu·∫•t PhoBERT Embedding Bank ---")
train_encoded_bank = pre_encode_sentences(train_sentences, device=device)
dev_encoded_bank = pre_encode_sentences(dev_sentences, device=device)
test_encoded_bank = pre_encode_sentences(test_sentences, device=device)

# 7. C·∫•u h√¨nh DataLoader (M·ªü comment v√† c·∫≠p nh·∫≠t)
train_dataset = TransitionDataset(training_samples)
train_loader = DataLoader(
    train_dataset,
    batch_size=256, # Gi·∫£m nh·∫π batch size n·∫øu VRAM b·ªã tr√†n do feature qu√° nhi·ªÅu
    shuffle=True,
    collate_fn=multiview_collate_fn # Ph·∫£i l√† h√†m tr·∫£ v·ªÅ 11 th√†nh ph·∫ßn
)

ƒêang t·∫°o chu·ªói transitions cho 1400 c√¢u...


  0%|          | 0/1400 [00:00<?, ?it/s]

--- ƒêang tr√≠ch xu·∫•t PhoBERT Embedding Bank ---


  0%|          | 0/1400 [00:00<?, ?it/s]

  0%|          | 0/1123 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

In [53]:
print(len(training_samples))

48561


In [None]:
# 8. Kh·ªüi t·∫°o M√¥ h√¨nh (B·ªî SUNG)
num_actions = len(oracle.action_map)
model = BiaffineDependencyModel(
    pos_vocab_size=len(p_vocab),
    dep_vocab_size=len(d_vocab),
    num_actions=num_actions,
    phobert_dim=768,
    pos_dim=64,
    dep_dim=64,
    node_dim=256,
    hidden_dim=1024,
    dropout_rate=0.3
).to(device)

print(f"‚úÖ Model ƒë√£ s·∫µn s√†ng v·ªõi {num_actions} h√†nh ƒë·ªông.")

# 9. B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán
model = train(
    model=model,
    train_loader=train_loader,
    train_encoded_bank=train_encoded_bank,
    dev_sentences=dev_sentences,
    dev_encoded_bank=dev_encoded_bank,
    oracle=oracle,
    epochs=50,
    lr=3e-4,
    device=device
)

# 3. ƒê√°nh gi√° cu·ªëi c√πng tr√™n t·∫≠p TEST
print("\n--- ƒê√°nh gi√° cu·ªëi c√πng tr√™n t·∫≠p TEST ---")
# T·∫£i l·∫°i tr·ªçng s·ªë t·ªët nh·∫•t tr∆∞·ªõc khi test
uas, las = evaluate_model_parallel(
    model, oracle, test_sentences, test_encoded_bank, device=device, is_training=False
)
print(f"Final Result: UAS: {uas*100:.2f}% | LAS: {las*100:.2f}%")


‚úÖ Model ƒë√£ s·∫µn s√†ng v·ªõi 159 h√†nh ƒë·ªông.


Epoch 1/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 1...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

‚ú® K·∫øt qu·∫£: UAS: 68.35% | LAS: 51.96%
‚úÖ New best model saved: best_model_bi_linear_scorer_1_0.6835.pth


Epoch 2/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 2...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

‚ú® K·∫øt qu·∫£: UAS: 73.02% | LAS: 58.55%
‚úÖ New best model saved: best_model_bi_linear_scorer_2_0.7302.pth


Epoch 3/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 3...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

‚ú® K·∫øt qu·∫£: UAS: 74.45% | LAS: 60.15%
‚úÖ New best model saved: best_model_bi_linear_scorer_3_0.7445.pth


Epoch 4/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 4...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

‚ú® K·∫øt qu·∫£: UAS: 75.37% | LAS: 61.18%
‚úÖ New best model saved: best_model_bi_linear_scorer_4_0.7537.pth


Epoch 5/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 5...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

‚ú® K·∫øt qu·∫£: UAS: 74.33% | LAS: 60.21%


Epoch 6/50:   0%|          | 0/190 [00:00<?, ?it/s]


üîç ƒêang ƒë√°nh gi√° Epoch 6...


Batch Evaluating:   0%|          | 0/18 [00:00<?, ?it/s]

In [33]:
model_save = BiaffineDependencyModel(
    pos_vocab_size=len(p_vocab),
    dep_vocab_size=len(d_vocab),
    num_actions=num_actions,
    phobert_dim=768,
    pos_dim=64,
    dep_dim=64,
    node_dim=256,
    hidden_dim=1024,
    dropout_rate=0.3
).to(device)

# 3. N·∫°p "tr·ªçng s·ªë" (weights) v√†o khung ƒë√£ t·∫°o
weights_path = 'best_model_biaffine/best_model_bi_linear_scorer_38_0.7754.pth'
state_dict = torch.load(weights_path, map_location=device)

# N·∫øu file .pth c·ªßa b·∫°n ch·ª©a tr·ª±c ti·∫øp state_dict:
model_save.load_state_dict(state_dict)

# 4. Chuy·ªÉn sang ch·∫ø ƒë·ªô ƒë√°nh gi√°
model_save.eval()

print("‚úÖ ƒê√£ load weights th√†nh c√¥ng!")

‚úÖ ƒê√£ load weights th√†nh c√¥ng!


In [34]:
uas, las = evaluate_model_parallel(
    model_save, oracle, test_sentences, test_encoded_bank, device=device, is_training=False
)
print(f"Final Result: UAS: {uas*100:.2f}% | LAS: {las*100:.2f}%")

Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Final Result: UAS: 80.68% | LAS: 68.92%
