In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, AutoTokenizer # BitsAndBytesConfig Ï∂îÍ∞Ä
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training,PeftModel ,TaskType # peft Í¥ÄÎ†® Î™®Îìà Ï∂îÍ∞Ä
import pandas as pd
import shutil
import os
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import random_split
import random
import itertools
import gc
import nltk
from rouge_score import rouge_scorer
import jiwer
from sklearn.model_selection import train_test_split
import time
from tqdm import tqdm
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any, List
import math
from pathlib import Path
import csv

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
print("Available devices:", torch.cuda.device_count())  # 1Í∞úÎßå Î≥¥Ïó¨Ïïº Ï†ïÏÉÅ
print("Device name:", torch.cuda.get_device_name(0))    # Ïã§Ï†úÎ°ú GPU 1Î≤à Ïù¥Î¶ÑÏù¥ ÎÇòÏò¥

Available devices: 1
Device name: NVIDIA H100 80GB HBM3


In [4]:
SEED = 42

def set_seeds(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # Ïó¨Îü¨ GPU ÏÇ¨Ïö© Ïãú
        # CUDA Ïó∞ÏÇ∞Ïùò Í≤∞Ï†ïÎ°†Ï†Å Ïã§Ìñâ ÏÑ§Ï†ï (ÏÑ±Îä•Ïóê ÏïΩÍ∞Ñ ÏòÅÌñ• Ï§Ñ Ïàò ÏûàÏùå)
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

set_seeds(SEED)
print(f"Î™®Îì† ÎùºÏù¥Î∏åÎü¨Î¶¨Ïùò ÏãúÎìúÍ∞Ä {SEED}Î°ú Í≥†Ï†ïÎêòÏóàÏäµÎãàÎã§.")

Î™®Îì† ÎùºÏù¥Î∏åÎü¨Î¶¨Ïùò ÏãúÎìúÍ∞Ä 42Î°ú Í≥†Ï†ïÎêòÏóàÏäµÎãàÎã§.


In [5]:
class EEGDataset(Dataset):
    def __init__(self,
                 data_dir = "/home/work/skku/hyo/hyo/dataset/sentence.parquet"):
        df = pd.read_parquet(data_dir)
        eeg_vecs = df["eeg"].to_numpy()

        arr = np.stack(eeg_vecs).astype(np.float32)
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        mu, std = arr.mean(0, keepdims=True), arr.std(0, keepdims=True)+1e-8
        self.eeg_arr = (arr - mu) / std      # Ï†ïÍ∑úÌôî
        self.text_arr = df["text"].to_numpy() # ÌÖçÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞
        self.data = list(zip(torch.tensor(self.eeg_arr), self.text_arr))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class ConvEEGEncoder(nn.Module):
    """
    840-dim Î≤°ÌÑ∞Î•º 1√ó840 ÏãúÌÄÄÏä§Î°ú Î≥¥Í≥† Conv1D Îëê Ï∏µÏúºÎ°ú Ïû†Ïû¨ÌëúÌòÑ ÏÉùÏÑ±
    Ï∂úÎ†•ÏùÄ (B, latent_dim)
    """
    def __init__(self, input_dim=840, latent_dim=128, hidden=256):
        super().__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv1d(1, hidden, kernel_size=3, padding=1), nn.ReLU(),
            nn.Conv1d(hidden, latent_dim, kernel_size=3, padding=1), nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool1d(1)   # Í∏∏Ïù¥ 840 ‚Üí 1 Î°ú ÏïïÏ∂ï

    def forward(self, x):           # x: (B, feat)
        x = x.unsqueeze(1)          # (B, 1, 840)
        z = self.conv_stack(x)      # (B, latent_dim, 840)
        z = self.pool(z).squeeze(-1)  # (B, latent_dim)
        return z

class RVQ(nn.Module):
    def __init__(self, num_quantizers, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.num_quantizers = num_quantizers # ÏΩîÎìúÎ∂ÅÏùò Í∞úÏàò (n_q)
        self.num_embeddings = num_embeddings # Í∞Å ÏΩîÎìúÎ∂Å ÎÇ¥ ÏûÑÎ≤†Îî©(ÏΩîÎìúÏõåÎìú) Í∞úÏàò (n_emb, Ïñ¥Ìúò ÌÅ¨Í∏∞)
        self.embedding_dim = embedding_dim   # Í∞Å ÏûÑÎ≤†Îî©Ïùò Ï∞®Ïõê (D, latent_dimÍ≥º ÎèôÏùº)
        self.commitment_cost = commitment_cost # VQ ÏÜêÏã§ Í≥ÑÏÇ∞ Ïãú ÏÇ¨Ïö©ÎêòÎäî ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞

        # num_quantizers Í∞úÏàòÎßåÌÅºÏùò ÏΩîÎìúÎ∂Å(nn.Embedding Î†àÏù¥Ïñ¥)ÏùÑ Î¶¨Ïä§Ìä∏Î°ú Í∞ÄÏßê
        self.codebooks = nn.ModuleList([
            nn.Embedding(self.num_embeddings, self.embedding_dim) for _ in range(self.num_quantizers)
        ])
        # ÏΩîÎìúÎ∂Å Í∞ÄÏ§ëÏπò Ï¥àÍ∏∞Ìôî (ÏÑ†ÌÉù ÏÇ¨Ìï≠Ïù¥ÏßÄÎßå ÏùºÎ∞òÏ†ÅÏúºÎ°ú ÏàòÌñâ)
        for i, codebook in enumerate(self.codebooks):
            nn.init.uniform_(codebook.weight, -1.0 / self.num_embeddings, 1.0 / self.num_embeddings)

    def forward(self, z_e): # ÏûÖÎ†• z_eÏùò Î™®Ïñë: (B, L, D), Ïó¨Í∏∞ÏÑú L=1, D=embedding_dim
        B, L, D = z_e.shape
        z_e_flat = z_e.reshape(-1, D) # (B*L, D) Î™®ÏñëÏúºÎ°ú ÌéºÏπ® (Ïó¨Í∏∞ÏÑúÎäî (B, D)ÏôÄ ÎèôÏùº)

        all_quantized_stages = [] # Í∞Å ÏΩîÎìúÎ∂ÅÏóêÏÑú ÏñëÏûêÌôîÎêú Î≤°ÌÑ∞Îì§ÏùÑ Ï†ÄÏû•Ìï† Î¶¨Ïä§Ìä∏
        all_indices = []          # Í∞Å ÏΩîÎìúÎ∂ÅÏóêÏÑú ÏÑ†ÌÉùÎêú Ïù∏Îç±Ïä§Îì§ÏùÑ Ï†ÄÏû•Ìï† Î¶¨Ïä§Ìä∏
        residual = z_e_flat       # Ï≤´ Î≤àÏß∏ ÏΩîÎìúÎ∂ÅÏóê ÏûÖÎ†•Îê† ÏûîÏ∞® (Ï¥àÍ∏∞ÏóêÎäî z_e_flat Ï†ÑÏ≤¥)

        # num_quantizers ÎßåÌÅº Î∞òÎ≥µ (Í∞Å ÏΩîÎìúÎ∂ÅÏóê ÎåÄÌï¥ ÏàúÏ∞®Ï†ÅÏúºÎ°ú Ï≤òÎ¶¨)
        for i in range(self.num_quantizers):
            codebook = self.codebooks[i] # ÌòÑÏû¨ ÏÇ¨Ïö©Ìï† ÏΩîÎìúÎ∂Å

            # ÌòÑÏû¨ ÏûîÏ∞®(residual)ÏôÄ ÌòÑÏû¨ ÏΩîÎìúÎ∂ÅÏùò Î™®Îì† ÏûÑÎ≤†Îî© Í∞ÑÏùò Ïú†ÌÅ¥Î¶¨Îìú Í±∞Î¶¨ Ï†úÍ≥± Í≥ÑÏÇ∞
            # distances Î™®Ïñë: (B*L, num_embeddings)
            distances = torch.sum(residual**2, dim=1, keepdim=True) \
                        - 2 * torch.matmul(residual, codebook.weight.t()) \
                        + torch.sum(codebook.weight**2, dim=1, keepdim=True).t()

            # Í∞ÄÏû• Í∞ÄÍπåÏö¥ ÏûÑÎ≤†Îî©Ïùò Ïù∏Îç±Ïä§ Ï∞æÍ∏∞
            # current_indices Î™®Ïñë: (B*L)
            current_indices = torch.argmin(distances, dim=1)
            all_indices.append(current_indices) # ÌòÑÏû¨ ÏΩîÎìúÎ∂ÅÏùò Ïù∏Îç±Ïä§ Ï†ÄÏû•

            # ÏÑ†ÌÉùÎêú Ïù∏Îç±Ïä§Î•º ÏÇ¨Ïö©ÌïòÏó¨ ÏñëÏûêÌôîÎêú Î≤°ÌÑ∞(ÏΩîÎìúÏõåÎìú) Í∞ÄÏ†∏Ïò§Í∏∞
            # quantized_vector Î™®Ïñë: (B*L, D)
            quantized_vector = codebook(current_indices)
            # ÏõêÎûò Î™®Ïñë (B, L, D)Î°ú Î≥µÏõêÌïòÏó¨ Ï†ÄÏû• (Ïó¨Í∏∞ÏÑúÎäî (B, 1, D))
            all_quantized_stages.append(quantized_vector.reshape(B, L, D))

            # Îã§Ïùå ÏΩîÎìúÎ∂ÅÏúºÎ°ú ÎÑòÍ∏∏ ÏûîÏ∞® Í≥ÑÏÇ∞
            # Ï§ëÏöî: quantized_vectorÏóêÏÑú Í∑∏ÎûòÎîîÏñ∏Ìä∏ ÌùêÎ¶ÑÏùÑ ÎÅäÍ∏∞ ÏúÑÌï¥ .detach() ÏÇ¨Ïö©
            residual = residual - quantized_vector.detach()

        # Î™®Îì† ÏΩîÎìúÎ∂ÅÏóêÏÑú ÎÇòÏò® ÏñëÏûêÌôîÎêú Î≤°ÌÑ∞Îì§ÏùÑ Ìï©ÏÇ∞ (EEGTran ÎÖºÎ¨∏ Figure 2 Ï∞∏Ï°∞)
        # final_quantized_output Î™®Ïñë: (B, L, D)
        final_quantized_output = torch.stack(all_quantized_stages, dim=0).sum(dim=0)

        # ÏàòÏßëÎêú Ïù∏Îç±Ïä§Îì§ÏùÑ (B, L, num_quantizers) ÌòïÌÉúÎ°ú ÏåìÏùå
        # stacked_indices Î™®Ïñë: (B, L, n_q) (Ïó¨Í∏∞ÏÑúÎäî (B, 1, n_q))
        stacked_indices = torch.stack(all_indices, dim=1).reshape(B, L, self.num_quantizers)

        # ÏµúÏ¢Ö Î∞òÌôòÍ∞í: Ìï©ÏÇ∞Îêú ÏñëÏûêÌôî Î≤°ÌÑ∞, ÏåìÏù∏ Ïù∏Îç±Ïä§ ÏãúÌÄÄÏä§, VQ ÏÜêÏã§
        # RVQTokenizerÏùò forwardÏóêÏÑúÎäî Ïù¥ Ï§ë Ï≤´ Îëê Í∞úÎ•º zq, indicesÎ°ú Î∞õÍ≤å Îê©ÎãàÎã§.
        return final_quantized_output, stacked_indices

class RVQTokenizer(nn.Module):
    def __init__(self,
                 feat=840,
                 latent=128,  # 1024->2048
                 n_q=12,
                 n_emb=512,
                 hidden=256,
                 TOKENIZER_CHECKPOINT_PATH = "/home/work/skku/hyo/hyo/model/rvq_best_model_sen_512.pt"
                 ):
        super().__init__()
        self.n_q = n_q
        self.n_emb = n_emb
        # Ïã§Ï†ú ConvEEGEncoderÏôÄ RVQ Î™®ÎìàÏù¥ Ïó¨Í∏∞Ïóê ÏôÄÏïº Ìï®
        self.enc = ConvEEGEncoder(feat, latent, hidden)
        self.rvq = RVQ(num_quantizers=n_q, num_embeddings=n_emb, embedding_dim=latent)

        checkpoint = torch.load(TOKENIZER_CHECKPOINT_PATH, map_location="cpu")
        self.enc.load_state_dict(checkpoint["encoder"])
        for i, cb_weight_tensor in enumerate(checkpoint["codebooks"]):
            self.rvq.codebooks[i].weight.data = cb_weight_tensor

    @torch.no_grad()
    def forward(self, x): # x: (B, 840)
        z = self.enc(x)
        quantized_vector, token_indices = self.rvq(z.unsqueeze(1)) # vq_lossÎäî Î¨¥Ïãú
        zq = quantized_vector
        indices = token_indices # Î™®Ïñë (B, 1, n_q)
        # ÎßåÏïΩ LLaDA ÏûÖÎ†•Ïö©ÏúºÎ°ú (B, n_q) Î™®ÏñëÏùò Ïù∏Îç±Ïä§Î•º ÏõêÌïúÎã§Î©¥ squeeze(1) ÌïÑÏöî
        # return zq, indices.squeeze(1)
        return zq, indices # ÌòÑÏû¨ pasted_content.txtÏùò Ï£ºÏÑùÍ≥º ÎßûÏ∂îÎ†§Î©¥ Ïù¥ÎåÄÎ°ú

class UnifiedEEGTextTokenizer:
    def __init__(self,
                rvq_tokenizer_instance,
                llada_text_tokenizer_instance,
                max_seq_length,
                v_text_original,
                eeg_token_length,
                ):

        self.rvq_tokenizer = rvq_tokenizer_instance
        self.llada_text_tokenizer = llada_text_tokenizer_instance
        self.max_seq_length = max_seq_length
        self.v_text_original = v_text_original
        self.eeg_token_length = eeg_token_length


        self.bos_token_id = torch.tensor([self.llada_text_tokenizer.bos_token_id], dtype=torch.long, device=config.system.DEVICE)
        self.eos_token_id = torch.tensor([self.llada_text_tokenizer.eos_token_id], dtype=torch.long, device=config.system.DEVICE)
        self.pad_token_id = self.llada_text_tokenizer.pad_token_id if self.llada_text_tokenizer.pad_token_id is not None else self.llada_text_tokenizer.eos_token_id

        self.user_prompt_intro_ids = self.llada_text_tokenizer.encode(
                "<start_id>user<end_id>\n",
                add_special_tokens=False,
                return_tensors="pt"
            ).squeeze(0).to(config.system.DEVICE)

        self.assistant_prompt_intro_ids = self.llada_text_tokenizer.encode(
                "<eot_id><start_id>assistant<end_id>\n",
                add_special_tokens=False,
                return_tensors="pt"
            ).squeeze(0).to(config.system.DEVICE)

        print(f"Unified Tokenizer Initialized:")
        print(f"  BOS ID: {self.bos_token_id.item()}")
        print(f"  EOS ID: {self.eos_token_id.item()}")
        print(f"  PAD ID: {self.pad_token_id}")
        print(f"  User Prompt Intro IDs ({len(self.user_prompt_intro_ids)} tokens): {self.user_prompt_intro_ids.tolist()}")
        print(f"  Assistant Prompt Intro IDs ({len(self.assistant_prompt_intro_ids)} tokens): {self.assistant_prompt_intro_ids.tolist()}")

    def process_single_sample(self, eeg_tensor, assistant_response_text):
        eeg_tensor = eeg_tensor.to(config.system.DEVICE)

        with torch.no_grad():
            _, local_eeg_indices = self.rvq_tokenizer(eeg_tensor.unsqueeze(0))
        local_eeg_indices = local_eeg_indices.squeeze(0).squeeze(0)
        if local_eeg_indices.ndim == 0: local_eeg_indices = local_eeg_indices.unsqueeze(0)
        global_eeg_ids = (local_eeg_indices + self.v_text_original).to(config.system.DEVICE)

        len_fixed_tokens = (
            len(self.bos_token_id) +
            len(self.user_prompt_intro_ids) +
            self.eeg_token_length +
            len(self.assistant_prompt_intro_ids) +
            len(self.eos_token_id)
        )
        max_assistant_text_len = self.max_seq_length - len_fixed_tokens

        if max_assistant_text_len <= 0:
            # print(f"Warning: Not enough space for assistant text. Max assistant length: {max_assistant_text_len}. Truncating or using empty.")
            assistant_response_text = ""
            max_assistant_text_len = 1 # Ensure at least 1 token space if possible, or handle error

        tokenized_assistant_text = self.llada_text_tokenizer(
            assistant_response_text,
            max_length=max(1, max_assistant_text_len),
            padding="do_not_pad",
            truncation=True,
            add_special_tokens=False,
            return_tensors="pt"
        )
        assistant_text_ids = tokenized_assistant_text.input_ids.squeeze(0).to(config.system.DEVICE)

        input_ids_list = [
            self.bos_token_id,
            self.user_prompt_intro_ids,
            global_eeg_ids, # This is the {USER_CONTENT}
            self.assistant_prompt_intro_ids,
            assistant_text_ids, # This is the {ASSISTANT_CONTENT}
            self.eos_token_id
        ]
        input_ids = torch.cat(input_ids_list, dim=0)

        prompt_len = (
                len(self.bos_token_id) +
                len(self.user_prompt_intro_ids) +
                len(global_eeg_ids) +
                len(self.assistant_prompt_intro_ids)
            )

        labels = input_ids.clone()
        labels[:prompt_len] = -100 # ÌîÑÎ°¨ÌîÑÌä∏ Î∂ÄÎ∂ÑÏùÄ ÏÜêÏã§ Í≥ÑÏÇ∞ÏóêÏÑú Ï†úÏô∏
        # EOS ÌÜ†ÌÅ∞ÎèÑ ÏòàÏ∏° ÎåÄÏÉÅÏù¥ ÏïÑÎãàÎùºÎ©¥ -100 Ï≤òÎ¶¨Ìï† Ïàò ÏûàÏúºÎÇò, Î≥¥ÌÜµÏùÄ ÏòàÏ∏° ÎåÄÏÉÅÏóê Ìè¨Ìï®.

        current_len = len(input_ids)
        attention_mask = torch.ones_like(input_ids)

        if current_len < self.max_seq_length:
            padding_len = self.max_seq_length - current_len
            pad_values = torch.full((padding_len,), self.pad_token_id, dtype=torch.long, device=config.system.DEVICE)
            input_ids = torch.cat([input_ids, pad_values], dim=0)
            labels = torch.cat([labels, torch.full((padding_len,), -100, dtype=torch.long, device=config.system.DEVICE)], dim=0)
            attention_mask = torch.cat([attention_mask, torch.zeros((padding_len,), dtype=torch.long, device=config.system.DEVICE)], dim=0)
        elif current_len > self.max_seq_length:
            input_ids = input_ids[:self.max_seq_length]
            labels = labels[:self.max_seq_length]
            attention_mask = attention_mask[:self.max_seq_length]
            # prompt_len = min(prompt_len, self.max_seq_length) # ÏûòÎ¶∞ Í≤ΩÏö∞ ÌîÑÎ°¨ÌîÑÌä∏ Í∏∏Ïù¥ÎèÑ Ï°∞Ï†ïÎê† Ïàò ÏûàÏùå

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "prompt_lengths": torch.tensor(prompt_len, dtype=torch.long)
        }

    def build_chat_template_prompt(self, eeg_tensor):
        """
        EEG ÏûÖÎ†•ÎßåÏúºÎ°ú inferenceÏö© prompt Íµ¨ÏÑ±.
        assistant_response_text ÏóÜÏù¥ ÏÉùÏÑ±.
        """
        eeg_tensor = eeg_tensor.to(config.system.DEVICE)

        with torch.no_grad():
            _, local_eeg_indices = self.rvq_tokenizer(eeg_tensor.unsqueeze(0))  # (1, L)
        local_eeg_indices = local_eeg_indices.squeeze(0).squeeze(0)
        if local_eeg_indices.ndim == 0:
            local_eeg_indices = local_eeg_indices.unsqueeze(0)

        global_eeg_ids = (local_eeg_indices + self.v_text_original).to(config.system.DEVICE)

        # ÌîÑÎ°¨ÌîÑÌä∏ Íµ¨ÏÑ± (text ÏóÜÏù¥)
        input_ids_list = [
            self.bos_token_id,
            self.user_prompt_intro_ids,
            global_eeg_ids,  # USER_CONTENT (EEG ÌÜ†ÌÅ∞)
            self.assistant_prompt_intro_ids  # ASSISTANT_CONTENT ÌÖúÌîåÎ¶øÍπåÏßÄÎßå
        ]
        input_ids = torch.cat(input_ids_list, dim=0)

        prompt_len = len(input_ids)

        attention_mask = torch.ones_like(input_ids)

        return {
            "input_ids": input_ids.unsqueeze(0),          # (1, T)
            "attention_mask": attention_mask.unsqueeze(0),# (1, T)
            "prompt_len": prompt_len                      # int
        }

class DataCollatorForEEGTextSFT:
    def __init__(self, unified_tokenizer_instance):
        self.unified_tokenizer = unified_tokenizer_instance

    def __call__(self, batch_of_samples):
        processed_samples = []
        for eeg_tensor, assistant_response_text in batch_of_samples:
            if eeg_tensor is None or assistant_response_text is None:
                continue
            processed_samples.append(self.unified_tokenizer.process_single_sample(eeg_tensor, assistant_response_text))

        if not processed_samples:
            # print("Warning: Collator received no valid samples to batch.")
            # Îπà ÌÖêÏÑúÎ•º Î∞òÌôòÌïòÍ±∞ÎÇò NoneÏùÑ Î∞òÌôòÌïòÏó¨ ÌïôÏäµ Î£®ÌîÑÏóêÏÑú Ï≤òÎ¶¨
            return None

        batched_input_ids = torch.stack([s["input_ids"] for s in processed_samples])
        batched_attention_mask = torch.stack([s["attention_mask"] for s in processed_samples])
        batched_labels = torch.stack([s["labels"] for s in processed_samples])
        batched_prompt_lengths = torch.stack([s["prompt_lengths"] for s in processed_samples])

        return {
            "input_ids": batched_input_ids,
            "attention_mask": batched_attention_mask,
            "labels": batched_labels,
            "prompt_lengths": batched_prompt_lengths
        }

In [6]:
def collate_fn_for_evaluation(batch):
    eeg_data_list = [item[0] for item in batch] # ÏõêÎ≥∏ EEG Îç∞Ïù¥ÌÑ∞ Î¶¨Ïä§Ìä∏
    reference_texts_list = [item[1] for item in batch] # Ï∞∏Ï°∞ ÌÖçÏä§Ìä∏ Î¶¨Ïä§Ìä∏

    # EEG Îç∞Ïù¥ÌÑ∞Îäî Î∞∞Ïπò ÎÇ¥ÏóêÏÑú Ìå®Îî© ÏóÜÏù¥ Î¶¨Ïä§Ìä∏ ÌòïÌÉúÎ°ú Ïú†ÏßÄÌïòÍ±∞ÎÇò,
    # ÎßåÏïΩ Î™®Îì† EEG Îç∞Ïù¥ÌÑ∞Ïùò Í∏∏Ïù¥Í∞Ä Í∞ôÎã§Î©¥ torch.stackÏùÑ ÏÇ¨Ïö©Ìï† Ïàò ÏûàÏäµÎãàÎã§.
    # Ïó¨Í∏∞ÏÑúÎäî Î¶¨Ïä§Ìä∏ ÌòïÌÉúÎ°ú Î∞òÌôòÌïòÍ≥†, ÏÉùÏÑ± Î£®ÌîÑÏóêÏÑú Í∞úÎ≥Ñ Ï≤òÎ¶¨ÌïúÎã§Í≥† Í∞ÄÏ†ïÌï©ÎãàÎã§.
    return {
        "batched_eeg_data": eeg_data_list, 
        "batched_reference_texts": reference_texts_list
    }

In [7]:
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def calculate_bleu_scores(references, hypotheses):
    """
    references: List[List[List[str]]]  # Í∞Å ÏÉòÌîåÎßàÎã§ Ïó¨Îü¨ Î†àÌçºÎü∞Ïä§ Î¨∏Ïû•(ÌÜ†ÌÅ∞ÌôîÎêú Î¶¨Ïä§Ìä∏)
    hypotheses: List[List[str]]        # Í∞Å ÏÉòÌîåÎßàÎã§ ÏÉùÏÑ±Îêú Î¨∏Ïû•(ÌÜ†ÌÅ∞ÌôîÎêú Î¶¨Ïä§Ìä∏)
    
    Î∞òÌôò: {
        "BLEU-1": float,
        "BLEU-2": float,
        "BLEU-3": float,
        "BLEU-4": float
    }
    """
    # Ïä§Î¨¥Îî© Ìï®Ïàò
    smooth_fn = SmoothingFunction().method1
    
    # Ï∏°Ï†ïÌï† BLEU weight ÏÑ§Ï†ï
    weight_dict = {
        "BLEU-1": (1.0, 0.0, 0.0, 0.0),
        "BLEU-2": (0.5, 0.5, 0.0, 0.0),
        "BLEU-3": (1/3, 1/3, 1/3, 0.0),
        "BLEU-4": (0.25, 0.25, 0.25, 0.25),
    }
    
    results = {}
    for name, weights in weight_dict.items():
        scores = []
        for refs, hyp in zip(references, hypotheses):
            # Îπà Î¨∏ÏûêÏó¥Ïù¥ÎÇò Î†àÌçºÎü∞Ïä§Í∞Ä ÏóÜÏúºÎ©¥ 0Ï†ê Ï≤òÎ¶¨
            if not refs or not refs[0] or not hyp:
                scores.append(0.0)
                continue
            score = sentence_bleu(
                refs,
                hyp,
                weights=weights,
                smoothing_function=smooth_fn
            )
            scores.append(score)
        # ÏÉòÌîåÎ≥Ñ Ï†êÏàò ÌèâÍ∑†
        results[name] = float(np.mean(scores)) if scores else 0.0
    
    return results

In [8]:
def calculate_rouge_scores(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    all_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

    for ref, hyp in zip(references, hypotheses):
        if not ref or not hyp:
            # Îπà Î¨∏ÏûêÏó¥ Ï≤òÎ¶¨ ÎòêÎäî ÌäπÏ†ï Í∞í Ìï†Îãπ
            for key in all_scores:
                all_scores[key].append(0.0) # f-measure Í∏∞Ï§Ä 0Ï†êÏúºÎ°ú Ï≤òÎ¶¨
            continue
        scores = scorer.score(ref, hyp)
        all_scores['rouge1'].append(scores['rouge1'].fmeasure)
        all_scores['rouge2'].append(scores['rouge2'].fmeasure)
        all_scores['rougeL'].append(scores['rougeL'].fmeasure)

    avg_scores = {key: np.mean(values) if values else 0.0 for key, values in all_scores.items()}
    return avg_scores

In [9]:
def calculate_wer(references, hypotheses):
    """
    Ï£ºÏñ¥ÏßÑ Ï†ïÎãµÍ≥º ÏòàÏ∏°Ïóê ÎåÄÌï¥ WERÏùÑ Í≥ÑÏÇ∞Ìï©ÎãàÎã§.
    references: list of strings (Í∞Å ÏÉòÌîåÏóê ÎåÄÌïú Ï†ïÎãµ Î¨∏Ïû• Î¶¨Ïä§Ìä∏)
    hypotheses: list of strings (Í∞Å ÏÉòÌîåÏóê ÎåÄÌïú ÏòàÏ∏° Î¨∏Ïû• Î¶¨Ïä§Ìä∏)
    """
    if not references or not hypotheses or len(references) != len(hypotheses):
        return 1.0 # ÏûòÎ™ªÎêú ÏûÖÎ†• Ï≤òÎ¶¨, WERÏùÄ ÎÇÆÏùÑÏàòÎ°ù Ï¢ãÏúºÎØÄÎ°ú 1.0 (100% Ïò§Î•ò) Î∞òÌôò

    # jiwer.compute_measuresÎäî Ï†ÑÏ≤¥ Î¶¨Ïä§Ìä∏Ïóê ÎåÄÌï¥ ÏßëÍ≥ÑÎêú Í≤∞Í≥ºÎ•º Î∞òÌôòÌï©ÎãàÎã§.
    # Í∞úÎ≥Ñ Î¨∏Ïû•Ïóê ÎåÄÌïú WERÏùÑ Í≥ÑÏÇ∞ÌïòÍ≥† ÌèâÍ∑†ÎÇ¥Îäî Í≤ÉÎ≥¥Îã§, Ï†ÑÏ≤¥ ÎßêÎ≠âÏπòÏóê ÎåÄÌïú WERÏùÑ Í≥ÑÏÇ∞ÌïòÎäî Í≤ÉÏù¥ ÏùºÎ∞òÏ†ÅÏûÖÎãàÎã§.
    measures = jiwer.compute_measures(references, hypotheses)
    return measures['wer']

In [11]:
def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

In [12]:
@ torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits.to(torch.float64), dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    return x

In [13]:
def load_finetuned_eeg_llada(
        base_ckpt: str = "GSAI-ML/LLaDA-8B-Base",
        comp_dir: Path = Path(
            "/home/work/skku/hyo/hyo/checkpoints_v2/best_finetuned_model_components"  # wte.pth, lm_head.pth, adapters/
        ),
        rvq_n_emb: int = 512,
        device: str = "cuda",
        load_in_4bit: bool = True,
):
    # --- ‚ë† base LLaDA (4-bit optional) ------------------------------
    bnb_cfg = (BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
        if load_in_4bit else None)

    base = AutoModelForCausalLM.from_pretrained(
        base_ckpt, quantization_config=bnb_cfg, torch_dtype="auto",
        trust_remote_code=True).to(device)

    # --- ‚ë° vocab resize --------------------------------------------
    v_txt = base.config.vocab_size
    base.resize_token_embeddings(v_txt + rvq_n_emb + 1)   # +1 = EEG-MASK

    # --- ‚ë¢ embedding / lm_head Î≥µÏõê -------------------------------
    wte_path = comp_dir / "wte.pth"
    lm_path = comp_dir / "lm_head.pth"

    llada_core = getattr(base, "model", base)          # 1Ï∞®
    llada_core = getattr(llada_core, "model", llada_core)  # 2Ï∞®
    #   3-2. transformer Ìï∏Îì§
    transformer = getattr(llada_core, "transformer", llada_core)

    if wte_path.exists():
        base.get_input_embeddings().load_state_dict(torch.load(wte_path, map_location="cpu"))
        print("‚úî wte restored")

    if lm_path.exists():
        transformer.ff_out.load_state_dict(torch.load(lm_path, map_location="cpu"))
        print("‚úî lm_head (ff_out) restored")

    # --- ‚ë£ LoRA attach (trainable=False Î°ú inference) --------------
    adapter_dir = comp_dir / "adapters"
    if adapter_dir.exists():
        model = PeftModel.from_pretrained(base, adapter_dir, is_trainable=False)
        print("‚úî LoRA adapter loaded")
    else:
        model = base
        print("‚ö† LoRA adapter not found, base model only")

    model.eval()
    return model

In [14]:
class SystemConfig(BaseModel):
    SEED: int = 42
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_WORKERS: int = 0

class PathsConfig(BaseModel):
    DATASET_PATH: str = "/home/work/skku/hyo/hyo/dataset/sentence.parquet"
    TOKENIZER_CHECKPOINT_PATH: str = "/home/work/skku/hyo/hyo/model/rvq_best_model_sen_512.pt"
    MODEL_SAVE_DIR: str = "./saved_models"
    BEST_MODEL_FILENAME: str = "eeg_llada_sft_best_model.pth"
    # LLADA_LOSS_FUNCTION_PATH: str = "/home/ubuntu/llada_loss_function.py" # ÌïÑÏöîÏãú Ï∂îÍ∞Ä
    # MODIFIED_TRAINING_LOOPS_PATH: str = "/home/ubuntu/modified_training_loops.py" # ÌïÑÏöîÏãú Ï∂îÍ∞Ä

class EEGEncoderConfig(BaseModel):
    INPUT_DIM: int = 840
    LATENT_DIM: int = 128 # RVQÏùò embedding_dimÍ≥º ÏùºÏπòÌï¥Ïïº Ìï®
    HIDDEN_DIM: int = 256

class RVQConfig(BaseModel):
    NUM_QUANTIZERS: int = 12 # RVQTokenizerÏùò n_q, UnifiedEEGTextTokenizerÏùò eeg_token_lengthÏôÄ ÏùºÏπò
    NUM_EMBEDDINGS: int = 512 # RVQTokenizerÏùò n_emb
    EMBEDDING_DIM: int = 128 # EEGEncoderConfigÏùò LATENT_DIMÍ≥º ÏùºÏπò
    COMMITMENT_COST: float = 0.25

class TokenizerConfig(BaseModel):
    # RVQTokenizer ÎÇ¥Î∂Ä ÌååÎùºÎØ∏ÌÑ∞ (EEGEncoderConfig, RVQConfig Í∞íÏúºÎ°ú ÎåÄÏ≤¥ Í∞ÄÎä•)
    # UnifiedEEGTextTokenizer ÌååÎùºÎØ∏ÌÑ∞
    MAX_SEQ_LENGTH: int = 1024 # LLaDA Î™®Îç∏Ïùò ÏµúÎåÄ Ïª®ÌÖçÏä§Ìä∏ Í∏∏Ïù¥ Í≥†Î†§
    V_TEXT_ORIGINAL: int = 32000 # LLaMA ÌÖçÏä§Ìä∏ ÌÜ†ÌÅ¨ÎÇòÏù¥Ï†ÄÏùò Ïñ¥Ìúò ÌÅ¨Í∏∞ (LLaDA-8B Í∏∞Ï§Ä)
    # EEG_TOKEN_LENGTH: int = 12 # RVQConfig.NUM_QUANTIZERS ÏôÄ ÎèôÏùº
    LLM_MODEL_NAME: str = "GSAI-ML/LLaDA-8B-Base" # LLM ÌÜ†ÌÅ¨ÎÇòÏù¥Ï†Ä Î°úÎìúÏö©

class ModelConfig(BaseModel):
    LLM_MODEL_NAME: str = "GSAI-ML/LLaDA-8B-Base"
    USE_QLORA: bool = True
    LORA_R: int = 16
    LORA_ALPHA: int = 32
    LORA_DROPOUT: float = 0.05
    LORA_BIAS: str = "none"
    # LLADA_MASK_TOKEN_ID: Optional[int] = None # ÎèôÏ†ÅÏúºÎ°ú ÏÑ§Ï†ïÎê† Ïàò ÏûàÏùå (Ïñ¥ÌúòÌÅ¨Í∏∞ + 1)

class TrainingConfig(BaseModel):
    BATCH_SIZE: int = 4 # GPU Î©îÎ™®Î¶¨Ïóê Îî∞Îùº Ï°∞Ï†à
    NUM_EPOCHS: int = 10
    START_EPOCH: int = 0
    LEARNING_RATE: float = 1e-4
    GRADIENT_ACCUMULATION_STEPS: int = 4 # BATCH_SIZE * GRAD_ACCUM = Ïã§Ï†ú Î∞∞Ïπò ÌÅ¨Í∏∞
    MAX_GRAD_NORM: float = 1.0
    # SCHEDULER: Optional[str] = None # Ïòà: "StepLR"
    TRAIN_LOG_INTERVAL: int = 1
    PATIENCE_EARLY_STOPPING: int = 3 # 0Ïù¥Î©¥ ÎπÑÌôúÏÑ±Ìôî

class GenerationConfig(BaseModel):
    RUN_TEST_LOOP_EACH_EPOCH: bool = False
    USE_LLADA_SAMPLING_FOR_GENERATION: bool = True
    MAX_GEN_TOKENS: int = 64
    NUM_SAMPLING_STEPS_GEN: int = 10
    REMASKING_STRATEGY_GEN: str = "low_confidence"
    REMASKING_RATIO_GEN: float = 0.25
    TEMPERATURE_GEN: float = 0.7
    TOP_K_GEN: int = 50
    TOP_P_GEN: float = 0.9
    HF_NUM_BEAMS_GEN: int = 1
    # HF_MAX_LENGTH_GEN: Optional[int] = None # ÎèôÏ†ÅÏúºÎ°ú ÏÑ§Ï†ï (ÏûÖÎ†•Í∏∏Ïù¥ + MAX_GEN_TOKENS)

class ExperimentConfig(BaseModel):
    system: SystemConfig = Field(default_factory=SystemConfig)
    paths: PathsConfig = Field(default_factory=PathsConfig)
    eeg_encoder: EEGEncoderConfig = Field(default_factory=EEGEncoderConfig)
    rvq: RVQConfig = Field(default_factory=RVQConfig)
    tokenizer: TokenizerConfig = Field(default_factory=TokenizerConfig)
    model: ModelConfig = Field(default_factory=ModelConfig)
    training: TrainingConfig = Field(default_factory=TrainingConfig)
    generation: GenerationConfig = Field(default_factory=GenerationConfig)

    # LLADA_MASK_TOKEN_IDÎäî ÎèôÏ†ÅÏúºÎ°ú ÏÑ§Ï†ïÎê† Ïàò ÏûàÏúºÎØÄÎ°ú, Ï¥àÍ∏∞Ìôî ÌõÑ ÏÑ§Ï†ïÌïòÎäî Í≤ÉÏùÑ Í∂åÏû•
    # Ïòà: config.model.LLADA_MASK_TOKEN_ID = tokenizer.llada_text_tokenizer.vocab_size + 1

config = ExperimentConfig()

In [15]:
comp_dir = Path("/home/work/skku/hyo/hyo/checkpoints_v2/best_finetuned_model_components")
model = load_finetuned_eeg_llada(
            base_ckpt="GSAI-ML/LLaDA-8B-Base",
            comp_dir=comp_dir,
            rvq_n_emb=512,
            device="cuda")

This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64



Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


‚úî wte restored
‚úî lm_head (ff_out) restored
‚úî LoRA adapter loaded


In [16]:
llada_txt_tokenizer = AutoTokenizer.from_pretrained(config.model.LLM_MODEL_NAME)
rvq_eeg_tokenizer = RVQTokenizer()
rvq_eeg_tokenizer = rvq_eeg_tokenizer.to(config.system.DEVICE)
rvq_eeg_tokenizer.eval() # Ï∂îÎ°† Î™®ÎìúÎ°ú ÏÑ§Ï†ï

v_original = llada_txt_tokenizer.vocab_size
eeg_seq_len = 12 # RVQ_N_Q
model_max_len = 512

unified_eeg_text_tokenizer = UnifiedEEGTextTokenizer(
    rvq_tokenizer_instance=rvq_eeg_tokenizer,
    llada_text_tokenizer_instance=llada_txt_tokenizer,
    max_seq_length=model_max_len,
    v_text_original=v_original,
    eeg_token_length=eeg_seq_len
)

Unified Tokenizer Initialized:
  BOS ID: 126080
  EOS ID: 126081
  PAD ID: 126081
  User Prompt Intro IDs (10 tokens): [27, 7351, 2983, 29, 3840, 27, 486, 2983, 29, 198]
  Assistant Prompt Intro IDs (15 tokens): [27, 68, 335, 2983, 3583, 7351, 2983, 29, 598, 10450, 27, 486, 2983, 29, 198]


In [17]:
# 1. EEGDataset Ïù∏Ïä§ÌÑ¥Ïä§ ÏÉùÏÑ± (config ÏÇ¨Ïö©)
# EEGDataset ÌÅ¥ÎûòÏä§ Ï†ïÏùòÎäî Ïù¥ÎØ∏ ÎÖ∏Ìä∏Î∂ÅÏóê ÏûàÎã§Í≥† Í∞ÄÏ†ïÌï©ÎãàÎã§.
eeg_dataset = EEGDataset(data_dir=config.paths.DATASET_PATH)
eeg_dataset = eeg_dataset # ÌÖåÏä§Ìä∏Ïö©
num_total_samples = len(eeg_dataset)
indices = list(range(num_total_samples))

# 2. Îç∞Ïù¥ÌÑ∞ÏÖã Î∂ÑÌï† (config ÏÇ¨Ïö©)
# Ï∞∏Í≥†: test_size Í∞íÎì§(ÌòÑÏû¨ 0.2 Î∞è 0.5)ÎèÑ config Í∞ùÏ≤¥Ïóê Ï∂îÍ∞ÄÌïòÏó¨ Í¥ÄÎ¶¨Ìï† Ïàò ÏûàÏäµÎãàÎã§.
# Ïòà: config.training.TRAIN_VAL_SPLIT_RATIO, config.training.VAL_TEST_SPLIT_RATIO
train_indices, temp_test_indices = train_test_split(
    indices,
    test_size=0.2, # Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ Ï§ë 20%Î•º (Í≤ÄÏ¶ù+ÌÖåÏä§Ìä∏)Ïö©ÏúºÎ°ú Î∂ÑÎ¶¨
    random_state=config.system.SEED, # configÏóêÏÑú SEED Í∞í ÏÇ¨Ïö©
    shuffle=True
)

val_indices, test_indices = train_test_split(
    temp_test_indices,
    test_size=0.5, # (Í≤ÄÏ¶ù+ÌÖåÏä§Ìä∏)Ïö© Îç∞Ïù¥ÌÑ∞ Ï§ë 50%Î•º ÌÖåÏä§Ìä∏Ïö©ÏúºÎ°ú Î∂ÑÎ¶¨ (Ï¶â, Ï†ÑÏ≤¥Ïùò 10%)
    random_state=config.system.SEED, # configÏóêÏÑú SEED Í∞í ÏÇ¨Ïö©
    shuffle=True
)

# Í∞Å Subset ÏÉùÏÑ±
train_dataset = Subset(eeg_dataset, train_indices)
val_dataset = Subset(eeg_dataset, val_indices)
test_dataset = Subset(eeg_dataset, test_indices)

print(f"Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ÏÖã ÌÅ¨Í∏∞: {num_total_samples}")
print(f"ÌïôÏäµ ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: {len(train_dataset)} (Ï†ÑÏ≤¥Ïùò {len(train_dataset)/num_total_samples:.2%})")
print(f"Í≤ÄÏ¶ù ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: {len(val_dataset)} (Ï†ÑÏ≤¥Ïùò {len(val_dataset)/num_total_samples:.2%})")
print(f"ÌÖåÏä§Ìä∏ ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: {len(test_dataset)} (Ï†ÑÏ≤¥Ïùò {len(test_dataset)/num_total_samples:.2%})")

# DataCollatorForEEGTextSFT Ïù∏Ïä§ÌÑ¥Ïä§ ÏÉùÏÑ± (Ïù¥Ï†ÑÏóê data_collator Î°ú Ï†ïÏùòÎêòÏóàÎã§Í≥† Í∞ÄÏ†ï)
data_collator = DataCollatorForEEGTextSFT(unified_eeg_text_tokenizer)

# 3. DataLoader ÏÉùÏÑ± (config ÏÇ¨Ïö©)
# ÌïôÏäµ Îç∞Ïù¥ÌÑ∞ Î°úÎçî
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.training.BATCH_SIZE, # configÏóêÏÑú BATCH_SIZE Í∞í ÏÇ¨Ïö©
    collate_fn=collate_fn_for_evaluation,
    shuffle=True,
    num_workers=config.system.NUM_WORKERS, # configÏóêÏÑú NUM_WORKERS Í∞í ÏÇ¨Ïö©
    pin_memory= False
)

# Í≤ÄÏ¶ù Îç∞Ïù¥ÌÑ∞ Î°úÎçî
val_dataloader = DataLoader(
    val_dataset,
    batch_size=config.training.BATCH_SIZE, # configÏóêÏÑú BATCH_SIZE Í∞í ÏÇ¨Ïö©
    collate_fn=collate_fn_for_evaluation,
    shuffle=False,
    num_workers=config.system.NUM_WORKERS, # configÏóêÏÑú NUM_WORKERS Í∞í ÏÇ¨Ïö©
    pin_memory= False
)

# ÌÖåÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞ Î°úÎçî
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.training.BATCH_SIZE, # configÏóêÏÑú BATCH_SIZE Í∞í ÏÇ¨Ïö©
    collate_fn=collate_fn_for_evaluation,
    shuffle=False,
    num_workers=config.system.NUM_WORKERS, # configÏóêÏÑú NUM_WORKERS Í∞í ÏÇ¨Ïö©
    pin_memory= False
)

