# Scoring and Citations Testbed
---

The objective of this notebook is to:
1. Experiment and figure out how to perform Scoring as described in the [paper](papers/TRLM_2412.02626.pdf)
2. Experiment with linear search for citation attribution

To further explore: 
1. Experiment with binary and exclusion search
2. Experiment with retrieval

## Import Libraries

In [1]:
import torch as t
import numpy as np
import pandas as pd
import torch.nn.functional as F
from tqdm.auto import tqdm

from transformers import GPTNeoXForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pd.set_option('display.max_colwidth', 200)
pd.set_option('display.float_format', '{:.4f}'.format)

## Define Util Functions

In [3]:
device="cuda" if t.cuda.is_available() else "cpu"

In [4]:
# TODO: may be helpful to move this over to a utils.py later, or define the models as a separate classes?

# Load models
def load_models():
    # Forward model
    fo_model = GPTNeoXForCausalLM.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        revision="step143000",
        cache_dir="./.cache/pythia-160m-deduped/step143000",
    ).to(device)
    
    fo_tokenizer = AutoTokenizer.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        revision="step143000",
        cache_dir="./.cache/pythia-160m-deduped/step143000",
    )
    
    # Backward model
    ba_model = GPTNeoXForCausalLM.from_pretrained(
        "afterless/reverse-pythia-160m",
        cache_dir="./.cache/reverse-pythia-160m",
    ).to(device)
    
    ba_tokenizer = AutoTokenizer.from_pretrained(
        "afterless/reverse-pythia-160m",
        cache_dir="./.cache/reverse-pythia-160m",
    )
    
    return fo_model, fo_tokenizer, ba_model, ba_tokenizer

In [5]:
# TODO : same with above

# Load dataset
def load_cnn_dataset(num_samples=10):
    try:
        # Try with a specific cache directory
        dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir=".cache")
        print("Dataset loaded successfully")
        
        # Verify the structure - this helps debug
        if num_samples > 0:
            print("Example dataset item:", dataset['train'][0])
            
        # Take only a small sample for testing
        if hasattr(dataset, 'train'):
            return dataset['train'].select(range(min(num_samples, len(dataset['train']))))
        
        return dataset['train'][:num_samples]
        
    except Exception as e:
        print(f"Error loading full dataset: {e}")
        
        # Create a tiny synthetic dataset for testing
        print("Creating synthetic test dataset instead...")
        
        sample_data = {
            'article': [
                "John likes to play basketball. He goes to the court every evening. His friends join him on weekends.",
                "The company announced record profits. Investors were pleased. The stock price increased by 10%."
            ],
            'highlights': [
                "John plays basketball regularly with friends.",
                "Company profits lead to stock price increase."
            ],
            'id': ['test1', 'test2']  # Added ID field
        }
        
        return Dataset.from_dict(sample_data)

In [6]:
def calculate_llm_score(query, answer, model, tokenizer, task='citation', backward=False, debug=False):
    """
    Calculate log probability of response given prompt or vice versa.
    
    Args:
        query (str): The prompt text
        answer (str): The response text
        model: The language model
        tokenizer: The corresponding tokenizer
        direction (str): "forward" for P(response|prompt) or "backward" for P(prompt|response)
    
    Returns:
        dict: Contains token-wise and sequence log probabilities
    """
    
    # The paper describes "Score" as conditional distribution (Section 4) which means the Log Probability, and therefore
    # this reimplementation uses Log Probability.
    
    # The notation used here is P(Query|Answer) to make it easier to compare with the paper
    
    # First, prepare the texts
    if not backward: 
        #Forward
        conditioning_prompt = ' is a summary of ' if task =='citation' else ' has an answer to '
    else:
        #Backward
        conditioning_prompt = ' is summarized by ' if task =='citation' else ' is answered by '
    
    # DEBUG 
    if debug:
        print(f"Context: {answer + conditioning_prompt}")
        print(f"Target: {query}")
    
    # convert to tokens, but because the model is auto-regressive, it predicts left -> right 
    # (i.e. tokens at t, t+1, t+2 predicts the t+3, t+4)
    # so we may need to reverse the tokens? 
    # input_ids = tokenizer.encode(query + (conditioning_prompt + answer), return_tensors="pt").to(model.device)
    
    target_ids = tokenizer.encode(query, return_tensors="pt")
    context_ids = tokenizer.encode(answer + conditioning_prompt, return_tensors="pt")

    # store length to "divide" the texts later
    target_len = target_ids.shape[1]
    context_len = context_ids.shape[1]
    
    if backward:
        # We need to reverse the tokens in backward
        target_ids = t.flip(target_ids, (1,))
        context_ids = t.flip(context_ids, (1,))

    input_ids = t.cat((context_ids, target_ids), dim=1).to(model.device)

    # Get model output
    with t.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Extract token probabilities for the target text
    token_probs = []
    # Because the text is (query+answer), we just want to get the (answer) logits
    for i in range(context_len - 1, context_len + target_len - 1):
        
        # essentially, get the probability for the actual token at sequence, i.e. 
        # if "Harry Potter is the boy who survived" and i = 5
        # then we get the probabilities of the model output up until "who"
        # and find what's the probability of "survived"
        
        # get the logits [batch_size, sequence_length, vocabulary_size]
        next_token_logits = logits[0, i, :]  # no batch, sequence i, all vocab
        
        # get the actual token
        next_token_id = input_ids[0, i+1].item()
        
        # Convert logits to probabilities
        next_token_probs = F.softmax(next_token_logits, dim=0)
        prob = next_token_probs[next_token_id].item()
        log_prob = np.log(prob)
        
        token_text = tokenizer.decode([next_token_id])
        token_probs.append({
            'token': token_text,
            'token_id': next_token_id,
            'log_prob': log_prob
        })
    
    # Calculate sequence probability
    sequence_log_prob = sum(tp['log_prob'] for tp in token_probs)
    # Normalize by length to get per-token average
    normalized_log_prob = sequence_log_prob / len(token_probs)
    # Convert to perplexity if needed
    perplexity = np.exp(-sequence_log_prob / len(token_probs))
    
    return {
        'token_log_probs': token_probs,
        'sequence_log_prob': sequence_log_prob,
        'normalized_log_prob': normalized_log_prob,
        'perplexity': perplexity
    }



In [7]:
# Testing

fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()


# Example Text
sentence = "Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won't cast a spell on him."
highlight = "Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday"
adverse_highlight = "Daniel Craig is recasted as James Bond again"

# Defining sentence/highlight query/answer is still confusing tho sheesh


# Define prompts
ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)

adv_ba_score = calculate_llm_score(sentence, adverse_highlight, ba_model, ba_tokenizer, backward=True)
adv_fo_score = calculate_llm_score(sentence, adverse_highlight, fo_model, fo_tokenizer)


scores_data = {
    'Model Type': ['Backward', 'Forward', 'Backward', 'Forward'],
    'Highlight': ['Correct', 'Correct', 'Adverse', 'Adverse'],
    'Sequence Log Prob': [
        ba_score['sequence_log_prob'],
        fo_score['sequence_log_prob'],
        adv_ba_score['sequence_log_prob'],
        adv_fo_score['sequence_log_prob']
    ],
    'Normalized Log Prob': [
        ba_score['normalized_log_prob'],
        fo_score['normalized_log_prob'],
        adv_ba_score['normalized_log_prob'],
        adv_fo_score['normalized_log_prob']
    ],
    'Perplexity': [
        ba_score['perplexity'],
        fo_score['perplexity'],
        adv_ba_score['perplexity'],
        adv_fo_score['perplexity']
    ]
}

# Create DataFrame
pd.DataFrame(scores_data)

Unnamed: 0,Model Type,Highlight,Sequence Log Prob,Normalized Log Prob,Perplexity
0,Backward,Correct,-113.2715,-2.7627,15.8429
1,Forward,Correct,-113.7015,-2.7732,16.0099
2,Backward,Adverse,-137.511,-3.3539,28.6148
3,Forward,Adverse,-151.0687,-3.6846,39.8293


##  Citation, Linear Search

