# ANLI with LLM

You have to implement in this notebook a better ANLI classifier using an LLM.
This classifier must be implemented using DSPy.


In [None]:
# Configure the DSPy environment with the language model - for grok the parameters must be:
# env variable should be in os.environ['XAI_API_KEY']
# "xai/grok-3-mini"
import os
import dspy
from typing import Literal
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import evaluate
import numpy as np
from tqdm.auto import tqdm


os.environ["XAI_API_KEY"] = ""

lm = dspy.LM('xai/grok-3-mini', api_key=os.environ['XAI_API_KEY'])
dspy.configure(lm=lm)

In [9]:
# Load sentence transformer for similarity measurement
print("Loading sentence transformer model...")
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
print("✓ Sentence transformer loaded")

Loading sentence transformer model...
✓ Sentence transformer loaded


## Load ANLI dataset

In [2]:
dataset = load_dataset("facebook/anli")
dataset = dataset.filter(lambda x: x['reason'] != None and x['reason'] != "")

## Evaluate Metrics

Let's use the huggingface `evaluate` package to compute the performance of the baseline.


In [4]:
from evaluate import load

accuracy = load("accuracy")
precision = load("precision")
recall = load("recall")
f1 = load("f1")


In [5]:
import evaluate
clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

## Your Turn

Compute the classification metrics on the baseline LLM model on each test section of the ANLI dataset for samples that have a non-empty 'reason' field.

You also must show a comparison between the DeBERTa baseline model and this LLM baseline model. The comparison metric should compute the agreement between the two models:
* On how many samples they are both correct [Correct]
* On how many samples Model1 is correct and Model2 is incorrect [Correct1]
* On how many samples Model1 is incorrect and Model2 is correct [Correct2]
* On how many samples both are incorrect [Incorrect]

In [6]:
## Task 1.4: Explanation CoT LLM for ANLI

### Step 1: Implement DSPy Signatures and Modules

# Define the signatures for both strategies
class JointPromptSignature(dspy.Signature):
    """Generate both an explanation and a label for the NLI task in one step."""
    premise = dspy.InputField()
    hypothesis = dspy.InputField()
    human_reason = dspy.InputField()
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

class PipelineExplanationSignature(dspy.Signature):
    """Generate an explanation for the relationship between premise and hypothesis."""
    premise = dspy.InputField()
    hypothesis = dspy.InputField()
    human_reason = dspy.InputField()
    explanation = dspy.OutputField(desc="Step by step thought process to classify the relationship between the premise and hypothesis: 'entailment', 'contradiction', 'neutral'.")

class PipelineResponseSignature(dspy.Signature):
    """Given an explanation, determine the entailment label."""
    explanation = dspy.InputField(desc="Explanation of the relationship")
    label: Literal['entailment', 'neutral', 'contradiction'] = dspy.OutputField()

# Pipeline approach
class PipelineModule(dspy.Module):
    def __init__(self):
        super().__init__()
        self.explanation_step = dspy.Predict(PipelineExplanationSignature)
        self.label_step = dspy.Predict(PipelineResponseSignature)
    
    def forward(self, premise, hypothesis, human_reason):
        explanation_result = self.explanation_step(
            premise=premise, 
            hypothesis=hypothesis, 
            human_reason=human_reason
        )
        label_result = self.label_step(explanation=explanation_result.explanation)
        
        return dspy.Prediction(
            explanation=explanation_result.explanation,
            label=label_result.label
        )

# Initialize modules
print("Initializing modules...")
joint_cot = dspy.ChainOfThought(JointPromptSignature)
pipeline_module = PipelineModule()
print("✓ Modules initialized")



Initializing modules...
✓ Modules initialized


In [27]:
### Step 2: Utility Functions for Similarity Analysis

