# Scoring and Citations Testbed
---

The objective of this notebook is to:
1. Experiment and figure out how to perform Scoring as described in the [paper](papers/TRLM_2412.02626.pdf)
2. Experiment with linear search for citation attribution

To further explore: 
1. Experiment with binary and exclusion search
2. Experiment with retrieval

## Import Libraries

In [None]:
import torch as t
import numpy as np
import pandas as pd
import torch.nn.functional as F
from tqdm.auto import tqdm

from transformers import GPTNeoXForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer, util

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize



In [5]:
pd.set_option('display.max_colwidth', 200)
pd.set_option('display.float_format', '{:.4f}'.format)

## Define Util Functions

In [6]:
device="cuda" if t.cuda.is_available() else "cpu"

In [7]:
# TODO: may be helpful to move this over to a utils.py later, or define the models as a separate classes?

# Load models
def load_models():
    # Forward model
    fo_model = GPTNeoXForCausalLM.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        revision="step143000",
        cache_dir="./.cache/pythia-160m-deduped/step143000",
    ).to(device)
    
    fo_tokenizer = AutoTokenizer.from_pretrained(
        "EleutherAI/pythia-160m-deduped",
        revision="step143000",
        cache_dir="./.cache/pythia-160m-deduped/step143000",
    )
    
    # Backward model
    ba_model = GPTNeoXForCausalLM.from_pretrained(
        "afterless/reverse-pythia-160m",
        cache_dir="./.cache/reverse-pythia-160m",
    ).to(device)
    
    ba_tokenizer = AutoTokenizer.from_pretrained(
        "afterless/reverse-pythia-160m",
        cache_dir="./.cache/reverse-pythia-160m",
    )
    
    return fo_model, fo_tokenizer, ba_model, ba_tokenizer

In [8]:
# TODO : same with above

# Load dataset
def load_cnn_dataset(num_samples=10):
    try:
        # Try with a specific cache directory
        dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir=".cache")
        print("Dataset loaded successfully")
        
        # Verify the structure - this helps debug
        if num_samples > 0:
            print("Example dataset item:", dataset['train'][0])
            
        # Take only a small sample for testing
        if hasattr(dataset, 'train'):
            return dataset['train'].select(range(min(num_samples, len(dataset['train']))))
        
        return dataset['train'][:num_samples]
        
    except Exception as e:
        print(f"Error loading full dataset: {e}")
        
        # Create a tiny synthetic dataset for testing
        print("Creating synthetic test dataset instead...")
        
        sample_data = {
            'article': [
                "John likes to play basketball. He goes to the court every evening. His friends join him on weekends.",
                "The company announced record profits. Investors were pleased. The stock price increased by 10%."
            ],
            'highlights': [
                "John plays basketball regularly with friends.",
                "Company profits lead to stock price increase."
            ],
            'id': ['test1', 'test2']  # Added ID field
        }
        
        return Dataset.from_dict(sample_data)

In [46]:
def calculate_llm_score(query, answer, model, tokenizer, task='citation', backward=False, debug=False):
    """
    Calculate log probability of response given prompt or vice versa.
    
    Args:
        query (str): The prompt text
        answer (str): The response text
        model: The language model
        tokenizer: The corresponding tokenizer
        direction (str): "forward" for P(response|prompt) or "backward" for P(prompt|response)
    
    Returns:
        dict: Contains token-wise and sequence log probabilities
    """
    
    # The paper describes "Score" as conditional distribution (Section 4) which means the Log Probability, and therefore
    # this reimplementation uses Log Probability.
    
    # The notation used here is P(Query|Answer) to make it easier to compare with the paper
    
    # First, prepare the texts
    if not backward: 
        #Forward
        conditioning_prompt = ' is a summary of ' if task =='citation' else ' has an answer to '
    else:
        #Backward
        conditioning_prompt = ' is summarized by ' if task =='citation' else ' is answered by '
    
    # DEBUG 
    if debug:
        print(f"Context: {answer + conditioning_prompt}")
        print(f"Target: {query}")
    
    # convert to tokens, but because the model is auto-regressive, it predicts left -> right 
    # (i.e. tokens at t, t+1, t+2 predicts the t+3, t+4)
    # so we may need to reverse the tokens? 
    # input_ids = tokenizer.encode(query + (conditioning_prompt + answer), return_tensors="pt").to(model.device)
    
    target_ids = tokenizer.encode(query, return_tensors="pt")
    context_ids = tokenizer.encode(answer + conditioning_prompt, return_tensors="pt")

    # store length to "divide" the texts later
    target_len = target_ids.shape[1]
    context_len = context_ids.shape[1]
    
    if backward:
        # We need to reverse the tokens in backward
        target_ids = t.flip(target_ids, (1,))
        context_ids = t.flip(context_ids, (1,))

    input_ids = t.cat((context_ids, target_ids), dim=1).to(model.device)

    # Get model output
    with t.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    
    # Extract token probabilities for the target text
    token_probs = []
    # Because the text is (query+answer), we just want to get the (answer) logits
    for i in range(context_len - 1, context_len + target_len - 1):
        
        # essentially, get the probability for the actual token at sequence, i.e. 
        # if "Harry Potter is the boy who survived" and i = 5
        # then we get the probabilities of the model output up until "who"
        # and find what's the probability of "survived"
        
        # get the logits [batch_size, sequence_length, vocabulary_size]
        next_token_logits = logits[0, i, :]  # no batch, sequence i, all vocab
        
        # get the actual token
        next_token_id = input_ids[0, i+1].item()
        
        # Convert logits to probabilities
        next_token_probs = F.softmax(next_token_logits, dim=0)
        prob = next_token_probs[next_token_id].item()
        log_prob = np.log(prob)
        
        token_text = tokenizer.decode([next_token_id])
        token_probs.append({
            'token': token_text,
            'token_id': next_token_id,
            'log_prob': log_prob
        })
    
    # Calculate sequence probability
    sequence_log_prob = sum(tp['log_prob'] for tp in token_probs)
    # Normalize by length to get per-token average
    normalized_log_prob = sequence_log_prob / len(token_probs)
    # Convert to perplexity if needed
    perplexity = np.exp(-sequence_log_prob / len(token_probs))
    
    return {
        'token_log_probs': token_probs,
        'sequence_log_prob': sequence_log_prob,
        'normalized_log_prob': normalized_log_prob,
        'perplexity': perplexity
    }