print(f"\nÌïôÏäµ Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: {len(train_dataloader)}")
print(f"Í≤ÄÏ¶ù Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: {len(val_dataloader)}")
print(f"ÌÖåÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: {len(test_dataloader)}")


print("\nÎç∞Ïù¥ÌÑ∞ Î°úÎî© Î∞è Î∂ÑÌï† (config Í∏∞Î∞ò) ÏôÑÎ£å.")

Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ÏÖã ÌÅ¨Í∏∞: 25616
ÌïôÏäµ ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: 20492 (Ï†ÑÏ≤¥Ïùò 80.00%)
Í≤ÄÏ¶ù ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: 2562 (Ï†ÑÏ≤¥Ïùò 10.00%)
ÌÖåÏä§Ìä∏ ÏÑ∏Ìä∏ ÌÅ¨Í∏∞: 2562 (Ï†ÑÏ≤¥Ïùò 10.00%)

ÌïôÏäµ Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: 5123
Í≤ÄÏ¶ù Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: 641
ÌÖåÏä§Ìä∏ Îç∞Ïù¥ÌÑ∞Î°úÎçî Î∞∞Ïπò Ïàò: 641

Îç∞Ïù¥ÌÑ∞ Î°úÎî© Î∞è Î∂ÑÌï† (config Í∏∞Î∞ò) ÏôÑÎ£å.