def compute_similarity_scores(premise, hypothesis, human_reason, predicted_explanation, similarity_model):
    """
    Compute similarity scores between different text combinations.
    Returns a dictionary with all possible combinations.
    """
    # Combine premise and hypothesis
    premise_hypothesis = f"{premise} {hypothesis}"
    
    # Encode all texts
    texts = [premise_hypothesis, human_reason, predicted_explanation]
    embeddings = similarity_model.encode(texts)
    
    # Compute similarities
    similarities = similarity_model.similarity(embeddings, embeddings).numpy()
    
    # Extract individual similarities
    ph_hr_sim = similarities[0, 1]  # premise+hypothesis vs human_reason
    ph_pe_sim = similarities[0, 2]  # premise+hypothesis vs predicted_explanation
    hr_pe_sim = similarities[1, 2]  # human_reason vs predicted_explanation
    
    return {
        'A': ph_hr_sim,      # (premise, hypothesis) vs human_reason
        'B': hr_pe_sim,      # human_reason vs predicted_explanation  
        'C': ph_pe_sim,      # (premise, hypothesis) vs predicted_explanation
        'A+B': ph_hr_sim + hr_pe_sim,
        'A+C': ph_hr_sim + ph_pe_sim,
        'B+C': hr_pe_sim + ph_pe_sim,
        'A+B+C': ph_hr_sim + hr_pe_sim + ph_pe_sim
    }

def normalize_label(label):
    """
    Normalize label to standard format, handling both string and integer inputs.
    """
    # Handle integer labels (0, 1, 2)
    if isinstance(label, int):
        label_map = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
        return label_map.get(label, str(label))
    
    # Handle string labels
    if isinstance(label, str):
        label = label.lower().strip()
        # Map common variations
        if 'entail' in label or label in ['yes', 'true', '0']:
            return 'entailment'
        elif 'neutral' in label or label in ['unknown', 'maybe', '1']:
            return 'neutral'
        elif 'contradic' in label or label in ['no', 'false', '2']:
            return 'contradiction'
        # Direct matches
        if label in ['entailment', 'neutral', 'contradiction']:
            return label
    
    # If none match, return as string for debugging
    return str(label).lower().strip()

def evaluate_joint_cot_with_similarity(examples, joint_cot, similarity_model):
    """
    Evaluate joint ChainOfThought predictions and compute similarity metrics.
    """
    results = []
    correct_predictions = 0
    
    print(f"Evaluating Joint CoT on {len(examples)} examples...")
    
    for example in tqdm(examples):
        try:
            # Get prediction from ChainOfThought
            prediction = joint_cot(
                premise=example['premise'],
                hypothesis=example['hypothesis'],
                human_reason=example['reason']
            )
            
            # Extract explanation (reasoning from CoT)
            predicted_explanation = getattr(prediction, 'reasoning', "")
            
            # Normalize labels for comparison
            predicted_label = normalize_label(prediction.label)
            true_label = normalize_label(example['label'])
            
            # Check if prediction is correct
            is_correct = predicted_label == true_label
            if is_correct:
                correct_predictions += 1
            
            # Compute similarity scores
            similarity_scores = compute_similarity_scores(
                example['premise'],
                example['hypothesis'], 
                example['reason'],
                predicted_explanation,
                similarity_model
            )
            
            results.append({
                'premise': example['premise'],
                'hypothesis': example['hypothesis'],
                'human_reason': example['reason'],
                'predicted_explanation': predicted_explanation,
                'predicted_label': predicted_label,
                'true_label': true_label,
                'is_correct': is_correct,
                **similarity_scores
            })
            
        except Exception as e:
            print(f"Error processing example: {e}")
            print(f"Prediction label type: {type(prediction.label)}, value: {prediction.label}")
            continue
    
    accuracy = correct_predictions / len(results) if results else 0
    print(f"✓ Joint CoT Accuracy: {accuracy:.3f}")
    
    return results, accuracy

def evaluate_pipeline_with_similarity(examples, pipeline_module, similarity_model):
    """
    Evaluate pipeline module predictions and compute similarity metrics.
    """
    results = []
    correct_predictions = 0
    
    print(f"Evaluating Pipeline Module on {len(examples)} examples...")
    
    for example in tqdm(examples):
        try:
            # Get prediction from pipeline module
            prediction = pipeline_module(
                premise=example['premise'],
                hypothesis=example['hypothesis'],
                human_reason=example['reason']
            )
            
            # Extract explanation from pipeline module
            predicted_explanation = getattr(prediction, 'explanation', "")
            
            # Normalize labels for comparison
            predicted_label = normalize_label(prediction.label)
            true_label = normalize_label(example['label'])
            
            # Check if prediction is correct
            is_correct = predicted_label == true_label
            if is_correct:
                correct_predictions += 1
            
            # Compute similarity scores
            similarity_scores = compute_similarity_scores(
                example['premise'],
                example['hypothesis'], 
                example['reason'],
                predicted_explanation,
                similarity_model
            )
            
            results.append({
                'premise': example['premise'],
                'hypothesis': example['hypothesis'],
                'human_reason': example['reason'],
                'predicted_explanation': predicted_explanation,
                'predicted_label': predicted_label,
                'true_label': true_label,
                'is_correct': is_correct,
                **similarity_scores
            })
            
        except Exception as e:
            print(f"Error processing example: {e}")
            print(f"Prediction label type: {type(prediction.label)}, value: {prediction.label}")
            continue
    
    accuracy = correct_predictions / len(results) if results else 0
    print(f"✓ Pipeline Module Accuracy: {accuracy:.3f}")
    
    return results, accuracy