In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer

def calculate_tfidf_score(highlight, sentences, citation):
    """
    Calculate the maximum TF-IDF similarity between the highlight and a given citation
    among the provided sentences.
    """
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform([highlight] + [citation] + sentences)
    similarity = (tfidf_matrix * tfidf_matrix.T).toarray()[0][2:]  # highlight와 각 문장의 TF-IDF 유사도
    return max(similarity) if similarity.size > 0 else 0.0

In [47]:
# Testing

fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()


# Example Text
sentence = "Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won't cast a spell on him."
highlight = "Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday"
adverse_highlight = "Daniel Craig is recasted as James Bond again"

# Defining sentence/highlight query/answer is still confusing tho sheesh


# Define prompts
ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)

adv_ba_score = calculate_llm_score(sentence, adverse_highlight, ba_model, ba_tokenizer, backward=True)
adv_fo_score = calculate_llm_score(sentence, adverse_highlight, fo_model, fo_tokenizer)


scores_data = {
    'Model Type': ['Backward', 'Forward', 'Backward', 'Forward'],
    'Highlight': ['Correct', 'Correct', 'Adverse', 'Adverse'],
    'Sequence Log Prob': [
        ba_score['sequence_log_prob'],
        fo_score['sequence_log_prob'],
        adv_ba_score['sequence_log_prob'],
        adv_fo_score['sequence_log_prob']
    ],
    'Normalized Log Prob': [
        ba_score['normalized_log_prob'],
        fo_score['normalized_log_prob'],
        adv_ba_score['normalized_log_prob'],
        adv_fo_score['normalized_log_prob']
    ],
    'Perplexity': [
        ba_score['perplexity'],
        fo_score['perplexity'],
        adv_ba_score['perplexity'],
        adv_fo_score['perplexity']
    ]
}

# Create DataFrame
pd.DataFrame(scores_data)

Unnamed: 0,Model Type,Highlight,Sequence Log Prob,Normalized Log Prob,Perplexity
0,Backward,Correct,-113.2715,-2.7627,15.8429
1,Forward,Correct,-113.7015,-2.7732,16.0099
2,Backward,Adverse,-137.511,-3.3539,28.6148
3,Forward,Adverse,-151.0687,-3.6846,39.8293


##  Citation, Linear Search

In [11]:
dataset = load_cnn_dataset(num_samples=50)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Dataset loaded successfully
Example dataset item: {'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places be

In [12]:
# Show dataframe
pd.DataFrame(dataset)

Unnamed: 0,article,highlights,id
0,"LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won't cast a spell...",Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .\nYoung actor says he has no plans to fritter his cash away .\nRadcliffe's earnings from first five Potter films have be...,42c027e4ff9730fbb3de84c1af0d2c506e41c3e4
1,"Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O'Brien takes users inside a ja...","Mentally ill inmates in Miami are housed on the ""forgotten floor""\nJudge Steven Leifman says most are there as a result of ""avoidable felonies""\nWhile CNN tours facility, patient shouts: ""I am the...",ee8871b15c50d0db17b0179a6d2beab35065f1e9
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who were on the Minneapolis bridge when it collapsed told harrowing tales of survival. ""The whole bridge from one side of the Mississippi to the other just ...","NEW: ""I thought I was going to die,"" driver says .\nMan says pickup truck was folded in half; he just has cut on face .\nDriver: ""I probably had a 30-, 35-foot free fall""\nMinnesota bridge collaps...",06352019a19ae31e527f37f7571c6dd7f0c5da37
3,"WASHINGTON (CNN) -- Doctors removed five small polyps from President Bush's colon on Saturday, and ""none appeared worrisome,"" a White House spokesman said. The polyps were removed and sent to the ...","Five small polyps found during procedure; ""none worrisome,"" spokesman says .\nPresident reclaims powers transferred to vice president .\nBush undergoes routine colonoscopy at Camp David .",24521a2abb2e1f5e34e6824e0f9e56904a2b0e88
4,"(CNN) -- The National Football League has indefinitely suspended Atlanta Falcons quarterback Michael Vick without pay, officials with the league said Friday. NFL star Michael Vick is set to appea...","NEW: NFL chief, Atlanta Falcons owner critical of Michael Vick's conduct .\nNFL suspends Falcons quarterback indefinitely without pay .\nVick admits funding dogfighting operation but says he did n...",7fe70cc8b12fab2d0a258fababf7d9c6b5e1262a
5,"BAGHDAD, Iraq (CNN) -- Dressed in a Superman shirt, 5-year-old Youssif held his sister's hand Friday, seemingly unaware that millions of people across the world have been touched by his story. Nea...","Parents beam with pride, can't stop from smiling from outpouring of support .\nMom: ""I was so happy I didn't know what to do""\nBurn center in U.S. has offered to provide treatment for reconstructi...",a1ebb8bb4d370a1fdf28769206d572be60642d70
6,"BAGHDAD, Iraq (CNN) -- The women are too afraid and ashamed to show their faces or have their real names used. They have been driven to sell their bodies to put food on the table for their childre...","Aid workers: Violence, increased cost of living drive women to prostitution .\nGroup is working to raise awareness of the problem with Iraq's political leaders .\nTwo Iraqi mothers tell CNN they t...",7c0e61ac829a3b3b653e2e3e7536cc4881d1f264
7,"BOGOTA, Colombia (CNN) -- A key rebel commander and fugitive from a U.S. drug trafficking indictment was killed over the weekend in an air attack on a guerrilla encampment, the Colombian military ...","Tomas Medina Caracas was a fugitive from a U.S. drug trafficking indictment .\n""El Negro Acacio"" allegedly helped manage extensive cocaine network .\nU.S. Justice Department indicted him in 2002 ....",f0d73bdab711763e745cdc75850861c9018f235d
8,"WASHINGTON (CNN) -- White House press secretary Tony Snow, who is undergoing treatment for cancer, will step down from his post September 14 and be replaced by deputy press secretary Dana Perino, ...","President Bush says Tony Snow ""will battle cancer and win"" Job of press secretary ""has been a dream for me,"" Snow says Snow leaving on September 14, will be succeeded by Dana Perino .",5e22bbfc7232418b8d2dd646b952e404df5bd048
9,"(CNN) -- Police and FBI agents are investigating the discovery of an empty rocket launcher tube on the front lawn of a Jersey City, New Jersey, home, FBI spokesman Sean Quinn said. Niranjan Desai ...","Empty anti-tank weapon turns up in front of New Jersey home .\nDevice handed over to Army ordnance disposal unit .\nWeapon not capable of being reloaded, experts say .",613d6311ec2c1985bd44707d1796d275452fe156


In [35]:
def linear_attribution_search(dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer):
    """
    Perform linear attribution search for citations as described in TRLM paper.
    
    For each highlight (summary sentence), find the most likely article sentence
    that it was derived from by scoring all possible pairs.
    """
    results = []
    
    # Process only the first few examples for demonstration
    for idx, example in tqdm(dataset.iterrows(), total=len(dataset)):
        # Split article and highlights into sentences
        article_sentences = sent_tokenize(example['article'])
        highlight_sentences = sent_tokenize(example['highlights'])
        
        # For demonstration, process just the first highlight sentence
        if not highlight_sentences:
            continue
            
        highlight = highlight_sentences[0]
        
        # Store best attribution for each model
        best_ba_sentence = None
        best_ba_score = float('-inf')
        best_fo_sentence = None
        best_fo_score = float('-inf')
        
        # Linear search through all article sentences
        for sentence in article_sentences:
            # Skip very short sentences
            if len(sentence.split()) < 3:
                continue
                
            # Calculate scores using both models
            ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
            fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)
            
            # Track best scores
            if ba_score['normalized_log_prob'] > best_ba_score:
                best_ba_score = ba_score['normalized_log_prob']
                best_ba_sentence = sentence
                
            if fo_score['normalized_log_prob'] > best_fo_score:
                best_fo_score = fo_score['normalized_log_prob']
                best_fo_sentence = sentence
        
        # Add results to our list
        results.append({
            'id': example['id'],
            'highlight': highlight,
            'ba_citation': best_ba_sentence,
            'ba_score': best_ba_score,
            'ba_perplexity': np.exp(-best_ba_score),
            'fo_citation': best_fo_sentence,
            'fo_score': best_fo_score,
            'fo_perplexity': np.exp(-best_fo_score)
        })
    
    
    return results