In [8]:
dataset = load_cnn_dataset(num_samples=50)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Dataset loaded successfully
Example dataset item: {'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places be

In [9]:
# Show dataframe
pd.DataFrame(dataset)

Unnamed: 0,article,highlights,id
0,"LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won't cast a spell...",Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .\nYoung actor says he has no plans to fritter his cash away .\nRadcliffe's earnings from first five Potter films have be...,42c027e4ff9730fbb3de84c1af0d2c506e41c3e4
1,"Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O'Brien takes users inside a ja...","Mentally ill inmates in Miami are housed on the ""forgotten floor""\nJudge Steven Leifman says most are there as a result of ""avoidable felonies""\nWhile CNN tours facility, patient shouts: ""I am the...",ee8871b15c50d0db17b0179a6d2beab35065f1e9
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who were on the Minneapolis bridge when it collapsed told harrowing tales of survival. ""The whole bridge from one side of the Mississippi to the other just ...","NEW: ""I thought I was going to die,"" driver says .\nMan says pickup truck was folded in half; he just has cut on face .\nDriver: ""I probably had a 30-, 35-foot free fall""\nMinnesota bridge collaps...",06352019a19ae31e527f37f7571c6dd7f0c5da37
3,"WASHINGTON (CNN) -- Doctors removed five small polyps from President Bush's colon on Saturday, and ""none appeared worrisome,"" a White House spokesman said. The polyps were removed and sent to the ...","Five small polyps found during procedure; ""none worrisome,"" spokesman says .\nPresident reclaims powers transferred to vice president .\nBush undergoes routine colonoscopy at Camp David .",24521a2abb2e1f5e34e6824e0f9e56904a2b0e88
4,"(CNN) -- The National Football League has indefinitely suspended Atlanta Falcons quarterback Michael Vick without pay, officials with the league said Friday. NFL star Michael Vick is set to appea...","NEW: NFL chief, Atlanta Falcons owner critical of Michael Vick's conduct .\nNFL suspends Falcons quarterback indefinitely without pay .\nVick admits funding dogfighting operation but says he did n...",7fe70cc8b12fab2d0a258fababf7d9c6b5e1262a
5,"BAGHDAD, Iraq (CNN) -- Dressed in a Superman shirt, 5-year-old Youssif held his sister's hand Friday, seemingly unaware that millions of people across the world have been touched by his story. Nea...","Parents beam with pride, can't stop from smiling from outpouring of support .\nMom: ""I was so happy I didn't know what to do""\nBurn center in U.S. has offered to provide treatment for reconstructi...",a1ebb8bb4d370a1fdf28769206d572be60642d70
6,"BAGHDAD, Iraq (CNN) -- The women are too afraid and ashamed to show their faces or have their real names used. They have been driven to sell their bodies to put food on the table for their childre...","Aid workers: Violence, increased cost of living drive women to prostitution .\nGroup is working to raise awareness of the problem with Iraq's political leaders .\nTwo Iraqi mothers tell CNN they t...",7c0e61ac829a3b3b653e2e3e7536cc4881d1f264
7,"BOGOTA, Colombia (CNN) -- A key rebel commander and fugitive from a U.S. drug trafficking indictment was killed over the weekend in an air attack on a guerrilla encampment, the Colombian military ...","Tomas Medina Caracas was a fugitive from a U.S. drug trafficking indictment .\n""El Negro Acacio"" allegedly helped manage extensive cocaine network .\nU.S. Justice Department indicted him in 2002 ....",f0d73bdab711763e745cdc75850861c9018f235d
8,"WASHINGTON (CNN) -- White House press secretary Tony Snow, who is undergoing treatment for cancer, will step down from his post September 14 and be replaced by deputy press secretary Dana Perino, ...","President Bush says Tony Snow ""will battle cancer and win"" Job of press secretary ""has been a dream for me,"" Snow says Snow leaving on September 14, will be succeeded by Dana Perino .",5e22bbfc7232418b8d2dd646b952e404df5bd048
9,"(CNN) -- Police and FBI agents are investigating the discovery of an empty rocket launcher tube on the front lawn of a Jersey City, New Jersey, home, FBI spokesman Sean Quinn said. Niranjan Desai ...","Empty anti-tank weapon turns up in front of New Jersey home .\nDevice handed over to Army ordnance disposal unit .\nWeapon not capable of being reloaded, experts say .",613d6311ec2c1985bd44707d1796d275452fe156


In [10]:
def linear_attribution_search(dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer):
    """
    Perform linear attribution search for citations as described in TRLM paper.
    
    For each highlight (summary sentence), find the most likely article sentence
    that it was derived from by scoring all possible pairs.
    """
    results = []
    
    # Process only the first few examples for demonstration
    for idx, example in tqdm(dataset.iterrows(), total=len(dataset)):
        # Split article and highlights into sentences
        article_sentences = sent_tokenize(example['article'])
        highlight_sentences = sent_tokenize(example['highlights'])
        
        # For demonstration, process just the first highlight sentence
        if not highlight_sentences:
            continue
            
        highlight = highlight_sentences[0]
        
        # Store best attribution for each model
        best_ba_sentence = None
        best_ba_score = float('-inf')
        best_fo_sentence = None
        best_fo_score = float('-inf')
        
        # Linear search through all article sentences
        for sentence in article_sentences:
            # Skip very short sentences
            if len(sentence.split()) < 3:
                continue
                
            # Calculate scores using both models
            ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
            fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)
            
            # Track best scores
            if ba_score['normalized_log_prob'] > best_ba_score:
                best_ba_score = ba_score['normalized_log_prob']
                best_ba_sentence = sentence
                
            if fo_score['normalized_log_prob'] > best_fo_score:
                best_fo_score = fo_score['normalized_log_prob']
                best_fo_sentence = sentence
        
        # Add results to our list
        results.append({
            'id': example['id'],
            'highlight': highlight,
            'ba_citation': best_ba_sentence,
            'ba_score': best_ba_score,
            'ba_perplexity': np.exp(-best_ba_score),
            'fo_citation': best_fo_sentence,
            'fo_score': best_fo_score,
            'fo_perplexity': np.exp(-best_fo_score)
        })
    
    
    return results


In [None]:
# Binary Search 구현 (이전과 동일)
def binary_search_citation(article, highlight, model, tokenizer, backward=False, max_iterations=10):
    """
    Binary search로 citation을 찾는 함수
    
    Args:
        article (str): 원본 기사 텍스트
        highlight (str): 하이라이트 텍스트
        model: 언어 모델
        tokenizer: 토크나이저
        backward (bool): backward 모델 여부
        max_iterations: 최대 반복 횟수
    
    Returns:
        dict: 최적의 citation과 스코어 정보
    """
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    left, right = 0, len(sentences) - 1
    best_score = float('-inf')
    best_citation = ''
    iteration = 0
    
    while left <= right and iteration < max_iterations:
        mid = (left + right) // 2
        
        # mid를 중심으로 문맥 생성 (주변 2문장 포함)
        start = max(0, mid - 1)
        end = min(len(sentences), mid + 2)
        candidate = ' '.join(sentences[start:end])
        
        # 스코어 계산
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        
        if score > best_score:
            best_score = score
            best_citation = candidate
        
        # 이진 탐색 방향 결정
        if mid > 0:
            left_candidate = ' '.join(sentences[max(0, mid-2):mid+1])
            left_score = calculate_llm_score(highlight, left_candidate, model, 
                                          tokenizer, task='citation', 
                                          backward=backward)['normalized_log_prob']
        else:
            left_score = float('-inf')
            
        if mid < len(sentences) - 1:
            right_candidate = ' '.join(sentences[mid:end+1])
            right_score = calculate_llm_score(highlight, right_candidate, model, 
                                           tokenizer, task='citation', 
                                           backward=backward)['normalized_log_prob']
        else:
            right_score = float('-inf')
            
        if left_score > score and left_score >= right_score:
            right = mid - 1
        elif right_score > score:
            left = mid + 1
        else:
            break
            
        iteration += 1
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': perplexity
    }

