In [None]:
!pip install loguru

In [None]:
import os
import sys
print(os.getcwd())
sys.path.append('Sim_tools')

from Sim_tools.dataset_Sim import TrainDataset, TestDataset
from Sim_tools.model_Sim import SimcseModel, simcse_sup_loss, simcse_unsup_loss
from Sim_tools.train_Sim import load_train_data_supervised, train_sup
from Sim_tools.embed_evidence import embed_evidence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import BertModel, BertConfig, BertTokenizer
from loguru import logger
from tqdm import tqdm

import pandas as pd
import json

In [None]:
batch_size = 16
max_length = 256
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(f"Using device: {device}")
checkpoint = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(checkpoint)
model = SimcseModel(pretrained_model=checkpoint, pooling='pooler', dropout=0.1).to(device)
model.load_state_dict(torch.load("saved_model/best_model.pt", map_location=device))
model.eval()

    
evidence_csv_path = "data/evidence.json"  
output_csv_path = "data/evidence_embed.json" 
embed_evidence(evidence_csv_path, output_csv_path, model, tokenizer, device, max_length)

In [None]:
import torch
import torch.nn.functional as F

def match_evidence_by_similarity(claim_embedding, evidence_embeddings_dict, top_k=5, temperature=0.05):
    """
    根据 claim 与 evidence embedding 的相似度，返回最相关的 evidence ID 列表。

    参数:
        claim_embedding: torch.Tensor，形状为 [768]（或其他维度）
        evidence_embeddings_dict: dict，格式 {'evidence-id': torch.Tensor([768])}
        top_k: 返回的 evidence 数量
        temperature: softmax 温度缩放因子

    返回:
        List[str]：与 claim 最相关的 evidence id（按相似度排序）
    """

    # 所有 evidence 的 ID 和向量堆叠成矩阵
    evidence_ids = list(evidence_embeddings_dict.keys())
    evidence_tensor = torch.stack([evidence_embeddings_dict[eid] for eid in evidence_ids])  # [num_evidence, 768]

    # 计算余弦相似度
    sim_scores = F.cosine_similarity(claim_embedding.unsqueeze(0), evidence_tensor, dim=1)  # [num_evidence]

    # softmax 转成概率（可选，如果你只想排序，不一定要 softmax）
    sim_probs = F.softmax(sim_scores / temperature, dim=0)  # [num_evidence]

    # 取 top-k
    topk_probs, topk_indices = torch.topk(sim_probs, top_k)

    # 返回 evidence id（按相似度高到低排序）
    top_evidence_ids = [evidence_ids[i] for i in topk_indices]

    return top_evidence_ids

In [None]:
def varify_evidence(train_json_path, evidence_embeddings_dict, top_k=5,temperature=0.05):


    train_json_path = "data/train-claims.json"
    with open(train_json_path, "r", encoding="utf-8") as f:
        train_data = json.load(f)

    for claim_id, claim_info in train_data.items():

        claim_text = claim_info["claim_text"]
        positive_ids = claim_info["evidences"]
        
        inputs = tokenizer(
            claim_text,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )

        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        token_type_ids = inputs['token_type_ids'].to(device)

        with torch.no_grad():
            claim_embedding = model(input_ids, attention_mask, token_type_ids)        
        result_lst = match_evidence_by_similarity(claim_embedding, evidence_embeddings_dict, top_k=5, temperature=0.05)

In [None]:
def varify_evidence(train_json_path, evidence_embeddings_dict, model, tokenizer, device, max_length=256, top_k=5, temperature=0.05):

    with open(train_json_path, "r", encoding="utf-8") as f:
        train_data = json.load(f)

    total_claims = 0
    total_hits = 0

    for claim_id, claim_info in train_data.items():
        claim_text = claim_info["claim_text"]
        positive_ids = set(claim_info["evidences"])  # ground truth ids as set

        # Tokenize claim
        inputs = tokenizer(
            claim_text,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        token_type_ids = inputs['token_type_ids'].to(device)

        # Get claim embedding
        with torch.no_grad():
            claim_embedding = model(input_ids, attention_mask, token_type_ids)  # [1, 768]
            claim_embedding = claim_embedding.squeeze(0)  # -> [768]

        # Get top-k matching evidence ids
        result_lst = match_evidence_by_similarity(claim_embedding, evidence_embeddings_dict, top_k=top_k, temperature=temperature)

        # Evaluate hit (if any of top-k is in positive ids)
        hit = any(eid in positive_ids for eid in result_lst)
        total_hits += int(hit)
        total_claims += 1

    accuracy = total_hits / total_claims if total_claims > 0 else 0.0
    print(f"Top-{top_k} Accuracy: {accuracy:.4f} ({total_hits}/{total_claims})")
    return accuracy