In [59]:
# concat 수정본
def binary_search_citation(article, highlight, model, tokenizer, backward=False, max_iterations=30):
    sentences = sent_tokenize(article)
    if not sentences:
        return {'citation': '', 'score': float('-inf'), 'perplexity': float('inf')}
    
    highlight_sentences = sent_tokenize(highlight)
    if not highlight_sentences:
        return {'citation': '', 'score': float('-inf'), 'perplexity': float('inf')}
    highlight = highlight_sentences[0]
    
    def binary_search_recursive(s, t, iteration=0):
        if t - s <= 0 or iteration >= max_iterations:
            if t < s:
                return '', float('-inf'), float('inf')
            a_half = ' '.join(sentences[s:t + 1])
            result = calculate_llm_score(a_half, highlight, model, tokenizer, backward=backward)
            score = result['normalized_log_prob']
            perplexity = result['perplexity']
            return a_half, score, perplexity
        
        mid = s + (t - s) // 2
        a_half1 = ' '.join(sentences[s:mid + 1])
        a_half2 = ' '.join(sentences[mid + 1:t + 1])
        result1 = calculate_llm_score(a_half1, highlight, model, tokenizer, backward=backward)
        result2 = calculate_llm_score(a_half2, highlight, model, tokenizer, backward=backward)
        s1, p1 = result1['normalized_log_prob'], result1['perplexity']
        s2, p2 = result2['normalized_log_prob'], result2['perplexity']
        
        print(f"Binary Search (Backward={backward}): s={s}, t={t}, Mid={mid}, s1={s1}, s2={s2}")
        
        if s1 > s2:
            return binary_search_recursive(s, mid, iteration + 1)
        else:
            return binary_search_recursive(mid + 1, t, iteration + 1)
    
    s, t = 0, len(sentences) - 1
    citation, score, perplexity = binary_search_recursive(s, t)
    
    if not citation:
        return {'citation': '', 'score': float('-inf'), 'perplexity': float('inf')}
    
    print(f"Binary Search (Backward={backward}) Final: Score={score}, Citation={citation[:50]}...")
    return {
        'citation': citation,
        'score': score,
        'perplexity': perplexity
    }

In [60]:
# exclusion_search_citation (dynamic threshold)
def exclusion_search_citation(article, highlight, model, tokenizer, backward=False, threshold=None):
    sentences = sent_tokenize(article)
    if not sentences:
        return {'citation': '', 'score': float('-inf'), 'perplexity': float('inf')}
    
    # Highlight를 문장 단위로 분리하고 첫 번째 문장만 사용
    highlight_sentences = sent_tokenize(highlight)
    if not highlight_sentences:
        return {'citation': '', 'score': float('-inf'), 'perplexity': float('inf')}
    highlight = highlight_sentences[0]  # 첫 번째 문장만 선택
    
    all_scores = []
    for i in range(len(sentences)):
        result = calculate_llm_score(sentences[i], highlight, model, tokenizer, task='citation', backward=backward)
        score = result['normalized_log_prob']
        all_scores.append(score)
    
    if threshold is None:
        threshold = sorted(all_scores)[-int(len(all_scores) * 0.05)]  # Top 5%
        threshold = max(threshold, np.mean(all_scores) + np.std(all_scores))  # 보정
    
    best_score = float('-inf')
    best_citation = ''
    best_perplexity = float('inf')
    for i in range(len(sentences)):
        if all_scores[i] >= threshold:
            continue
        result = calculate_llm_score(sentences[i], highlight, model, tokenizer, task='citation', backward=backward)
        score = result['normalized_log_prob']
        if score > best_score:
            best_score = score
            best_citation = sentences[i]
            best_perplexity = result['perplexity']
    
    print(f"Threshold: {threshold}, Backward: {backward}")
    print(f"All scores: {all_scores}")
    print(f"Min: {min(all_scores)}, Max: {max(all_scores)}, Mean: {np.mean(all_scores)}")
    
    return {
        'citation': best_citation,
        'score': best_score,
        'perplexity': best_perplexity
    }