# Linear과 Binary를 비교하는 평가 함수
def evaluate_citations_with_linear_and_binary(dataset, num_samples=10):
    """
    Linear_attribution_search와 Binary search를 사용하여 citation 평가
    
    Args:
        dataset: 평가에 사용할 데이터셋 (dict 리스트 또는 DataFrame 변환 가능)
        num_samples: 평가할 샘플 수
    
    Returns:
        list: Linear와 Binary 결과 포함한 평가 결과
    """
    fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()
    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # 데이터셋을 pandas DataFrame으로 변환
    if isinstance(dataset, list):
        df_dataset = pd.DataFrame(dataset)
    else:
        df_dataset = pd.DataFrame(dataset)  # dataset이 이미 다른 형식일 경우
        
    # num_samples만큼 슬라이싱
    df_dataset = df_dataset.iloc[:num_samples]
    
    # Linear attribution search 실행
    linear_results = linear_attribution_search(df_dataset, ba_model, ba_tokenizer, 
                                              fo_model, fo_tokenizer)
    
    results = []
    
    # 리스트로 변환된 데이터셋을 사용
    dataset_list = df_dataset.to_dict('records')  # DataFrame을 리스트로 변환
    
    for i in tqdm(range(min(num_samples, len(dataset_list)))):
        article = dataset_list[i]['article']
        highlight = dataset_list[i]['highlights']
        highlight_sentences = sent_tokenize(highlight)
        if not highlight_sentences:
            continue
        first_highlight = highlight_sentences[0]
        
        # Linear 결과 가져오기
        linear_result = linear_results[i]
        
        # Binary search 실행
        ba_binary = binary_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_binary = binary_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        # 임베딩 유사도 계산
        highlight_emb = sentence_transformer.encode(first_highlight)
        ba_linear_emb = sentence_transformer.encode(linear_result['ba_citation'])
        fo_linear_emb = sentence_transformer.encode(linear_result['fo_citation'])
        ba_binary_emb = sentence_transformer.encode(ba_binary['citation'])
        fo_binary_emb = sentence_transformer.encode(fo_binary['citation'])
        
        # ROUGE 스코어 계산
        ba_linear_rouge = scorer.score(first_highlight, linear_result['ba_citation'])
        fo_linear_rouge = scorer.score(first_highlight, linear_result['fo_citation'])
        ba_binary_rouge = scorer.score(first_highlight, ba_binary['citation'])
        fo_binary_rouge = scorer.score(first_highlight, fo_binary['citation'])
        
        result = {
            'id': dataset_list[i]['id'],
            'highlight': first_highlight,
            # Linear search 결과
            'ba_linear_citation': linear_result['ba_citation'],
            'fo_linear_citation': linear_result['fo_citation'],
            'ba_linear_score': linear_result['ba_score'],
            'fo_linear_score': linear_result['fo_score'],
            'ba_linear_perplexity': linear_result['ba_perplexity'],
            'fo_linear_perplexity': linear_result['fo_perplexity'],
            'ba_linear_emb_similarity': cosine_similarity([highlight_emb], [ba_linear_emb])[0][0],
            'fo_linear_emb_similarity': cosine_similarity([highlight_emb], [fo_linear_emb])[0][0],
            'ba_linear_rougeL_fmeasure': ba_linear_rouge['rougeL'].fmeasure,
            'fo_linear_rougeL_fmeasure': fo_linear_rouge['rougeL'].fmeasure,
            # Binary search 결과
            'ba_binary_citation': ba_binary['citation'],
            'fo_binary_citation': fo_binary['citation'],
            'ba_binary_score': ba_binary['score'],
            'fo_binary_score': fo_binary['score'],
            'ba_binary_perplexity': ba_binary['perplexity'],
            'fo_binary_perplexity': fo_binary['perplexity'],
            'ba_binary_emb_similarity': cosine_similarity([highlight_emb], [ba_binary_emb])[0][0],
            'fo_binary_emb_similarity': cosine_similarity([highlight_emb], [fo_binary_emb])[0][0],
            'ba_binary_rougeL_fmeasure': ba_binary_rouge['rougeL'].fmeasure,
            'fo_binary_rougeL_fmeasure': fo_binary_rouge['rougeL'].fmeasure,
        }
        results.append(result)
    
    return results

# 비교 결과 표시 함수
def display_comparison_results(results):
    results_df = pd.DataFrame(results)
    
    # 평균 메트릭 계산
    # score: 값이 높을수록 좋은 결과
    # perplexity: 값이 낮을수록 모델이 더 확신을 가지고 예측했다는 뜻
    # emb_similarity: 하이라이트와 인용문 간의 의미적 유사도를 나타내며, 0-1 값이 클수록 유사도 높음
    # rougeL_fmeasure: 하이라이트와 인용문 간의 텍스트 중첩, 0-1 값이 클수록 더 많은 공통 단어
    metrics = ['score', 'perplexity', 'emb_similarity', 'rougeL_fmeasure']
    comparison_data = {}
    
    for model_type in ['ba', 'fo']:
        for search_type in ['linear', 'binary']:
            col_prefix = f'{model_type}_{search_type}_'
            comparison_data[f'{model_type.upper()} {search_type.capitalize()}'] = {
                metric: results_df[f'{col_prefix}{metric}'].mean() 
                for metric in metrics
            }
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # 이전 스타일의 결과 표시를 위해 선택된 열로 DataFrame 생성
    display_df = results_df[['highlight', 
                            'ba_linear_citation', 'ba_linear_score', 'ba_linear_perplexity',
                            'fo_linear_citation', 'fo_linear_score', 'fo_linear_perplexity',
                            'ba_binary_citation', 'ba_binary_score', 'ba_binary_perplexity',
                            'fo_binary_citation', 'fo_binary_score', 'fo_binary_perplexity']]
    
    return comparison_df, display_df

# 실행 및 결과 비교
dataset = load_cnn_dataset(num_samples=50)  # dataset이 dict 리스트라고 가정
comparison_results = evaluate_citations_with_linear_and_binary(dataset, num_samples=10)
comparison_df, display_df = display_comparison_results(comparison_results)

# 결과 출력
print("Comparison of Average Metrics:")
print(comparison_df.round(4))
print("\nDetailed Results:")
print(display_df)

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from nltk import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer

# 기존 함수들 (Linear와 Binary는 그대로 유지)
# Binary Search 구현
def binary_search_citation(article, highlight, model, tokenizer, backward=False, max_iterations=10):
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    left, right = 0, len(sentences) - 1
    best_score = float('-inf')
    best_citation = ''
    iteration = 0
    
    while left <= right and iteration < max_iterations:
        mid = (left + right) // 2
        start = max(0, mid - 1)
        end = min(len(sentences), mid + 2)
        candidate = ' '.join(sentences[start:end])
        
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        
        if score > best_score:
            best_score = score
            best_citation = candidate
        
        if mid > 0:
            left_candidate = ' '.join(sentences[max(0, mid-2):mid+1])
            left_score = calculate_llm_score(highlight, left_candidate, model, 
                                          tokenizer, task='citation', 
                                          backward=backward)['normalized_log_prob']
        else:
            left_score = float('-inf')
            
        if mid < len(sentences) - 1:
            right_candidate = ' '.join(sentences[mid:end+1])
            right_score = calculate_llm_score(highlight, right_candidate, model, 
                                           tokenizer, task='citation', 
                                           backward=backward)['normalized_log_prob']
        else:
            right_score = float('-inf')
            
        if left_score > score and left_score >= right_score:
            right = mid - 1
        elif right_score > score:
            left = mid + 1
        else:
            break
            
        iteration += 1
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': perplexity
    }

# Exclusion Search 구현
def exclusion_search_citation(article, highlight, model, tokenizer, backward=False, threshold=-7.0):
    """
    Exclusion search로 citation을 찾는 함수
    
    Args:
        article (str): 원본 기사 텍스트
        highlight (str): 하이라이트 텍스트
        model: 언어 모델
        tokenizer: 토크나이저
        backward (bool): backward 모델 여부
        threshold (float): 제외 임계값 (이보다 낮은 점수는 제외)
    
    Returns:
        dict: 최적의 citation과 스코어 정보
    """
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    best_score = float('-inf')
    best_citation = ''
    best_perplexity = float('inf')
    
    # 모든 문장에 대해 점수 계산 후 제외
    for i in range(len(sentences)):
        candidate = sentences[i]  # 단일 문장만 사용 (Linear와 유사)
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        
        # 임계값 이상인 경우에만 고려
        if score >= threshold and score > best_score:
            best_score = score
            best_citation = candidate
            best_perplexity = perplexity
    
    # 유효한 결과가 없으면 기본값 반환
    if best_score == float('-inf'):
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': best_perplexity
    }