In [None]:
from tqdm import tqdm
import torch, numpy as np
import json

# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ ÏÉòÌîåÎßÅ ÌïòÏù¥ÌçºÌååÎùºÎØ∏ÌÑ∞ ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
GEN_STEPS   = 64    # diffusion Ïä§ÌÖù
GEN_LEN     = 128        # ÏÉùÏÑ± ÌÜ†ÌÅ∞ Ïàò
BLOCK_LEN   = 64
TEMP        = 0.3
CFG_SCALE   = 0.9
REMASKING   = "low_confidence"

device = next(model.parameters()).device        # Î™®Îç∏Ïù¥ Ïò¨ÎùºÍ∞Ñ GPU/CPU

model.eval()
all_hyp, all_ref = [], []

# ‚îÄ‚îÄ‚îÄ Í≤∞Í≥º Ï†ÄÏû•Ïö© ‚îÄ‚îÄ‚îÄ
out_dir      = Path("/home/work/skku/hyo/hyo/generate")
out_dir.mkdir(exist_ok=True)
csv_path     = out_dir / "eeg_gen_results.csv"
metrics_path = out_dir / "final_metrics.json"

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="üîÆ  Generating from EEG"):
        eeg_list = batch["batched_eeg_data"]            # list[Tensor(840)]
        ref_list = batch["batched_reference_texts"]     # list[str]

        for eeg_tensor, ref_text in zip(eeg_list, ref_list):
            # 1) EEG  ‚ûú  RVQ  ‚ûú  prompt
            prompt_pack = unified_eeg_text_tokenizer.build_chat_template_prompt(eeg_tensor)
            prompt_ids  = prompt_pack["input_ids"].to(device)          # (1, T)
            prompt_len  = prompt_pack["prompt_len"]
            # EEG ÌÜ†ÌÅ∞ id Í∞Ä ÏÑúÎ°ú Îã§Î•∏ÏßÄ
            # print(prompt_ids[0, :20])          # Îã§Î•∏ EEG sample ÏóêÏÑú Í∞íÏù¥ Îã¨ÎùºÏïº Ìï®

            # # Ïñ¥ÎåëÌÑ∞ Î°úÎìúÎêêÎäîÏßÄ
            # print(model.peft_config.keys())    # LoRA ÏÑ§Ï†ï dict Í∞Ä ÎπÑÏñ¥ ÏûàÏúºÎ©¥ Ïñ¥ÎåëÌÑ∞ ÎØ∏Ï†ÅÏö©

            # 2) LLaDA diffusion sampling
            out_ids = generate(
                model,                                  # ‚Üê Î∞îÎ°ú model Ï†ÑÎã¨
                prompt = prompt_ids,         # (T,)
                steps  = GEN_STEPS,
                gen_length = GEN_LEN,
                block_length = BLOCK_LEN,
                temperature = TEMP,
                cfg_scale   = CFG_SCALE,
                remasking   = REMASKING
            )

            # 3) ÎîîÏΩîÎî© (prompt Ïù¥ÌõÑÎßå)
            gen_txt = unified_eeg_text_tokenizer.llada_text_tokenizer.batch_decode(
                out_ids[:, prompt_len:], skip_special_tokens=True
            )[0].strip()
    
            # 4) Ï¶âÏãú CSVÎ°ú flush (Ï§ëÍ∞ÑÏóê ÎÅäÍ≤®ÎèÑ Îç∞Ïù¥ÌÑ∞ Î≥¥Ï°¥)
            with open(csv_path, "a", newline="", encoding="utf-8") as f:
                csv.writer(f).writerow([ref_text, gen_txt])

            all_ref.append(ref_text)
            all_hyp.append(gen_txt)

    # ‚îÄ‚îÄ ÏµúÏ¢Ö ÏßÄÌëú Í≥ÑÏÇ∞ & Ï†ÄÏû• ‚îÄ‚îÄ
    bleu  = calculate_bleu_scores(all_ref, all_hyp)
    rouge = calculate_rouge_scores(all_ref, all_hyp)
    wer   = calculate_wer(all_ref, all_hyp)
    metrics = {"BLEU": bleu, "ROUGE": rouge, "WER": wer}
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=2)

    print("üìù  finished! metrics saved ‚Üí", metrics_path)