In [None]:
# evaluate_citations_with_linear_binary_exclusion (dynamic threshold)
def evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10):
    fo_model, fo_tokenizer, ba_model, ba_tokenizer = load_models()
    sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    if isinstance(dataset, list):
        df_dataset = pd.DataFrame(dataset)
    else:
        df_dataset = pd.DataFrame(dataset)
        
    df_dataset = df_dataset.iloc[:num_samples]
    
    linear_results = linear_attribution_search(df_dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer)
    
    results = []
    dataset_list = df_dataset.to_dict('records')
    
    for i in tqdm(range(min(num_samples, len(dataset_list)))):
        article = dataset_list[i]['article']
        highlight = dataset_list[i]['highlights']
        highlight_sentences = sent_tokenize(highlight)
        if not highlight_sentences:
            continue
        first_highlight = highlight_sentences[0]
        
        linear_result = linear_results[i]
        
        ba_binary = binary_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True)
        fo_binary = binary_search_citation(article, first_highlight, fo_model, fo_tokenizer)
        
        ba_exclusion = exclusion_search_citation(article, first_highlight, ba_model, ba_tokenizer, backward=True, threshold=None)
        fo_exclusion = exclusion_search_citation(article, first_highlight, fo_model, fo_tokenizer, threshold=None)
        
        highlight_emb = sentence_transformer.encode(first_highlight)
        ba_linear_emb = sentence_transformer.encode(linear_result['ba_citation'])
        fo_linear_emb = sentence_transformer.encode(linear_result['fo_citation'])
        ba_binary_emb = sentence_transformer.encode(ba_binary['citation'])
        fo_binary_emb = sentence_transformer.encode(fo_binary['citation'])
        ba_exclusion_emb = sentence_transformer.encode(ba_exclusion['citation'])
        fo_exclusion_emb = sentence_transformer.encode(fo_exclusion['citation'])
        
        ba_linear_rouge = scorer.score(first_highlight, linear_result['ba_citation'])
        fo_linear_rouge = scorer.score(first_highlight, linear_result['fo_citation'])
        ba_binary_rouge = scorer.score(first_highlight, ba_binary['citation'])
        fo_binary_rouge = scorer.score(first_highlight, fo_binary['citation'])
        ba_exclusion_rouge = scorer.score(first_highlight, ba_exclusion['citation'])
        fo_exclusion_rouge = scorer.score(first_highlight, fo_exclusion['citation'])
        
        # TF-IDF 점수 계산
        sentences = sent_tokenize(article)
        ba_linear_tfidf = calculate_tfidf_score(first_highlight, sentences, linear_result['ba_citation'])
        fo_linear_tfidf = calculate_tfidf_score(first_highlight, sentences, linear_result['fo_citation'])
        ba_binary_tfidf = calculate_tfidf_score(first_highlight, sentences, ba_binary['citation'])
        fo_binary_tfidf = calculate_tfidf_score(first_highlight, sentences, fo_binary['citation'])
        ba_exclusion_tfidf = calculate_tfidf_score(first_highlight, sentences, ba_exclusion['citation'])
        fo_exclusion_tfidf = calculate_tfidf_score(first_highlight, sentences, fo_exclusion['citation'])
        
        result = {
            'id': dataset_list[i]['id'],
            'highlight': first_highlight,
            'ba_linear_citation': linear_result['ba_citation'],
            'fo_linear_citation': linear_result['fo_citation'],
            'ba_linear_score': linear_result['ba_score'],
            'fo_linear_score': linear_result['fo_score'],
            'ba_linear_perplexity': linear_result['ba_perplexity'],
            'fo_linear_perplexity': linear_result['fo_perplexity'],
            'ba_linear_emb_similarity': util.cos_sim([highlight_emb], [ba_linear_emb])[0][0].item(),
            'fo_linear_emb_similarity': util.cos_sim([highlight_emb], [fo_linear_emb])[0][0].item(),
            'ba_linear_rougeL_fmeasure': ba_linear_rouge['rougeL'].fmeasure,
            'fo_linear_rougeL_fmeasure': fo_linear_rouge['rougeL'].fmeasure,
            'ba_linear_tfidf': ba_linear_tfidf,
            'fo_linear_tfidf': fo_linear_tfidf,
            'ba_binary_citation': ba_binary['citation'],
            'fo_binary_citation': fo_binary['citation'],
            'ba_binary_score': ba_binary['score'],
            'fo_binary_score': fo_binary['score'],
            'ba_binary_perplexity': ba_binary['perplexity'],
            'fo_binary_perplexity': fo_binary['perplexity'],
            'ba_binary_emb_similarity': util.cos_sim([highlight_emb], [ba_binary_emb])[0][0].item(),
            'fo_binary_emb_similarity': util.cos_sim([highlight_emb], [fo_binary_emb])[0][0].item(),
            'ba_binary_rougeL_fmeasure': ba_binary_rouge['rougeL'].fmeasure,
            'fo_binary_rougeL_fmeasure': fo_binary_rouge['rougeL'].fmeasure,
            'ba_binary_tfidf': ba_binary_tfidf,
            'fo_binary_tfidf': fo_binary_tfidf,
            'ba_exclusion_citation': ba_exclusion['citation'],
            'fo_exclusion_citation': fo_exclusion['citation'],
            'ba_exclusion_score': ba_exclusion['score'],
            'fo_exclusion_score': fo_exclusion['score'],
            'ba_exclusion_perplexity': ba_exclusion['perplexity'],
            'fo_exclusion_perplexity': fo_exclusion['perplexity'],
            'ba_exclusion_emb_similarity': util.cos_sim([highlight_emb], [ba_exclusion_emb])[0][0].item(),
            'fo_exclusion_emb_similarity': util.cos_sim([highlight_emb], [fo_exclusion_emb])[0][0].item(),
            'ba_exclusion_rougeL_fmeasure': ba_exclusion_rouge['rougeL'].fmeasure,
            'fo_exclusion_rougeL_fmeasure': fo_exclusion_rouge['rougeL'].fmeasure,
            'ba_exclusion_tfidf': ba_exclusion_tfidf,
            'fo_exclusion_tfidf': fo_exclusion_tfidf,
        }
        results.append(result)
    
    return results