# Linear, Binary, Exclusion을 비교하는 평가 함수
def evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10):
    """
    Linear, Binary, Exclusion search를 사용하여 citation 평가
    
    Args:
        dataset: 평가에 사용할 데이터셋 (dict 리스트 또는 DataFrame 변환 가능)
        num_samples: 평가할 샘플 수
    
    Returns:
        list: Linear, Binary, Exclusion 결과 포함한 평가 결과
    """
    fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()
    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # 데이터셋을 pandas DataFrame으로 변환
    if isinstance(dataset, list):
        df_dataset = pd.DataFrame(dataset)
    else:
        df_dataset = pd.DataFrame(dataset)
        
    # num_samples만큼 슬라이싱
    df_dataset = df_dataset.iloc[:num_samples]
    
    # Linear attribution search 실행
    linear_results = linear_attribution_search(df_dataset, ba_model, ba_tokenizer, 
                                              fo_model, fo_tokenizer)
    
    results = []
    dataset_list = df_dataset.to_dict('records')
    
    for i in tqdm(range(min(num_samples, len(dataset_list)))):
        article = dataset_list[i]['article']
        highlight = dataset_list[i]['highlights']
        highlight_sentences = sent_tokenize(highlight)
        if not highlight_sentences:
            continue
        first_highlight = highlight_sentences[0]
        
        # Linear 결과
        linear_result = linear_results[i]
        
        # Binary search
        ba_binary = binary_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_binary = binary_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        # Exclusion search
        ba_exclusion = exclusion_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_exclusion = exclusion_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        # 임베딩 유사도 계산
        highlight_emb = sentence_transformer.encode(first_highlight)
        ba_linear_emb = sentence_transformer.encode(linear_result['ba_citation'])
        fo_linear_emb = sentence_transformer.encode(linear_result['fo_citation'])
        ba_binary_emb = sentence_transformer.encode(ba_binary['citation'])
        fo_binary_emb = sentence_transformer.encode(fo_binary['citation'])
        ba_exclusion_emb = sentence_transformer.encode(ba_exclusion['citation'])
        fo_exclusion_emb = sentence_transformer.encode(fo_exclusion['citation'])
        
        # ROUGE 스코어 계산
        ba_linear_rouge = scorer.score(first_highlight, linear_result['ba_citation'])
        fo_linear_rouge = scorer.score(first_highlight, linear_result['fo_citation'])
        ba_binary_rouge = scorer.score(first_highlight, ba_binary['citation'])
        fo_binary_rouge = scorer.score(first_highlight, fo_binary['citation'])
        ba_exclusion_rouge = scorer.score(first_highlight, ba_exclusion['citation'])
        fo_exclusion_rouge = scorer.score(first_highlight, fo_exclusion['citation'])
        
        result = {
            'id': dataset_list[i]['id'],
            'highlight': first_highlight,
            # Linear search 결과
            'ba_linear_citation': linear_result['ba_citation'],
            'fo_linear_citation': linear_result['fo_citation'],
            'ba_linear_score': linear_result['ba_score'],
            'fo_linear_score': linear_result['fo_score'],
            'ba_linear_perplexity': linear_result['ba_perplexity'],
            'fo_linear_perplexity': linear_result['fo_perplexity'],
            'ba_linear_emb_similarity': cosine_similarity([highlight_emb], [ba_linear_emb])[0][0],
            'fo_linear_emb_similarity': cosine_similarity([highlight_emb], [fo_linear_emb])[0][0],
            'ba_linear_rougeL_fmeasure': ba_linear_rouge['rougeL'].fmeasure,
            'fo_linear_rougeL_fmeasure': fo_linear_rouge['rougeL'].fmeasure,
            # Binary search 결과
            'ba_binary_citation': ba_binary['citation'],
            'fo_binary_citation': fo_binary['citation'],
            'ba_binary_score': ba_binary['score'],
            'fo_binary_score': fo_binary['score'],
            'ba_binary_perplexity': ba_binary['perplexity'],
            'fo_binary_perplexity': fo_binary['perplexity'],
            'ba_binary_emb_similarity': cosine_similarity([highlight_emb], [ba_binary_emb])[0][0],
            'fo_binary_emb_similarity': cosine_similarity([highlight_emb], [fo_binary_emb])[0][0],
            'ba_binary_rougeL_fmeasure': ba_binary_rouge['rougeL'].fmeasure,
            'fo_binary_rougeL_fmeasure': fo_binary_rouge['rougeL'].fmeasure,
            # Exclusion search 결과
            'ba_exclusion_citation': ba_exclusion['citation'],
            'fo_exclusion_citation': fo_exclusion['citation'],
            'ba_exclusion_score': ba_exclusion['score'],
            'fo_exclusion_score': fo_exclusion['score'],
            'ba_exclusion_perplexity': ba_exclusion['perplexity'],
            'fo_exclusion_perplexity': fo_exclusion['perplexity'],
            'ba_exclusion_emb_similarity': cosine_similarity([highlight_emb], [ba_exclusion_emb])[0][0],
            'fo_exclusion_emb_similarity': cosine_similarity([highlight_emb], [fo_exclusion_emb])[0][0],
            'ba_exclusion_rougeL_fmeasure': ba_exclusion_rouge['rougeL'].fmeasure,
            'fo_exclusion_rougeL_fmeasure': fo_exclusion_rouge['rougeL'].fmeasure,
        }
        results.append(result)
    
    return results

# 비교 결과 표시 함수
def display_comparison_results(results):
    results_df = pd.DataFrame(results)
    
    # 평균 메트릭 계산
    metrics = ['score', 'perplexity', 'emb_similarity', 'rougeL_fmeasure']
    comparison_data = {}
    
    for model_type in ['ba', 'fo']:
        for search_type in ['linear', 'binary', 'exclusion']:
            col_prefix = f'{model_type}_{search_type}_'
            comparison_data[f'{model_type.upper()} {search_type.capitalize()}'] = {
                metric: results_df[f'{col_prefix}{metric}'].mean() 
                for metric in metrics
            }
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # 상세 결과 표시를 위한 DataFrame (Exclusion 추가)
    display_df = results_df[['highlight', 
                            'ba_linear_citation', 'ba_linear_score', 'ba_linear_perplexity',
                            'fo_linear_citation', 'fo_linear_score', 'fo_linear_perplexity',
                            'ba_binary_citation', 'ba_binary_score', 'ba_binary_perplexity',
                            'fo_binary_citation', 'fo_binary_score', 'fo_binary_perplexity',
                            'ba_exclusion_citation', 'ba_exclusion_score', 'ba_exclusion_perplexity',
                            'fo_exclusion_citation', 'fo_exclusion_score', 'fo_exclusion_perplexity']]
    
    return comparison_df, display_df

# 실행 및 결과 비교
dataset = load_cnn_dataset(num_samples=50)  # dataset이 dict 리스트라고 가정
comparison_results = evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10)
comparison_df, display_df = display_comparison_results(comparison_results)

# 결과 출력
print("Comparison of Average Metrics:")
print(comparison_df.round(4))
print("\nDetailed Results:")
print(display_df)

In [13]:
## 임계값 수정
import pandas as pd
import numpy as np
from tqdm import tqdm
from nltk import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer

# Exclusion Search 수정
def exclusion_search_citation(article, highlight, model, tokenizer, backward=False, threshold=-5.5):
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    best_score = float('-inf')
    best_citation = ''
    best_perplexity = float('inf')
    all_scores = []  # 점수 분포 확인용
    
    for i in range(len(sentences)):
        candidate = sentences[i]
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        all_scores.append(score)  # 모든 점수 기록
        
        if score >= threshold and score > best_score:
            best_score = score
            best_citation = candidate
            best_perplexity = perplexity
    
    # 디버깅: 점수 분포 출력
    if best_score == float('-inf'):
        print(f"Exclusion Search failed: No scores above threshold {threshold}")
        print(f"All scores: {all_scores}")
        print(f"Min: {min(all_scores)}, Max: {max(all_scores)}, Mean: {np.mean(all_scores)}")
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': best_perplexity
    }

# 나머지 함수 (Binary Search 등)는 그대로 유지
def binary_search_citation(article, highlight, model, tokenizer, backward=False, max_iterations=10):
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    left, right = 0, len(sentences) - 1
    best_score = float('-inf')
    best_citation = ''
    iteration = 0
    
    while left <= right and iteration < max_iterations:
        mid = (left + right) // 2
        start = max(0, mid - 1)
        end = min(len(sentences), mid + 2)
        candidate = ' '.join(sentences[start:end])
        
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        
        if score > best_score:
            best_score = score
            best_citation = candidate
        
        if mid > 0:
            left_candidate = ' '.join(sentences[max(0, mid-2):mid+1])
            left_score = calculate_llm_score(highlight, left_candidate, model, 
                                          tokenizer, task='citation', 
                                          backward=backward)['normalized_log_prob']
        else:
            left_score = float('-inf')
            
        if mid < len(sentences) - 1:
            right_candidate = ' '.join(sentences[mid:end+1])
            right_score = calculate_llm_score(highlight, right_candidate, model, 
                                           tokenizer, task='citation', 
                                           backward=backward)['normalized_log_prob']
        else:
            right_score = float('-inf')
            
        if left_score > score and left_score >= right_score:
            right = mid - 1
        elif right_score > score:
            left = mid + 1
        else:
            break
            
        iteration += 1
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': perplexity
    }