In [22]:
# Get a sample from dev_r3 for testing
print("Loading ANLI dev_r3 dataset...")
dev_data = dataset['dev_r3']
print(f"Total examples with reasons: {len(dev_data)}")

# Take a smaller sample for initial testing (to manage costs)
sample_size = 50  # Adjust based on your budget
test_examples = dev_data.select(range(min(sample_size, len(dev_data))))

### Step 3.1: Debug Label Issues

# Let's first inspect what the actual labels look like
print("=== DEBUGGING LABEL FORMATS ===")
sample_examples = test_examples.select(range(3))

print("\nTrue labels from dataset:")
for i, example in enumerate(sample_examples):
    print(f"Example {i}: '{example['label']}' (type: {type(example['label'])})")

print("\nTesting joint_cot predictions:")
for i, example in enumerate(sample_examples):
    try:
        prediction = joint_cot(
            premise=example['premise'],
            hypothesis=example['hypothesis'],
            human_reason=example['reason']
        )
        print(f"Example {i}:")
        print(f"  Raw prediction.label: '{prediction.label}' (type: {type(prediction.label)})")
        print(f"  Normalized predicted: '{normalize_label(prediction.label)}'")
        print(f"  Normalized true: '{normalize_label(example['label'])}'")
        print(f"  Match: {normalize_label(prediction.label) == normalize_label(example['label'])}")
        
        # Debug rationale access
        print(f"  Prediction attributes: {dir(prediction)}")
        reasoning = getattr(prediction, 'reasoning', 'NO REASONING')
        print(f"  Reasoning: {reasoning[:100] if reasoning != 'NO REASONING' else reasoning}...")
        print()
    except Exception as e:
        print(f"Example {i} failed: {e}")

Loading ANLI dev_r3 dataset...
Total examples with reasons: 1200
=== DEBUGGING LABEL FORMATS ===

True labels from dataset:
Example 0: '0' (type: <class 'int'>)
Example 1: '0' (type: <class 'int'>)
Example 2: '0' (type: <class 'int'>)

Testing joint_cot predictions:
Example 0:
  Raw prediction.label: 'neutral' (type: <class 'str'>)
  Normalized predicted: 'neutral'
  Normalized true: 'entailment'
  Match: False
  Prediction attributes: ['__add__', '__class__', '__contains__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__float__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__radd__', '__reduce__', '__reduce_ex__', '__repr__', '__rtruediv__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', '__truediv__', '__weakref__', '_completions', '_lm_usage', '_st

In [32]:
### Step 3: Run Initial Evaluation and Correlation Analysis

# Load and prepare test data
print("Loading ANLI dev_r3 dataset...")
dev_data = dataset['dev_r3']
print(f"Total examples with reasons: {len(dev_data)}")

# Take a sample for testing (adjust based on your budget)
sample_size = 200
test_examples = dev_data.select(range(min(sample_size, len(dev_data))))
print(f"Using {len(test_examples)} examples for testing")

# Evaluate both approaches
print("\n" + "="*50)
print("EVALUATING JOINT CHAIN-OF-THOUGHT")
print("="*50)
joint_results, joint_accuracy = evaluate_joint_cot_with_similarity(
    test_examples, joint_cot, similarity_model
)

print("\n" + "="*50)  
print("EVALUATING PIPELINE MODULE")
print("="*50)
pipeline_results, pipeline_accuracy = evaluate_pipeline_with_similarity(
    test_examples, pipeline_module, similarity_model
)

# Summary
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"Joint CoT Accuracy: {joint_accuracy:.3f}")
print(f"Pipeline Module Accuracy: {pipeline_accuracy:.3f}")
print(f"Performance difference: {abs(pipeline_accuracy - joint_accuracy):.3f}")
print("\nBoth approaches are ready for optimization and further analysis!")