# display_comparison_results (수정된 버전)
def display_comparison_results(results):
    results_df = pd.DataFrame(results)
    
    # results_df의 열 이름 확인 (디버�ング용)
    print("Columns in results_df:", results_df.columns.tolist())
    
    metrics = ['score', 'perplexity', 'emb_similarity', 'rougeL_fmeasure', 'tfidf']
    comparison_data = {}
    
    for model_type in ['ba', 'fo']:
        for search_type in ['linear', 'binary', 'exclusion']:
            col_prefix = f'{model_type}_{search_type}_'
            comparison_data[f'{model_type.upper()} {search_type.capitalize()}'] = {
                metric: results_df[f'{col_prefix}{metric}'].mean() 
                for metric in metrics
            }
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # display_df에 모든 열 포함
    display_df = results_df[[
        'highlight',
        # BA Linear
        'ba_linear_citation', 'ba_linear_score', 'ba_linear_perplexity', 'ba_linear_emb_similarity', 'ba_linear_rougeL_fmeasure', 'ba_linear_tfidf',
        # FO Linear
        'fo_linear_citation', 'fo_linear_score', 'fo_linear_perplexity', 'fo_linear_emb_similarity', 'fo_linear_rougeL_fmeasure', 'fo_linear_tfidf',
        # BA Binary
        'ba_binary_citation', 'ba_binary_score', 'ba_binary_perplexity', 'ba_binary_emb_similarity', 'ba_binary_rougeL_fmeasure', 'ba_binary_tfidf',
        # FO Binary
        'fo_binary_citation', 'fo_binary_score', 'fo_binary_perplexity', 'fo_binary_emb_similarity', 'fo_binary_rougeL_fmeasure', 'fo_binary_tfidf',
        # BA Exclusion
        'ba_exclusion_citation', 'ba_exclusion_score', 'ba_exclusion_perplexity', 'ba_exclusion_emb_similarity', 'ba_exclusion_rougeL_fmeasure', 'ba_exclusion_tfidf',
        # FO Exclusion
        'fo_exclusion_citation', 'fo_exclusion_score', 'fo_exclusion_perplexity', 'fo_exclusion_emb_similarity', 'fo_exclusion_rougeL_fmeasure', 'fo_exclusion_tfidf'
    ]]
    
    # pandas 출력 설정
    pd.set_option('display.max_columns', None)
    pd.set_option('display.max_colwidth', 50)
    
    return comparison_df, display_df

# execute
dataset = load_cnn_dataset(num_samples=50)
comparison_results = evaluate_citations_with_linear_binary_exclusion(dataset, num_samples=10)
comparison_df, display_df = display_comparison_results(comparison_results)

print("Comparison of Average Metrics:")
print(comparison_df.round(4))
print("\nDetailed Results:")
print(display_df)