# 평가 함수 ( Exclusion 포함)
def evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10):
    fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()
    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    if isinstance(dataset, list):
        df_dataset = pd.DataFrame(dataset)
    else:
        df_dataset = pd.DataFrame(dataset)
        
    df_dataset = df_dataset.iloc[:num_samples]
    
    linear_results = linear_attribution_search(df_dataset, ba_model, ba_tokenizer, 
                                              fo_model, fo_tokenizer)
    
    results = []
    dataset_list = df_dataset.to_dict('records')
    
    for i in tqdm(range(min(num_samples, len(dataset_list)))):
        article = dataset_list[i]['article']
        highlight = dataset_list[i]['highlights']
        highlight_sentences = sent_tokenize(highlight)
        if not highlight_sentences:
            continue
        first_highlight = highlight_sentences[0]
        
        linear_result = linear_results[i]
        
        ba_binary = binary_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_binary = binary_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        ba_exclusion = exclusion_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_exclusion = exclusion_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        highlight_emb = sentence_transformer.encode(first_highlight)
        ba_linear_emb = sentence_transformer.encode(linear_result['ba_citation'])
        fo_linear_emb = sentence_transformer.encode(linear_result['fo_citation'])
        ba_binary_emb = sentence_transformer.encode(ba_binary['citation'])
        fo_binary_emb = sentence_transformer.encode(fo_binary['citation'])
        ba_exclusion_emb = sentence_transformer.encode(ba_exclusion['citation'])
        fo_exclusion_emb = sentence_transformer.encode(fo_exclusion['citation'])
        
        ba_linear_rouge = scorer.score(first_highlight, linear_result['ba_citation'])
        fo_linear_rouge = scorer.score(first_highlight, linear_result['fo_citation'])
        ba_binary_rouge = scorer.score(first_highlight, ba_binary['citation'])
        fo_binary_rouge = scorer.score(first_highlight, fo_binary['citation'])
        ba_exclusion_rouge = scorer.score(first_highlight, ba_exclusion['citation'])
        fo_exclusion_rouge = scorer.score(first_highlight, fo_exclusion['citation'])
        
        result = {
            'id': dataset_list[i]['id'],
            'highlight': first_highlight,
            'ba_linear_citation': linear_result['ba_citation'],
            'fo_linear_citation': linear_result['fo_citation'],
            'ba_linear_score': linear_result['ba_score'],
            'fo_linear_score': linear_result['fo_score'],
            'ba_linear_perplexity': linear_result['ba_perplexity'],
            'fo_linear_perplexity': linear_result['fo_perplexity'],
            'ba_linear_emb_similarity': cosine_similarity([highlight_emb], [ba_linear_emb])[0][0],
            'fo_linear_emb_similarity': cosine_similarity([highlight_emb], [fo_linear_emb])[0][0],
            'ba_linear_rougeL_fmeasure': ba_linear_rouge['rougeL'].fmeasure,
            'fo_linear_rougeL_fmeasure': fo_linear_rouge['rougeL'].fmeasure,
            'ba_binary_citation': ba_binary['citation'],
            'fo_binary_citation': fo_binary['citation'],
            'ba_binary_score': ba_binary['score'],
            'fo_binary_score': fo_binary['score'],
            'ba_binary_perplexity': ba_binary['perplexity'],
            'fo_binary_perplexity': fo_binary['perplexity'],
            'ba_binary_emb_similarity': cosine_similarity([highlight_emb], [ba_binary_emb])[0][0],
            'fo_binary_emb_similarity': cosine_similarity([highlight_emb], [fo_binary_emb])[0][0],
            'ba_binary_rougeL_fmeasure': ba_binary_rouge['rougeL'].fmeasure,
            'fo_binary_rougeL_fmeasure': fo_binary_rouge['rougeL'].fmeasure,
            'ba_exclusion_citation': ba_exclusion['citation'],
            'fo_exclusion_citation': fo_exclusion['citation'],
            'ba_exclusion_score': ba_exclusion['score'],
            'fo_exclusion_score': fo_exclusion['score'],
            'ba_exclusion_perplexity': ba_exclusion['perplexity'],
            'fo_exclusion_perplexity': fo_exclusion['perplexity'],
            'ba_exclusion_emb_similarity': cosine_similarity([highlight_emb], [ba_exclusion_emb])[0][0],
            'fo_exclusion_emb_similarity': cosine_similarity([highlight_emb], [fo_exclusion_emb])[0][0],
            'ba_exclusion_rougeL_fmeasure': ba_exclusion_rouge['rougeL'].fmeasure,
            'fo_exclusion_rougeL_fmeasure': fo_exclusion_rouge['rougeL'].fmeasure,
        }
        results.append(result)
    
    return results

# 결과 표시 함수
def display_comparison_results(results):
    results_df = pd.DataFrame(results)
    
    metrics = ['score', 'perplexity', 'emb_similarity', 'rougeL_fmeasure']
    comparison_data = {}
    
    for model_type in ['ba', 'fo']:
        for search_type in ['linear', 'binary', 'exclusion']:
            col_prefix = f'{model_type}_{search_type}_'
            comparison_data[f'{model_type.upper()} {search_type.capitalize()}'] = {
                metric: results_df[f'{col_prefix}{metric}'].mean() 
                for metric in metrics
            }
    
    comparison_df = pd.DataFrame(comparison_data)
    
    display_df = results_df[['highlight', 
                            'ba_linear_citation', 'ba_linear_score', 'ba_linear_perplexity',
                            'fo_linear_citation', 'fo_linear_score', 'fo_linear_perplexity',
                            'ba_binary_citation', 'ba_binary_score', 'ba_binary_perplexity',
                            'fo_binary_citation', 'fo_binary_score', 'fo_binary_perplexity',
                            'ba_exclusion_citation', 'ba_exclusion_score', 'ba_exclusion_perplexity',
                            'fo_exclusion_citation', 'fo_exclusion_score', 'fo_exclusion_perplexity']]
    
    return comparison_df, display_df

# 실행
dataset = load_cnn_dataset(num_samples=50)
comparison_results = evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10)
comparison_df, display_df = display_comparison_results(comparison_results)

print("Comparison of Average Metrics:")
print(comparison_df.round(4))
print("\nDetailed Results:")
print(display_df)

Dataset loaded successfully
Example dataset item: {'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places be

100%|██████████| 10/10 [00:27<00:00,  2.70s/it]
100%|██████████| 10/10 [00:34<00:00,  3.49s/it]

Comparison of Average Metrics:
                 BA Linear  BA Binary  BA Exclusion  FO Linear  FO Binary  \
score              -2.2773    -4.4284       -3.9165    -2.7101    -4.7341   
perplexity         10.7705   123.5415       60.2297    16.4731   142.0270   
emb_similarity      0.3561     0.5044        0.5186     0.3670     0.4591   
rougeL_fmeasure     0.1663     0.0970        0.2568     0.1704     0.1168   

                 FO Exclusion  
score                 -4.2481  
perplexity            86.0093  
emb_similarity         0.5733  
rougeL_fmeasure        0.2523  

Detailed Results:
                                                                                                                                                                                                 highlight  \
0                                                                                                                             Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Mond




In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from nltk import sent_tokenize
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer

# Exclusion Search 수정
def exclusion_search_citation(article, highlight, model, tokenizer, backward=False, threshold=-5.5):
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    best_score = float('-inf')
    best_citation = ''
    best_perplexity = float('inf')
    all_scores = []  # 점수 분포 확인용
    
    for i in range(len(sentences)):
        candidate = sentences[i]
        result = calculate_llm_score(highlight, candidate, model, tokenizer, 
                                   task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        all_scores.append(score)  # 모든 점수 기록
        
        if score >= threshold and score > best_score:
            best_score = score
            best_citation = candidate
            best_perplexity = perplexity
    
    # 디버깅: 점수 분포 항상 출력
    print(f"Threshold: {threshold}, Backward: {backward}")
    print(f"All scores: {all_scores}")
    print(f"Min: {min(all_scores)}, Max: {max(all_scores)}, Mean: {np.mean(all_scores)}")
    if best_score == float('-inf'):
        print(f"Exclusion Search failed: No scores above threshold {threshold}")
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': best_perplexity
    }

# 나머지 함수 (Binary Search 등)는 그대로 유지
def binary_search_citation(article, highlight, model, tokenizer, backward=False, max_iterations=10):
    sentences = sent_tokenize(article)
    if not sentences:
        return {
            'citation': '',
            'score': float('-inf'),
            'perplexity': float('inf')
        }
    
    left, right = 0, len(sentences) - 1
    best_score = float('-inf')
    best_citation = ''
    iteration = 0
    
    while left <= right and iteration < max_iterations:
        mid = (left + right) // 2
        start = max(0, mid - 1)
        end = min(len(sentences), mid + 2)
        candidate = ' '.join(sentences[start:end])
        
        result = calculate_llm_score(highlight, candidate, model, tokenizer, task='citation', backward=backward)
        score = result['normalized_log_prob']
        perplexity = result['perplexity']
        
        if score > best_score:
            best_score = score
            best_citation = candidate
        
        if mid > 0:
            left_candidate = ' '.join(sentences[max(0, mid-2):mid+1])
            left_score = calculate_llm_score(highlight, left_candidate, model, tokenizer, task='citation', backward=backward)['normalized_log_prob']
        else:
            left_score = float('-inf')
            
        if mid < len(sentences) - 1:
            right_candidate = ' '.join(sentences[mid:end+1])
            right_score = calculate_llm_score(highlight, right_candidate, model, tokenizer, task='citation', backward=backward)['normalized_log_prob']
        else:
            right_score = float('-inf')
            
        if left_score > score and left_score >= right_score:
            right = mid - 1
        elif right_score > score:
            left = mid + 1
        else:
            break
            
        iteration += 1
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': perplexity
    }