Loading ANLI dev_r3 dataset...
Total examples with reasons: 1200
Using 200 examples for testing

EVALUATING JOINT CHAIN-OF-THOUGHT
Evaluating Joint CoT on 200 examples...


  0%|          | 0/200 [00:00<?, ?it/s]

✓ Joint CoT Accuracy: 0.885

EVALUATING PIPELINE MODULE
Evaluating Pipeline Module on 200 examples...


  0%|          | 0/200 [00:00<?, ?it/s]

✓ Pipeline Module Accuracy: 0.900

EVALUATION SUMMARY
Joint CoT Accuracy: 0.885
Pipeline Module Accuracy: 0.900
Performance difference: 0.015

Both approaches are ready for optimization and further analysis!


In [35]:
def analyze_correct_predictions_std(results, module_name):
    """
    Analyze standard deviation of similarity metrics A, B, C for both correctly and incorrectly predicted samples.
    """
    print(f"\n=== STD ANALYSIS FOR PREDICTIONS - {module_name.upper()} ===")
    
    if not results:
        print("No results to analyze")
        return None
    
    df = pd.DataFrame(results)
    
    # Split into correct and incorrect predictions
    correct_predictions = df[df['is_correct'] == True]
    incorrect_predictions = df[df['is_correct'] == False]
    
    total_samples = len(df)
    correct_count = len(correct_predictions)
    incorrect_count = len(incorrect_predictions)
    
    print(f"Total samples: {total_samples}")
    print(f"Correct predictions: {correct_count} ({correct_count/total_samples:.1%})")
    print(f"Incorrect predictions: {incorrect_count} ({incorrect_count/total_samples:.1%})")
    
    metrics = ['A', 'B', 'C']
    
    # Analyze correct predictions
    print(f"\nCORRECT Predictions - Metrics A, B, C:")
    print(f"{'Metric':<8} {'Mean':<8} {'Std Dev':<8}")
    print("-" * 24)
    
    correct_stats = {}
    for metric in metrics:
        if metric in correct_predictions.columns and len(correct_predictions) > 0:
            values = correct_predictions[metric]
            mean_val = values.mean()
            std_val = values.std()
            
            correct_stats[metric] = {
                'mean': mean_val,
                'std': std_val,
                'count': len(values)
            }
            
            print(f"{metric:<8} {mean_val:<8.4f} {std_val:<8.4f}")
    
    # Analyze incorrect predictions
    incorrect_stats = {}
    if incorrect_count > 0:
        print(f"\nINCORRECT Predictions - Metrics A, B, C:")
        print(f"{'Metric':<8} {'Mean':<8} {'Std Dev':<8}")
        print("-" * 24)
        
        for metric in metrics:
            if metric in incorrect_predictions.columns:
                values = incorrect_predictions[metric]
                mean_val = values.mean()
                std_val = values.std()
                
                incorrect_stats[metric] = {
                    'mean': mean_val,
                    'std': std_val,
                    'count': len(values)
                }
                
                print(f"{metric:<8} {mean_val:<8.4f} {std_val:<8.4f}")
        
        # Comparison
        print(f"\nCORRECT vs INCORRECT Comparison:")
        print(f"{'Metric':<8} {'Correct':<12} {'Incorrect':<12} {'Mean Diff':<10}")
        print(f"{'':8} {'Mean±Std':<12} {'Mean±Std':<12} {'(C-I)':<10}")
        print("-" * 50)
        
        for metric in metrics:
            if metric in correct_stats and metric in incorrect_stats:
                c_mean = correct_stats[metric]['mean']
                c_std = correct_stats[metric]['std']
                i_mean = incorrect_stats[metric]['mean']
                i_std = incorrect_stats[metric]['std']
                mean_diff = c_mean - i_mean
                
                print(f"{metric:<8} {c_mean:.3f}±{c_std:.3f}   {i_mean:.3f}±{i_std:.3f}   {mean_diff:+.4f}")
    
    else:
        print(f"\nNo incorrect predictions to analyze!")
    
    return correct_stats, incorrect_stats

# Run the analysis on both approaches
print("\n" + "="*60)
print("STANDARD DEVIATION ANALYSIS FOR CORRECT PREDICTIONS")
print("="*60)

joint_correct_stats, joint_incorrect_stats = analyze_correct_predictions_std(joint_results, "Joint CoT")
pipeline_correct_stats, pipeline_incorrect_stats = analyze_correct_predictions_std(pipeline_results, "Pipeline Module")

