In [2]:
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_from_disk
from sklearn.metrics import accuracy_score
import json

# 설정
device = "cuda" if torch.cuda.is_available() else "cpu"
k_values = [0, 1, 2, 4, 8, 16]

# 데이터 로딩
train_set = load_from_disk('../data/train')
dev_set = load_from_disk('../data/dev')
test_set = load_from_disk('../data/test')

In [3]:
train_set

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 800
})

In [4]:
dev_set

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 100
})

In [5]:
test_set

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 100
})

In [6]:
# 텍스트 생성 함수
def generate(model, tokenizer, prompt, max_new_tokens=64):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded[len(prompt):].strip()

In [7]:
def evaluate(model, tokenizer, support_examples, test_set, result_file):
    preds, labels, results = [], [], []

    for example in test_set:
        prompt = ""
        for support in support_examples:
            prompt += (
                f"Q: {support['question']}\n"
                f"C: {support['context']}\n"
                f"A: {support['final_decision']}\n\n"
            )
        prompt += (
            f"Q: {example['question']}\n"
            f"C: {example['context']}\n"
            f"A: Based on the context, what is the correct answer? "
            f"Please answer only one word: yes, no, or maybe.\nAnswer:"
        )

        answer = generate(model, tokenizer, prompt)
        answer_lower = answer.lower().strip()

        # 정확한 매칭만 허용
        if answer_lower.startswith("yes"):
            pred = "yes"
        elif answer_lower.startswith("no"):
            pred = "no"
        elif answer_lower.startswith("maybe"):
            pred = "maybe"
        else:
            pred = "wrong"  # fallback

        preds.append(pred)
        labels.append(example["final_decision"])
        results.append({
            "question": example["question"],
            "context": example["context"],
            "prediction": pred,
            "raw_answer": answer,
            "label": example["final_decision"]
        })

    acc = accuracy_score(labels, preds)

    with open(result_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    return acc, preds, labels

In [8]:
# k-shot 지원 예제 미리 샘플링 (seed 고정)
random.seed(42)
train_list = list(train_set)

support_examples_dict = {}
for k in k_values:
    if k > 0:
        support_examples_dict[k] = random.sample(train_list, k)
    else:
        support_examples_dict[k] = []

In [10]:
metaicl_base_path = "../models/metaicl_qwen_meta"

for model_k in [8, 16]:
    print(f"\n==== Loading MetaICL Model trained with k={model_k} shots ====")
    model_path = f"{metaicl_base_path}/k_{model_k}/final"
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    model.eval()

    # 이 모델로 모든 k_values에 대해 평가
    for eval_k in [0, 1, 2, 4, 8, 16]:
        support_examples = support_examples_dict[eval_k]
        result_file = f"results_metaicl_model_k{model_k}_eval_k{eval_k}.json"
        print(f"Evaluating model trained on k={model_k} shots with support set k={eval_k}...")
        acc, preds, labels = evaluate(model, tokenizer, support_examples, test_set, result_file)
        print(f"[MetaICL Model k={model_k}] Accuracy with support k={eval_k}: {acc:.4f}")



==== Loading MetaICL Model trained with k=8 shots ====


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]Cancellation requested; stopping current tasks.
Fetching 2 files:   0%|          | 0/2 [00:15<?, ?it/s]


KeyboardInterrupt: 