# 평가 함수 ( Exclusion 포함)
def evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10):
    fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()
    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    if isinstance(dataset, list):
        df_dataset = pd.DataFrame(dataset)
    else:
        df_dataset = pd.DataFrame(dataset)
        
    df_dataset = df_dataset.iloc[:num_samples]
    
    linear_results = linear_attribution_search(df_dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer)
    
    results = []
    dataset_list = df_dataset.to_dict('records')
    
    for i in tqdm(range(min(num_samples, len(dataset_list)))):
        article = dataset_list[i]['article']
        highlight = dataset_list[i]['highlights']
        highlight_sentences = sent_tokenize(highlight)
        if not highlight_sentences:
            continue
        first_highlight = highlight_sentences[0]
        
        linear_result = linear_results[i]
        
        ba_binary = binary_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_binary = binary_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        ba_exclusion = exclusion_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_exclusion = exclusion_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        highlight_emb = sentence_transformer.encode(first_highlight)
        ba_linear_emb = sentence_transformer.encode(linear_result['ba_citation'])
        fo_linear_emb = sentence_transformer.encode(linear_result['fo_citation'])
        ba_binary_emb = sentence_transformer.encode(ba_binary['citation'])
        fo_binary_emb = sentence_transformer.encode(fo_binary['citation'])
        ba_exclusion_emb = sentence_transformer.encode(ba_exclusion['citation'])
        fo_exclusion_emb = sentence_transformer.encode(fo_exclusion['citation'])
        
        ba_linear_rouge = scorer.score(first_highlight, linear_result['ba_citation'])
        fo_linear_rouge = scorer.score(first_highlight, linear_result['fo_citation'])
        ba_binary_rouge = scorer.score(first_highlight, ba_binary['citation'])
        fo_binary_rouge = scorer.score(first_highlight, fo_binary['citation'])
        ba_exclusion_rouge = scorer.score(first_highlight, ba_exclusion['citation'])
        fo_exclusion_rouge = scorer.score(first_highlight, fo_exclusion['citation'])
        
        result = {
            'id': dataset_list[i]['id'],
            'highlight': first_highlight,
            'ba_linear_citation': linear_result['ba_citation'],
            'fo_linear_citation': linear_result['fo_citation'],
            'ba_linear_score': linear_result['ba_score'],
            'fo_linear_score': linear_result['fo_score'],
            'ba_linear_perplexity': linear_result['ba_perplexity'],
            'fo_linear_perplexity': linear_result['fo_perplexity'],
            'ba_linear_emb_similarity': cosine_similarity([highlight_emb], [ba_linear_emb])[0][0],
            'fo_linear_emb_similarity': cosine_similarity([highlight_emb], [fo_linear_emb])[0][0],
            'ba_linear_rougeL_fmeasure': ba_linear_rouge['rougeL'].fmeasure,
            'fo_linear_rougeL_fmeasure': fo_linear_rouge['rougeL'].fmeasure,
            'ba_binary_citation': ba_binary['citation'],
            'fo_binary_citation': fo_binary['citation'],
            'ba_binary_score': ba_binary['score'],
            'fo_binary_score': fo_binary['score'],
            'ba_binary_perplexity': ba_binary['perplexity'],
            'fo_binary_perplexity': fo_binary['perplexity'],
            'ba_binary_emb_similarity': cosine_similarity([highlight_emb], [ba_binary_emb])[0][0],
            'fo_binary_emb_similarity': cosine_similarity([highlight_emb], [fo_binary_emb])[0][0],
            'ba_binary_rougeL_fmeasure': ba_binary_rouge['rougeL'].fmeasure,
            'fo_binary_rougeL_fmeasure': fo_binary_rouge['rougeL'].fmeasure,
            'ba_exclusion_citation': ba_exclusion['citation'],
            'fo_exclusion_citation': fo_exclusion['citation'],
            'ba_exclusion_score': ba_exclusion['score'],
            'fo_exclusion_score': fo_exclusion['score'],
            'ba_exclusion_perplexity': ba_exclusion['perplexity'],
            'fo_exclusion_perplexity': fo_exclusion['perplexity'],
            'ba_exclusion_emb_similarity': cosine_similarity([highlight_emb], [ba_exclusion_emb])[0][0],
            'fo_exclusion_emb_similarity': cosine_similarity([highlight_emb], [fo_exclusion_emb])[0][0],
            'ba_exclusion_rougeL_fmeasure': ba_exclusion_rouge['rougeL'].fmeasure,
            'fo_exclusion_rougeL_fmeasure': fo_exclusion_rouge['rougeL'].fmeasure,
        }
        results.append(result)
    
    return results

# 결과 표시 함수
def display_comparison_results(results):
    results_df = pd.DataFrame(results)
    
    metrics = ['score', 'perplexity', 'emb_similarity', 'rougeL_fmeasure']
    comparison_data = {}
    
    for model_type in ['ba', 'fo']:
        for search_type in ['linear', 'binary', 'exclusion']:
            col_prefix = f'{model_type}_{search_type}_'
            comparison_data[f'{model_type.upper()} {search_type.capitalize()}'] = {
                metric: results_df[f'{col_prefix}{metric}'].mean() 
                for metric in metrics
            }
    
    comparison_df = pd.DataFrame(comparison_data)
    
    display_df = results_df[['highlight', 
                            'ba_linear_citation', 'ba_linear_score', 'ba_linear_perplexity',
                            'fo_linear_citation', 'fo_linear_score', 'fo_linear_perplexity',
                            'ba_binary_citation', 'ba_binary_score', 'ba_binary_perplexity',
                            'fo_binary_citation', 'fo_binary_score', 'fo_binary_perplexity',
                            'ba_exclusion_citation', 'ba_exclusion_score', 'ba_exclusion_perplexity',
                            'fo_exclusion_citation', 'fo_exclusion_score', 'fo_exclusion_perplexity']]
    
    return comparison_df, display_df

# 실행
dataset = load_cnn_dataset(num_samples=50)
comparison_results = evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10)
comparison_df, display_df = display_comparison_results(comparison_results)

print("Comparison of Average Metrics:")
print(comparison_df.round(4))
print("\nDetailed Results:")
print(display_df)