# Summary comparison
if joint_correct_stats and pipeline_correct_stats:
    print(f"\n=== APPROACH COMPARISON - CORRECT PREDICTIONS ===")
    print(f"Standard Deviations for Correct Predictions:")
    print(f"{'Metric':<8} {'Joint CoT':<12} {'Pipeline':<12}")
    print("-" * 32)
    
    for metric in ['A', 'B', 'C']:
        if metric in joint_correct_stats and metric in pipeline_correct_stats:
            joint_std = joint_correct_stats[metric]['std']
            pipeline_std = pipeline_correct_stats[metric]['std']
            
            print(f"{metric:<8} {joint_std:<12.4f} {pipeline_std:<12.4f}")
    
    print(f"\n=== APPROACH COMPARISON - INCORRECT PREDICTIONS ===")
    if joint_incorrect_stats and pipeline_incorrect_stats:
        print(f"Standard Deviations for Incorrect Predictions:")
        print(f"{'Metric':<8} {'Joint CoT':<12} {'Pipeline':<12}")
        print("-" * 32)
        
        for metric in ['A', 'B', 'C']:
            if metric in joint_incorrect_stats and metric in pipeline_incorrect_stats:
                joint_std = joint_incorrect_stats[metric]['std']
                pipeline_std = pipeline_incorrect_stats[metric]['std']
                
                print(f"{metric:<8} {joint_std:<12.4f} {pipeline_std:<12.4f}")
    else:
        print("Insufficient incorrect predictions for comparison")
    
    print(f"\nLower std dev = more consistent similarity scores")
    print(f"Mean differences show which metrics best distinguish correct/incorrect predictions")



STANDARD DEVIATION ANALYSIS FOR CORRECT PREDICTIONS

=== STD ANALYSIS FOR PREDICTIONS - JOINT COT ===
Total samples: 200
Correct predictions: 177 (88.5%)
Incorrect predictions: 23 (11.5%)

CORRECT Predictions - Metrics A, B, C:
Metric   Mean     Std Dev 
------------------------
A        0.3662   0.1941  
B        0.5790   0.1622  
C        0.5931   0.1494  

INCORRECT Predictions - Metrics A, B, C:
Metric   Mean     Std Dev 
------------------------
A        0.3789   0.1956  
B        0.5412   0.1646  
C        0.6127   0.1309  

CORRECT vs INCORRECT Comparison:
Metric   Correct      Incorrect    Mean Diff 
         Mean±Std     Mean±Std     (C-I)     
--------------------------------------------------
A        0.366±0.194   0.379±0.196   -0.0127
B        0.579±0.162   0.541±0.165   +0.0379
C        0.593±0.149   0.613±0.131   -0.0196

=== STD ANALYSIS FOR PREDICTIONS - PIPELINE MODULE ===
Total samples: 200
Correct predictions: 180 (90.0%)
Incorrect predictions: 20 (10.0%)

CORRECT 

In [33]:
### Step 4: Correlation Analysis

from scipy.stats import spearmanr
import pandas as pd

def analyze_correlations(results, module_name):
    """
    Analyze correlations between similarity metrics and accuracy.
    """
    print(f"\n=== CORRELATION ANALYSIS FOR {module_name.upper()} ===")
    
    if not results:
        print("No results to analyze")
        return None, None
    
    df = pd.DataFrame(results)
    
    # Define all similarity metrics
    similarity_metrics = ['A', 'B', 'C', 'A+B', 'A+C', 'B+C', 'A+B+C']
    
    # Compute correlations
    correlations = {}
    for metric in similarity_metrics:
        if metric in df.columns:
            corr, p_value = spearmanr(df[metric], df['is_correct'])
            correlations[metric] = {'correlation': corr, 'p_value': p_value}
    
    # Sort by correlation strength
    sorted_correlations = sorted(correlations.items(), 
                               key=lambda x: abs(x[1]['correlation']), 
                               reverse=True)
    
    print("\nSpearman Correlations (sorted by absolute correlation):")
    print("-" * 60)
    for metric, stats in sorted_correlations:
        print(f"{metric:>6}: {stats['correlation']:>8.4f} (p={stats['p_value']:.4f})")
    
    # Average scores for A, B, C alone
    print(f"\nAverage Similarity Scores:")
    print("-" * 30)
    for metric in ['A', 'B', 'C']:
        if metric in df.columns:
            avg_score = df[metric].mean()
            print(f"{metric}: {avg_score:.4f}")
    
    return sorted_correlations, df