Dataset loaded successfully
Example dataset item: {'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places be

100%|██████████| 10/10 [00:25<00:00,  2.57s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Binary Search (Backward=True): s=0, t=23, Mid=11, s1=-2.977705282264682, s2=-3.194227808243272
Binary Search (Backward=True): s=0, t=11, Mid=5, s1=-3.005053265259954, s2=-3.1527275183372483
Binary Search (Backward=True): s=0, t=5, Mid=2, s1=-2.856931128816732, s2=-3.658465028547803
Binary Search (Backward=True): s=0, t=2, Mid=1, s1=-2.9076363216575167, s2=-3.2597458041810436
Binary Search (Backward=True): s=0, t=1, Mid=0, s1=-2.670148073404129, s2=-3.2587536024251884
Binary Search (Backward=True) Final: Score=-2.670148073404129, Citation=LONDON, England (Reuters) -- Harry Potter star Dan...
Binary Search (Backward=False): s=0, t=23, Mid=11, s1=-2.9589545213599053, s2=-3.2560818978961974
Binary Search (Backward=False): s=0, t=11, Mid=5, s1=-3.0057152874467, s2=-3.2581350867401944
Binary Search (Backward=False): s=0, t=5, Mid=2, s1=-2.790441799686955, s2=-3.8938452826852963
Binary Search (Backward=False): s=0, t=2, Mid=1, s1=-2.780818273672728, s2=-3.468478783882734
Binary Search (Backwa

 10%|█         | 1/10 [00:05<00:47,  5.27s/it]

Threshold: -1.8230010077121352, Backward: False
All scores: [np.float64(-2.646627525020814), np.float64(-3.0883585649452576), np.float64(-3.468478783882734), np.float64(-4.230653373198902), np.float64(-4.66487109822079), np.float64(-4.030357543959578), np.float64(-5.377750106660205), np.float64(-4.76638888370393), np.float64(-4.141797468545643), np.float64(-4.234976801441224), np.float64(-3.771874656819574), np.float64(-3.1093667500317768), np.float64(-4.761407855445343), np.float64(-3.829812465532803), np.float64(-2.9522379277327744), np.float64(-6.472597771193319), np.float64(-6.160806200333863), np.float64(-3.8838144808530832), np.float64(-4.3107298486868055), np.float64(-4.328373863896082), np.float64(-3.7427461502485357), np.float64(-4.574478937327841), np.float64(-7.644641609227622), np.float64(-1.8230010077121352)]
Min: -7.644641609227622, Max: -1.8230010077121352, Mean: -4.250672903109193
Binary Search (Backward=True): s=0, t=44, Mid=22, s1=-2.9671204877814654, s2=-2.9751115267

 20%|██        | 2/10 [00:15<01:06,  8.35s/it]

Threshold: -2.886299525210059, Backward: False
All scores: [np.float64(-3.690593316122967), np.float64(-3.3931450827272265), np.float64(-2.4917715326104934), np.float64(-4.126935135230957), np.float64(-3.4995405796408097), np.float64(-4.742182463865164), np.float64(-3.4600866870758735), np.float64(-4.325595118079151), np.float64(-4.458342912532219), np.float64(-3.7191146373497306), np.float64(-4.000357781388404), np.float64(-4.782999713179824), np.float64(-3.7775730897966886), np.float64(-5.06012058681872), np.float64(-4.981569433296433), np.float64(-4.36824938319623), np.float64(-4.335263684553497), np.float64(-3.0603452752035487), np.float64(-3.956510008005014), np.float64(-3.8807567686003823), np.float64(-5.0698250196801355), np.float64(-2.886299525210059), np.float64(-4.883931241860544), np.float64(-7.241787550965899), np.float64(-4.105243759743119), np.float64(-4.271583964226092), np.float64(-3.9486035662572427), np.float64(-4.98398459755832), np.float64(-3.599048743355258), np.fl

 30%|███       | 3/10 [00:25<01:01,  8.76s/it]

Threshold: -3.1863910230960126, Backward: False
All scores: [np.float64(-3.5209296531180594), np.float64(-4.64179232256793), np.float64(-5.260960761870811), np.float64(-4.114305671789156), np.float64(-7.374074074819944), np.float64(-4.659273415807462), np.float64(-3.53176102330444), np.float64(-5.397294732084145), np.float64(-5.534572632451656), np.float64(-7.208723176435448), np.float64(-3.4746379001893404), np.float64(-3.8725285379146777), np.float64(-3.227724989662316), np.float64(-4.404577949419825), np.float64(-3.7992229993779696), np.float64(-4.47558509931042), np.float64(-4.996591755660581), np.float64(-5.300393211444209), np.float64(-3.964554994342896), np.float64(-4.088140747457815), np.float64(-4.003786915665077), np.float64(-4.0082839191552715), np.float64(-4.995683184021818), np.float64(-4.51708891463529), np.float64(-4.829074365219806), np.float64(-3.910599079794222), np.float64(-4.116285999496686), np.float64(-6.361574665224402), np.float64(-5.775086239050942), np.float64

 40%|████      | 4/10 [00:29<00:43,  7.23s/it]

Threshold: -3.060066021393866, Backward: False
All scores: [np.float64(-3.077107852258876), np.float64(-3.303316560743505), np.float64(-4.4316741951198075), np.float64(-3.3581115306768075), np.float64(-5.027226245607877), np.float64(-4.926309652651832), np.float64(-4.340502372743222), np.float64(-5.116176963378833), np.float64(-4.0189368339962375), np.float64(-5.006225339059681), np.float64(-3.4023349330636967), np.float64(-4.154950919322929), np.float64(-5.798686879136889), np.float64(-3.060066021393866), np.float64(-3.543157337275913), np.float64(-4.552812002456096), np.float64(-3.1857400414840553), np.float64(-3.558234677505427), np.float64(-4.958940860957651), np.float64(-4.508249201077284), np.float64(-7.363535813175554), np.float64(-5.303979226936529), np.float64(-4.047093534353271), np.float64(-4.46653544733011)]
Min: -7.363535813175554, Max: -3.060066021393866, Mean: -4.3545793517377485
Binary Search (Backward=True): s=0, t=45, Mid=22, s1=-2.891959569805292, s2=-2.9647612133405

 50%|█████     | 5/10 [00:39<00:39,  7.99s/it]

Threshold: -2.9035316867545053, Backward: False
All scores: [np.float64(-3.0223622380248196), np.float64(-3.172652419403788), np.float64(-4.076653162578594), np.float64(-3.434097227578984), np.float64(-4.482498618752335), np.float64(-2.95203449895906), np.float64(-4.127864794376364), np.float64(-4.032980572281739), np.float64(-3.6069620529605086), np.float64(-3.559512054086917), np.float64(-5.007780164849885), np.float64(-3.1983430544815312), np.float64(-3.27623801964935), np.float64(-2.8821513385546593), np.float64(-4.545401102396126), np.float64(-3.574670475738852), np.float64(-3.996969446219444), np.float64(-5.21534067721837), np.float64(-6.0251090333126305), np.float64(-4.977633990813212), np.float64(-4.772573593088308), np.float64(-4.55814434439669), np.float64(-5.610802736860269), np.float64(-4.7895564401330715), np.float64(-4.749537735311028), np.float64(-3.485578617032097), np.float64(-2.9035316867545053), np.float64(-4.868805645872046), np.float64(-3.9855849018031817), np.floa

 60%|██████    | 6/10 [00:47<00:31,  7.98s/it]

Threshold: -3.081344375530008, Backward: False
All scores: [np.float64(-3.468368791871809), np.float64(-4.492340066002298), np.float64(-5.67438451343674), np.float64(-4.860801170992751), np.float64(-4.689357379083881), np.float64(-3.081344375530008), np.float64(-4.015383010204013), np.float64(-3.0644928200733443), np.float64(-3.1863205276195457), np.float64(-5.2077263087919725), np.float64(-4.03417527094971), np.float64(-3.695178862205983), np.float64(-4.713775776809053), np.float64(-4.345203802405489), np.float64(-4.313334229451195), np.float64(-5.262520666782068), np.float64(-4.103322363611226), np.float64(-6.207869003662982), np.float64(-3.358037326877852), np.float64(-3.812571197724209), np.float64(-4.688583204310415), np.float64(-4.861562347992277), np.float64(-4.432905128217965), np.float64(-4.23367703352401), np.float64(-3.967218972276821), np.float64(-4.8179450401607395), np.float64(-4.322999804848522), np.float64(-3.692441580705788), np.float64(-4.7613401598769665), np.float64

 70%|███████   | 7/10 [00:57<00:26,  8.73s/it]

Threshold: -3.235469034984074, Backward: False
All scores: [np.float64(-3.381390298311477), np.float64(-3.235469034984074), np.float64(-4.0089342431155846), np.float64(-4.511851519630441), np.float64(-4.236305526252072), np.float64(-3.716551399318106), np.float64(-4.457993493716698), np.float64(-3.6423077761901292), np.float64(-3.487055207966021), np.float64(-5.1622998850283635), np.float64(-3.4075815914541203), np.float64(-5.379901978714801), np.float64(-4.799459161832853), np.float64(-4.873479241129162), np.float64(-6.952304678694447), np.float64(-4.5414169986320605), np.float64(-3.539924674770691), np.float64(-3.939981189132372), np.float64(-3.82630872696607), np.float64(-4.101316609659585), np.float64(-4.547271600616422), np.float64(-4.736059190798983), np.float64(-3.644012151231792), np.float64(-5.01759793874734), np.float64(-4.841366155281191), np.float64(-8.252988468470383), np.float64(-4.311220352935163), np.float64(-5.744506904387367), np.float64(-3.8735588012560003), np.float

 80%|████████  | 8/10 [01:01<00:14,  7.10s/it]

Threshold: -2.585837505443772, Backward: False
All scores: [np.float64(-2.510203838610998), np.float64(-3.3221730249895813), np.float64(-2.5568817510016077), np.float64(-2.629907670220384), np.float64(-3.167508568721702), np.float64(-2.5403842800461023), np.float64(-3.085035830525837), np.float64(-3.0486953579935916), np.float64(-4.141290927801318), np.float64(-4.738547938884302), np.float64(-4.818836830904138), np.float64(-2.710151205549179), np.float64(-4.2030873752245785), np.float64(-4.67061778380463)]
Min: -4.818836830904138, Max: -2.510203838610998, Mean: -3.4388087417341393
Binary Search (Backward=True): s=0, t=32, Mid=16, s1=-2.703443656561015, s2=-2.8507435641644983
Binary Search (Backward=True): s=0, t=16, Mid=8, s1=-2.379347749300833, s2=-3.1266311310297445
Binary Search (Backward=True): s=0, t=8, Mid=4, s1=-2.0952336488236907, s2=-3.0462891638784373
Binary Search (Backward=True): s=0, t=4, Mid=2, s1=-1.8583584046396189, s2=-2.8375439341637527
Binary Search (Backward=True): 

 90%|█████████ | 9/10 [01:08<00:07,  7.08s/it]

Threshold: -2.066936872982527, Backward: False
All scores: [np.float64(-2.066936872982527), np.float64(-2.5403354735714485), np.float64(-3.4464186158582297), np.float64(-3.4966587533004048), np.float64(-3.8770106405123013), np.float64(-6.332861803876009), np.float64(-4.08587681827599), np.float64(-3.40214993445431), np.float64(-3.0784245393323726), np.float64(-3.740394087564943), np.float64(-3.6579392483075868), np.float64(-3.780101308905614), np.float64(-4.238427397841526), np.float64(-4.022037669348154), np.float64(-4.0395506538404975), np.float64(-3.7268892785445074), np.float64(-4.402181229189035), np.float64(-3.655773977043192), np.float64(-3.823144814537463), np.float64(-4.151453107395803), np.float64(-3.556586079459477), np.float64(-4.197768019241783), np.float64(-3.148108743870127), np.float64(-4.1665607548595025), np.float64(-4.016658695286305), np.float64(-4.068893706106894), np.float64(-5.350072348563867), np.float64(-7.1727326976548555), np.float64(-3.8656893875876173), np.

100%|██████████| 10/10 [01:11<00:00,  7.16s/it]

Threshold: -3.3947288569374012, Backward: False
All scores: [np.float64(-3.254194563172916), np.float64(-4.0243493045693866), np.float64(-3.2003646629309177), np.float64(-3.4045958602931856), np.float64(-3.1880026340008394), np.float64(-3.9263891332545504), np.float64(-3.8190180523039583), np.float64(-3.8195622978694135), np.float64(-4.893436010871528), np.float64(-4.219800790864455), np.float64(-3.647971484216533), np.float64(-3.809536440247135), np.float64(-5.22855938390989), np.float64(-5.195973826221048), np.float64(-4.5549088038363355), np.float64(-4.816230907893958)]
Min: -5.22855938390989, Max: -3.1880026340008394, Mean: -4.062680884778503
Columns in results_df: ['id', 'highlight', 'ba_linear_citation', 'fo_linear_citation', 'ba_linear_score', 'fo_linear_score', 'ba_linear_perplexity', 'fo_linear_perplexity', 'ba_linear_emb_similarity', 'fo_linear_emb_similarity', 'ba_linear_rougeL_fmeasure', 'fo_linear_rougeL_fmeasure', 'ba_linear_tfidf', 'fo_linear_tfidf', 'ba_binary_citation'




In [None]:
results = linear_attribution_search(pd.DataFrame(dataset), ba_model, ba_tokenizer, fo_model, fo_tokenizer)
results_df = pd.DataFrame(results)

# Display results in a more readable format
display_df = results_df[['highlight', 'ba_citation', 'ba_score', 'ba_perplexity', 
                         'fo_citation', 'fo_score', 'fo_perplexity']]

In [None]:
display_df.T

## Further Analysis

1. Analyze first sample citations is to check for a single article, what are the sentences' individual scores
2. Evaluate citations and display is to get the benchmark metrics
3. 

In [None]:
def analyze_first_sample_citations(dataset, ba_model, ba_tokenizer, fo_model, fo_tokenizer):
    """
    Analyze all possible citation sentences for the first highlight in the first sample.
    
    This function calculates scores for all article sentences against the first
    highlight sentence and returns them sorted by backward model score.
    """
    
    
    # Get the first sample
    first_sample = dataset.iloc[0]
    
    # Split article into sentences
    article_sentences = sent_tokenize(first_sample['article'])
    
    # Get the first highlight sentence
    highlight_sentences = sent_tokenize(first_sample['highlights'])
    if not highlight_sentences:
        print("No highlight sentences found!")
        return None
    
    highlight = highlight_sentences[0]
    print(f"Analyzing citations for highlight: \n'{highlight}'\n")
    
    # Calculate scores for all article sentences
    results = []
    
    for i, sentence in enumerate(article_sentences):
        # Skip very short sentences
        if len(sentence.split()) < 3:
            continue
            
        # Calculate scores using both models
        ba_score = calculate_llm_score(sentence, highlight, ba_model, ba_tokenizer, backward=True)
        fo_score = calculate_llm_score(sentence, highlight, fo_model, fo_tokenizer)
        
        # Add to results
        results.append({
            'sentence_idx': i,
            'article_sentence': sentence,
            'ba_score': ba_score['normalized_log_prob'],
            'ba_perplexity': ba_score['perplexity'],
            'fo_score': fo_score['normalized_log_prob'],
            'fo_perplexity': fo_score['perplexity']
        })
    
    # Create DataFrame and sort by backward model score (descending)
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('ba_score', ascending=False).reset_index(drop=True)
    
    # Set display options for better readability
    pd.set_option('display.max_colwidth', 70)
    pd.set_option('display.float_format', '{:.4f}'.format)
    
    return results_df

# Example usage:
citation_analysis = analyze_first_sample_citations(pd.DataFrame(dataset), 
                                                  ba_model, ba_tokenizer, 
                                                  fo_model, fo_tokenizer)
citation_analysis

In [None]:
def evaluate_citations(results, embedding_model_name='all-MiniLM-L6-v2'):
    """
    Evaluate citation quality using multiple metrics (ROUGE, embeddings, TF-IDF).
    
    Args:
        results: List of dictionaries with citation results
        embedding_model_name: Name of the sentence-transformers model to use
    
    Returns:
        Enhanced results with additional evaluation metrics
    """
    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # Initialize sentence transformer for embeddings
    embedding_model = SentenceTransformer(embedding_model_name, cache_folder='.cache/')
    
    # Initialize TF-IDF vectorizer
    tfidf_vectorizer = TfidfVectorizer(stop_words='english')
    
    # Process all highlights and citations to prepare TF-IDF matrix
    all_texts = []
    for result in results:
        all_texts.append(result['highlight'])
        all_texts.append(result['ba_citation'])
        all_texts.append(result['fo_citation'])
    
    # Fit the TF-IDF vectorizer
    tfidf_matrix = tfidf_vectorizer.fit_transform(all_texts)
    
    enhanced_results = []
    
    for i, result in enumerate(results):
        # Get text indices for the current result
        highlight_idx = i * 3
        ba_citation_idx = i * 3 + 1
        fo_citation_idx = i * 3 + 2
        
        # Get text for the current result
        highlight = result['highlight']
        ba_citation = result['ba_citation']
        fo_citation = result['fo_citation']
        
        # Calculate embeddings
        highlight_emb = embedding_model.encode([highlight])[0]
        ba_citation_emb = embedding_model.encode([ba_citation])[0]
        fo_citation_emb = embedding_model.encode([fo_citation])[0]
        
        # Calculate embedding similarity (cosine)
        ba_emb_similarity = cosine_similarity(
            highlight_emb.reshape(1, -1), 
            ba_citation_emb.reshape(1, -1)
        )[0][0]
        
        fo_emb_similarity = cosine_similarity(
            highlight_emb.reshape(1, -1), 
            fo_citation_emb.reshape(1, -1)
        )[0][0]
        
        # Calculate TF-IDF similarity
        ba_tfidf_similarity = cosine_similarity(
            tfidf_matrix[highlight_idx], 
            tfidf_matrix[ba_citation_idx]
        )[0][0]
        
        fo_tfidf_similarity = cosine_similarity(
            tfidf_matrix[highlight_idx], 
            tfidf_matrix[fo_citation_idx]
        )[0][0]
        
        # Calculate ROUGE scores
        ba_rouge = scorer.score(highlight, ba_citation)
        fo_rouge = scorer.score(highlight, fo_citation)
        
        # Create enhanced result with all metrics
        enhanced_result = result.copy()
        
        # Add backward model metrics
        enhanced_result.update({
            'ba_emb_similarity': ba_emb_similarity,
            'ba_tfidf_similarity': ba_tfidf_similarity,
            'ba_rouge1_precision': ba_rouge['rouge1'].precision,
            'ba_rouge1_recall': ba_rouge['rouge1'].recall,
            'ba_rouge1_fmeasure': ba_rouge['rouge1'].fmeasure,
            'ba_rouge2_fmeasure': ba_rouge['rouge2'].fmeasure,
            'ba_rougeL_fmeasure': ba_rouge['rougeL'].fmeasure,
        })
        
        # Add forward model metrics
        enhanced_result.update({
            'fo_emb_similarity': fo_emb_similarity,
            'fo_tfidf_similarity': fo_tfidf_similarity,
            'fo_rouge1_precision': fo_rouge['rouge1'].precision,
            'fo_rouge1_recall': fo_rouge['rouge1'].recall,
            'fo_rouge1_fmeasure': fo_rouge['rouge1'].fmeasure,
            'fo_rouge2_fmeasure': fo_rouge['rouge2'].fmeasure,
            'fo_rougeL_fmeasure': fo_rouge['rougeL'].fmeasure,
        })
        
        enhanced_results.append(enhanced_result)
    
    return enhanced_results

In [None]:
def display_evaluation_results(results, metrics_to_show=None):
    """
    Display evaluation results in a DataFrame.
    
    Args:
        results: List of dictionaries with enhanced evaluation metrics
        metrics_to_show: List of metric columns to display (if None, shows a default set)
    
    Returns:
        DataFrame with evaluation metrics
    """
    if metrics_to_show is None:
        # Default metrics to show
        metrics_to_show = [
            'highlight', 'ba_citation', 'fo_citation',
            'ba_perplexity', 'fo_perplexity',
            'ba_emb_similarity', 'fo_emb_similarity',
            'ba_rougeL_fmeasure', 'fo_rougeL_fmeasure'
        ]
    
    # Create DataFrame
    df = pd.DataFrame(results)
    
    # Select columns to display
    display_df = df[metrics_to_show]
    
    return display_df

# Example usage:
enhanced_results = evaluate_citations(results)
display_df = display_evaluation_results(enhanced_results)
display_df

In [None]:
# After you have your enhanced_results
df = pd.DataFrame(enhanced_results)

# Create two separate dataframes - one for each model type
ba_metrics = {col.replace('ba_', ''): df[col].mean() for col in df.columns 
              if col.startswith('ba_') and col != 'ba_citation'}
              
fo_metrics = {col.replace('fo_', ''): df[col].mean() for col in df.columns 
              if col.startswith('fo_') and col != 'fo_citation'}

# Combine into single comparison dataframe
comparison = pd.DataFrame({
    'Backward Model': ba_metrics,
    'Forward Model': fo_metrics
})

# Display the properly formatted comparison
comparison.round(4)