Dataset loaded successfully
Example dataset item: {'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places be

100%|██████████| 10/10 [00:25<00:00,  2.59s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Threshold: -5.5, Backward: True
All scores: [np.float64(-2.846222613936687), np.float64(-4.6519432085338055), np.float64(-5.172498497143561), np.float64(-5.938754569582826), np.float64(-5.6581623533672305), np.float64(-4.511843394325089), np.float64(-5.182403379436483), np.float64(-5.567099073543462), np.float64(-5.742396932401047), np.float64(-5.94110779345042), np.float64(-4.699259599719864), np.float64(-5.262418250318978), np.float64(-5.599672200349584), np.float64(-5.948666970422905), np.float64(-5.25255796773786), np.float64(-5.634611066028285), np.float64(-5.812859001252745), np.float64(-5.462642501882827), np.float64(-5.522810877381604), np.float64(-5.641740091935267), np.float64(-5.347815617730112), np.float64(-6.092893389283156), np.float64(-5.703105197770184), np.float64(-6.734847225850045)]
Min: -6.734847225850045, Max: -2.846222613936687, Mean: -5.413680490557667


 10%|█         | 1/10 [00:02<00:18,  2.08s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-3.645555167755457), np.float64(-5.146176291128523), np.float64(-5.3287479232043795), np.float64(-5.660415499693124), np.float64(-5.488964610649174), np.float64(-5.392789259952316), np.float64(-5.456674705223214), np.float64(-5.694067906426286), np.float64(-5.604448268918845), np.float64(-5.530918438444369), np.float64(-5.428648871127159), np.float64(-5.3551674805461635), np.float64(-5.432228406202273), np.float64(-5.552268874733248), np.float64(-5.5578569511390805), np.float64(-5.269416796197849), np.float64(-5.703439512510779), np.float64(-5.622301872281226), np.float64(-5.775106915700825), np.float64(-5.642941042908834), np.float64(-5.436789925145045), np.float64(-5.71919906227357), np.float64(-5.796105285545565), np.float64(-5.596944351374337)]
Min: -5.796105285545565, Max: -3.645555167755457, Mean: -5.451548892461735
Threshold: -5.5, Backward: True
All scores: [np.float64(-4.3159322319033535), np.float64(-3.7714666476034022)

 20%|██        | 2/10 [00:07<00:32,  4.08s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.39186917583983), np.float64(-3.844838930437283), np.float64(-3.950094040782461), np.float64(-4.180288623277586), np.float64(-3.845903016087016), np.float64(-4.289958946974849), np.float64(-4.115544693303592), np.float64(-4.299873767124957), np.float64(-4.172623782815531), np.float64(-4.155329351163796), np.float64(-4.368578199902906), np.float64(-4.179436356051932), np.float64(-4.3422800594930315), np.float64(-4.337851401514609), np.float64(-4.403820126413883), np.float64(-4.2900630421376444), np.float64(-4.427350427290788), np.float64(-3.995431225422759), np.float64(-4.358408772510547), np.float64(-4.362714149463571), np.float64(-4.376907672369219), np.float64(-4.097866044387334), np.float64(-4.379996694766954), np.float64(-4.36987794847397), np.float64(-4.3570810215522755), np.float64(-4.12260235107426), np.float64(-4.244046817521921), np.float64(-4.291095862035564), np.float64(-4.202515013555763), np.float64(-4.384544281339

 30%|███       | 3/10 [00:12<00:30,  4.36s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.242310528086156), np.float64(-4.802663642775189), np.float64(-4.799825624765951), np.float64(-4.866386631234403), np.float64(-4.884776189511272), np.float64(-4.804787435468897), np.float64(-4.728115497223038), np.float64(-4.686593785914508), np.float64(-4.745654852531253), np.float64(-4.541746495758565), np.float64(-4.653714989249453), np.float64(-4.716837319468694), np.float64(-4.868784282955479), np.float64(-4.8933995477882055), np.float64(-4.681587219515623), np.float64(-4.6175829752161315), np.float64(-4.751499241814232), np.float64(-4.815246945060794), np.float64(-4.911251310254948), np.float64(-4.799814411323785), np.float64(-4.728789990721078), np.float64(-4.9189048119292575), np.float64(-4.803613322802412), np.float64(-4.75267973760951), np.float64(-4.87487123888092), np.float64(-4.704554213520304), np.float64(-4.7580194825847375), np.float64(-4.951658268402977), np.float64(-4.878947286415431), np.float64(-4.7086102693

 40%|████      | 4/10 [00:15<00:22,  3.76s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.407444818295728), np.float64(-5.525767833955772), np.float64(-6.327699038636361), np.float64(-6.152820951083948), np.float64(-6.151578356704914), np.float64(-6.274053616955063), np.float64(-6.067773171498153), np.float64(-6.075551651790042), np.float64(-6.159039969491084), np.float64(-6.290450187262086), np.float64(-6.136269699431862), np.float64(-6.205775231951863), np.float64(-6.268498376310681), np.float64(-6.052438248106661), np.float64(-5.768885930435765), np.float64(-5.985265601843643), np.float64(-5.874258889621061), np.float64(-5.553537916938882), np.float64(-5.632213198924582), np.float64(-5.919055265897052), np.float64(-5.930950383667339), np.float64(-6.2239333547612965), np.float64(-6.176015902989254), np.float64(-6.4033806671539555)]
Min: -6.4033806671539555, Max: -4.407444818295728, Mean: -5.9817774276544595
Threshold: -5.5, Backward: True
All scores: [np.float64(-4.910687889107966), np.float64(-4.796926707098869)

 50%|█████     | 5/10 [00:18<00:19,  3.81s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.8264197458156), np.float64(-5.239263520460517), np.float64(-5.54435194772343), np.float64(-5.355968873696927), np.float64(-5.541600241623098), np.float64(-5.122981224378168), np.float64(-5.433863920230701), np.float64(-5.536452786932605), np.float64(-5.300726192602114), np.float64(-5.063737987529788), np.float64(-5.335569760147363), np.float64(-5.197882966829074), np.float64(-5.546951452568003), np.float64(-5.516494514772479), np.float64(-5.546649074895373), np.float64(-5.618963985538239), np.float64(-5.50777113403779), np.float64(-5.531209178925955), np.float64(-5.642677161921784), np.float64(-5.500389437307288), np.float64(-5.830848500772731), np.float64(-5.4825985426050385), np.float64(-5.531250641151445), np.float64(-5.57980550010846), np.float64(-5.500183224033832), np.float64(-5.66217142363067), np.float64(-5.527525954440247), np.float64(-5.423497117758464), np.float64(-5.579771307046634), np.float64(-5.319675620189178),

 60%|██████    | 6/10 [00:23<00:15,  3.89s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-5.490757296783327), np.float64(-5.577543727185654), np.float64(-5.688816520352704), np.float64(-5.608470593129719), np.float64(-5.665069960035823), np.float64(-5.414088030147108), np.float64(-5.739236926016666), np.float64(-5.199151075015089), np.float64(-5.638462769936603), np.float64(-5.712573698327699), np.float64(-5.390600411274621), np.float64(-5.586453674062789), np.float64(-5.946328798172154), np.float64(-5.653751076052574), np.float64(-5.607958813278789), np.float64(-5.418524195189206), np.float64(-5.760799498630521), np.float64(-5.7357799595418815), np.float64(-5.64494867213765), np.float64(-5.869105725451281), np.float64(-6.053731940544534), np.float64(-5.791145292094618), np.float64(-5.802828202615809), np.float64(-5.822545267175033), np.float64(-5.7637967150446), np.float64(-5.7756655117548465), np.float64(-5.7875200918816825), np.float64(-5.7241855505138375), np.float64(-5.807420924191902), np.float64(-5.51909190381

 70%|███████   | 7/10 [00:27<00:12,  4.05s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-5.642167603506443), np.float64(-5.814455190705936), np.float64(-6.116211484963222), np.float64(-5.857146543944097), np.float64(-5.746554505380795), np.float64(-5.998833523152045), np.float64(-6.115039922124997), np.float64(-5.939582590449629), np.float64(-5.834795881414798), np.float64(-6.058544466883386), np.float64(-6.0515870757560934), np.float64(-5.3975999568952515), np.float64(-6.198340304979225), np.float64(-5.858682757644538), np.float64(-5.9869077991893125), np.float64(-5.925653450577576), np.float64(-6.0186952393584425), np.float64(-5.991088024928585), np.float64(-6.173225769056433), np.float64(-6.032765554313029), np.float64(-6.079297298626374), np.float64(-6.172930458206077), np.float64(-6.196325071749893), np.float64(-5.943791586722238), np.float64(-5.990233279561264), np.float64(-6.006517742726955), np.float64(-5.3457428724574765), np.float64(-6.06118226686152), np.float64(-6.181094030620899), np.float64(-6.10942133

 80%|████████  | 8/10 [00:29<00:06,  3.42s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-2.9230986136506076), np.float64(-2.4640345937161485), np.float64(-2.593958739415301), np.float64(-3.8761057924646627), np.float64(-3.7005967912040867), np.float64(-3.025345239554942), np.float64(-3.6956022498564027), np.float64(-3.1949059092989374), np.float64(-3.877955739365053), np.float64(-3.8546561754356192), np.float64(-3.335420250291718), np.float64(-3.8673574766705308), np.float64(-3.9862676756618463), np.float64(-3.6989303027332463)]
Min: -3.9862676756618463, Max: -2.4640345937161485, Mean: -3.435302539237079
Threshold: -5.5, Backward: True
All scores: [np.float64(-4.223143735011711), np.float64(-4.308420925670111), np.float64(-4.5253771850854365), np.float64(-4.417228125654098), np.float64(-4.618673421952801), np.float64(-4.683571937508197), np.float64(-4.753128290166204), np.float64(-4.08683692002764), np.float64(-4.711368701066884), np.float64(-4.3916693175370245), np.float64(-4.664408245481364), np.float64(-4.5191204

 90%|█████████ | 9/10 [00:32<00:03,  3.34s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.444862988769586), np.float64(-4.741393163750566), np.float64(-4.862099926691276), np.float64(-5.077642825364696), np.float64(-5.0782068565394365), np.float64(-5.155929959306233), np.float64(-5.065217737275989), np.float64(-4.635756376381157), np.float64(-5.16516511219852), np.float64(-5.022490106193841), np.float64(-5.080581255012485), np.float64(-4.951126808991562), np.float64(-5.166096386448184), np.float64(-4.9794187493251), np.float64(-5.05325155409011), np.float64(-5.060115328526896), np.float64(-5.074188808719837), np.float64(-5.049552625829539), np.float64(-5.127253712586026), np.float64(-5.091463359650426), np.float64(-5.023752967655584), np.float64(-5.0877889975639405), np.float64(-4.931869541523753), np.float64(-5.005168805636184), np.float64(-5.027905183107824), np.float64(-4.995852770950215), np.float64(-5.144439310422392), np.float64(-5.213466890084352), np.float64(-5.079657081771606), np.float64(-5.15417953317271

100%|██████████| 10/10 [00:34<00:00,  3.46s/it]

Threshold: -5.5, Backward: False
All scores: [np.float64(-4.800941940406676), np.float64(-4.695501436821312), np.float64(-5.022932637212317), np.float64(-5.253361870699207), np.float64(-5.265110791560191), np.float64(-5.460401638996527), np.float64(-5.097616790695083), np.float64(-5.625913514033305), np.float64(-5.5195562627236505), np.float64(-5.4621291941147625), np.float64(-5.650941508272733), np.float64(-4.91924796911515), np.float64(-5.565171981703483), np.float64(-5.88768922146215), np.float64(-5.5440165872104314), np.float64(-5.720110467059553)]
Min: -5.88768922146215, Max: -4.695501436821312, Mean: -5.343165238255408
Comparison of Average Metrics:
                 BA Linear  BA Binary  BA Exclusion  FO Linear  FO Binary  \
score              -2.2773    -4.4284       -3.9165    -2.7101    -4.7341   
perplexity         10.7705   123.5415       60.2297    16.4731   142.0270   
emb_similarity      0.3561     0.5044        0.5186     0.3670     0.4591   
rougeL_fmeasure     0.1663  




In [None]:
results = linear_attribution_search(pd.DataFrame(dataset), ba_model, ba_tokenizer, fo_model, fo_tokenizer)
results_df = pd.DataFrame(results)

# Display results in a more readable format
display_df = results_df[['highlight', 'ba_citation', 'ba_score', 'ba_perplexity', 
                         'fo_citation', 'fo_score', 'fo_perplexity']]

In [None]:
display_df.T

## Further Analysis

1. Analyze first sample citations is to check for a single article, what are the sentences' individual scores
2. Evaluate citations and display is to get the benchmark metrics
3. 

In [None]:
def analyze_first_sample_citations(dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer):
    """
    Analyze all possible citation sentences for the first highlight in the first sample.
    
    This function calculates scores for all article sentences against the first
    highlight sentence and returns them sorted by backward model score.
    """
    
    
    # Get the first sample
    first_sample = dataset.iloc[0]
    
    # Split article into sentences
    article_sentences = sent_tokenize(first_sample['article'])
    
    # Get the first highlight sentence
    highlight_sentences = sent_tokenize(first_sample['highlights'])
    if not highlight_sentences:
        print("No highlight sentences found!")
        return None
    
    highlight = highlight_sentences[0]
    print(f"Analyzing citations for highlight: \n'{highlight}'\n")
    
    # Calculate scores for all article sentences
    results = []
    
    for i, sentence in enumerate(article_sentences):
        # Skip very short sentences
        if len(sentence.split()) < 3:
            continue
            
        # Calculate scores using both models
        ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
        fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)
        
        # Add to results
        results.append({
            'sentence_idx': i,
            'article_sentence': sentence,
            'ba_score': ba_score['normalized_log_prob'],
            'ba_perplexity': ba_score['perplexity'],
            'fo_score': fo_score['normalized_log_prob'],
            'fo_perplexity': fo_score['perplexity']
        })
    
    # Create DataFrame and sort by backward model score (descending)
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('ba_score', ascending=False).reset_index(drop=True)
    
    # Set display options for better readability
    pd.set_option('display.max_colwidth', 70)
    pd.set_option('display.float_format', '{:.4f}'.format)
    
    return results_df

# Example usage:
citation_analysis = analyze_first_sample_citations(pd.DataFrame(dataset), 
                                                  ba_model, ba_tokenizer, 
                                                  fo_model, fo_tokenizer)
citation_analysis

In [None]:
def evaluate_citations(results, embedding_model_name='all-MiniLM-L6-v2'):
    """
    Evaluate citation quality using multiple metrics (ROUGE, embeddings, TF-IDF).
    
    Args:
        results: List of dictionaries with citation results
        embedding_model_name: Name of the sentence-transformers model to use
    
    Returns:
        Enhanced results with additional evaluation metrics
    """
    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # Initialize sentence transformer for embeddings
    embedding_model = SentenceTransformer(embedding_model_name, cache_folder='.cache/')
    
    # Initialize TF-IDF vectorizer
    tfidf_vectorizer = TfidfVectorizer(stop_words='english')
    
    # Process all highlights and citations to prepare TF-IDF matrix
    all_texts = []
    for result in results:
        all_texts.append(result['highlight'])
        all_texts.append(result['ba_citation'])
        all_texts.append(result['fo_citation'])
    
    # Fit the TF-IDF vectorizer
    tfidf_matrix = tfidf_vectorizer.fit_transform(all_texts)
    
    enhanced_results = []
    
    for i, result in enumerate(results):
        # Get text indices for the current result
        highlight_idx = i * 3
        ba_citation_idx = i * 3 + 1
        fo_citation_idx = i * 3 + 2
        
        # Get text for the current result
        highlight = result['highlight']
        ba_citation = result['ba_citation']
        fo_citation = result['fo_citation']
        
        # Calculate embeddings
        highlight_emb = embedding_model.encode([highlight])[0]
        ba_citation_emb = embedding_model.encode([ba_citation])[0]
        fo_citation_emb = embedding_model.encode([fo_citation])[0]
        
        # Calculate embedding similarity (cosine)
        ba_emb_similarity = cosine_similarity(
            highlight_emb.reshape(1, -1), 
            ba_citation_emb.reshape(1, -1)
        )[0][0]
        
        fo_emb_similarity = cosine_similarity(
            highlight_emb.reshape(1, -1), 
            fo_citation_emb.reshape(1, -1)
        )[0][0]
        
        # Calculate TF-IDF similarity
        ba_tfidf_similarity = cosine_similarity(
            tfidf_matrix[highlight_idx], 
            tfidf_matrix[ba_citation_idx]
        )[0][0]
        
        fo_tfidf_similarity = cosine_similarity(
            tfidf_matrix[highlight_idx], 
            tfidf_matrix[fo_citation_idx]
        )[0][0]
        
        # Calculate ROUGE scores
        ba_rouge = scorer.score(highlight, ba_citation)
        fo_rouge = scorer.score(highlight, fo_citation)
        
        # Create enhanced result with all metrics
        enhanced_result = result.copy()
        
        # Add backward model metrics
        enhanced_result.update({
            'ba_emb_similarity': ba_emb_similarity,
            'ba_tfidf_similarity': ba_tfidf_similarity,
            'ba_rouge1_precision': ba_rouge['rouge1'].precision,
            'ba_rouge1_recall': ba_rouge['rouge1'].recall,
            'ba_rouge1_fmeasure': ba_rouge['rouge1'].fmeasure,
            'ba_rouge2_fmeasure': ba_rouge['rouge2'].fmeasure,
            'ba_rougeL_fmeasure': ba_rouge['rougeL'].fmeasure,
        })
        
        # Add forward model metrics
        enhanced_result.update({
            'fo_emb_similarity': fo_emb_similarity,
            'fo_tfidf_similarity': fo_tfidf_similarity,
            'fo_rouge1_precision': fo_rouge['rouge1'].precision,
            'fo_rouge1_recall': fo_rouge['rouge1'].recall,
            'fo_rouge1_fmeasure': fo_rouge['rouge1'].fmeasure,
            'fo_rouge2_fmeasure': fo_rouge['rouge2'].fmeasure,
            'fo_rougeL_fmeasure': fo_rouge['rougeL'].fmeasure,
        })
        
        enhanced_results.append(enhanced_result)
    
    return enhanced_results

In [None]:
def display_evaluation_results(results, metrics_to_show=None):
    """
    Display evaluation results in a DataFrame.
    
    Args:
        results: List of dictionaries with enhanced evaluation metrics
        metrics_to_show: List of metric columns to display (if None, shows a default set)
    
    Returns:
        DataFrame with evaluation metrics
    """
    if metrics_to_show is None:
        # Default metrics to show
        metrics_to_show = [
            'highlight', 'ba_citation', 'fo_citation',
            'ba_perplexity', 'fo_perplexity',
            'ba_emb_similarity', 'fo_emb_similarity',
            'ba_rougeL_fmeasure', 'fo_rougeL_fmeasure'
        ]
    
    # Create DataFrame
    df = pd.DataFrame(results)
    
    # Select columns to display
    display_df = df[metrics_to_show]
    
    return display_df

# Example usage:
enhanced_results = evaluate_citations(results)
display_df = display_evaluation_results(enhanced_results)
display_df

In [None]:
# After you have your enhanced_results
df = pd.DataFrame(enhanced_results)

# Create two separate dataframes - one for each model type
ba_metrics = {col.replace('ba_', ''): df[col].mean() for col in df.columns 
              if col.startswith('ba_') and col != 'ba_citation'}
              
fo_metrics = {col.replace('fo_', ''): df[col].mean() for col in df.columns 
              if col.startswith('fo_') and col != 'fo_citation'}

# Combine into single comparison dataframe
comparison = pd.DataFrame({
    'Backward Model': ba_metrics,
    'Forward Model': fo_metrics
})

# Display the properly formatted comparison
comparison.round(4)