# Run correlation analysis
joint_correlations, joint_df = analyze_correlations(joint_results, "Joint CoT")
pipeline_correlations, pipeline_df = analyze_correlations(pipeline_results, "Pipeline Module")



=== CORRELATION ANALYSIS FOR JOINT COT ===

Spearman Correlations (sorted by absolute correlation):
------------------------------------------------------------
     B:   0.0704 (p=0.3216)
   B+C:   0.0295 (p=0.6789)
     C:  -0.0254 (p=0.7213)
   A+C:  -0.0251 (p=0.7241)
   A+B:   0.0219 (p=0.7587)
     A:  -0.0156 (p=0.8264)
 A+B+C:   0.0004 (p=0.9954)

Average Similarity Scores:
------------------------------
A: 0.3676
B: 0.5747
C: 0.5953

=== CORRELATION ANALYSIS FOR PIPELINE MODULE ===

Spearman Correlations (sorted by absolute correlation):
------------------------------------------------------------
     C:  -0.0990 (p=0.1630)
   A+C:  -0.0947 (p=0.1823)
     A:  -0.0748 (p=0.2927)
 A+B+C:  -0.0554 (p=0.4357)
   A+B:  -0.0315 (p=0.6583)
     B:   0.0289 (p=0.6849)
   B+C:  -0.0289 (p=0.6849)

Average Similarity Scores:
------------------------------
A: 0.3676
B: 0.4698
C: 0.5090


In [31]:
### Step 5: Top Performers Analysis

def analyze_top_performers(df, correlations, module_name, top_n=20):
    """
    Analyze top performers for the best correlation metrics.
    """
    if correlations is None or df is None:
        return
        
    print(f"\n=== TOP PERFORMERS ANALYSIS FOR {module_name.upper()} ===")
    
    # Get top 3 metrics
    top_3_metrics = [corr[0] for corr in correlations[:3]]
    
    for i, metric in enumerate(top_3_metrics, 1):
        print(f"\n{i}. TOP {top_n} EXAMPLES FOR METRIC '{metric}' (correlation: {correlations[i-1][1]['correlation']:.4f})")
        print("-" * 80)
        
        # Sort by metric score and get top performers
        top_examples = df.nlargest(top_n, metric)
        
        for idx, (_, row) in enumerate(top_examples.iterrows(), 1):
            correct_indicator = "✓" if row['is_correct'] else "✗"
            print(f"{idx:2d}. {correct_indicator} Score: {row[metric]:.4f} | "
                  f"Pred: {row['predicted_label']} | True: {row['true_label']}")
            print(f"    Premise: {row['premise']}")
            print(f"    Hypothesis: {row['hypothesis']}")
            print(f"    Predicted Explanation: {row['predicted_explanation']}")
            print()

# Analyze top performers for both modules
if joint_correlations:
    analyze_top_performers(joint_df, joint_correlations, "Joint Module")

if pipeline_correlations:
    analyze_top_performers(pipeline_df, pipeline_correlations, "Pipeline Module")

print("\n" + "="*60)
print("INITIAL ANALYSIS COMPLETE")
print("="*60)
print(f"Joint Module Accuracy: {joint_accuracy:.3f}")
print(f"Pipeline Module Accuracy: {pipeline_accuracy:.3f}")
print("\nReady for next steps and further optimization!")


=== TOP PERFORMERS ANALYSIS FOR JOINT MODULE ===

1. TOP 20 EXAMPLES FOR METRIC 'C' (correlation: -0.1501)
--------------------------------------------------------------------------------
 1. ✓ Score: 0.8202 | Pred: entailment | True: entailment
    Premise: * India's benchmark 10-year bond yield up 1 basis point at 8.16 percent as some traders cut positions following allotment of securities at the 130 billion rupees debt sale. * However, a sharper rise was averted with traders preferring to stay on the sidelines ahead of the inflation data due on Oct. 15, which will offer cues on the likely central bank policy action. * The 10-year bond yield seen ranged between 8.12 and 8.18 percent until inflation data, traders say. (swati.bhat@thomsonreuters.com/; swati.bhat.thomsonreuters.com@reuters.net)
    Hypothesis: India's benchmark 10-year bond yield is under 20 percent.
    Predicted Explanation: The premise states that India's benchmark 10-year bond yield is at 8.16 percent, which is a s