In [1]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import pandas as pd
from tqdm.notebook import tqdm
import random



In [None]:
# Set up paths
analysis_dir = Path("analysis")

def load_analyzed_chunks(analysis_dir: Path) -> List[Dict]:
    """
    Load all analyzed chunks from the analysis directory.
    
    Args:
        analysis_dir: Path to the analysis directory
        
    Returns:
        List of dictionaries containing chunk data
    """
    all_chunks = []
    
    # Iterate through all problem directories
    for problem_dir in tqdm([d for d in analysis_dir.iterdir() if d.is_dir()]):
        problem_id = problem_dir.name
        
        # Iterate through seed directories
        for seed_dir in [d for d in problem_dir.iterdir() if d.is_dir()]:
            seed = seed_dir.name
            
            # Load chunks_analyzed.json if it exists
            chunks_file = seed_dir / "chunks.json"
            if chunks_file.exists():
                with open(chunks_file, 'r', encoding='utf-8') as f:
                    chunks_data = json.load(f)
                
                # Add problem_id and seed to each chunk
                for chunk in chunks_data:
                    chunk['problem_id'] = problem_id
                    chunk['seed'] = seed
                    chunk['relative_position'] = chunk['index'] / len(chunks_data)
                
                all_chunks.extend([chunk for chunk in chunks_data if chunk['category'] != 'Unknown'])
    
    return all_chunks

# Load all analyzed chunks
all_chunks = load_analyzed_chunks(analysis_dir)
print(f"Loaded {len(all_chunks)} chunks from {len(set(chunk['problem_id'] for chunk in all_chunks))} problems")

In [None]:
def plot_category_distribution(chunks: List[Dict]):
    """
    Plot the distribution of reasoning categories.
    
    Args:
        chunks: List of chunk dictionaries
    """
    # Count categories
    categories = [chunk['category'] for chunk in chunks]
    category_counts = Counter(categories)
    
    # Sort by frequency
    sorted_categories = sorted(category_counts.items(), key=lambda x: x[1], reverse=True)
    
    # Create DataFrame for plotting
    df = pd.DataFrame(sorted_categories, columns=['Category', 'Count'])
    
    # Calculate percentage
    total = sum(df['Count'])
    df['Percentage'] = df['Count'] / total * 100
    
    # Plot
    plt.figure(figsize=(12, 8))
    ax = sns.barplot(x='Category', y='Count', data=df)
    
    # Add percentage labels
    for i, row in enumerate(df.itertuples()):
        ax.text(i, row.Count + 5, f"{row.Percentage:.1f}%", ha='center')
    
    plt.xticks(rotation=45, ha='right')
    plt.title('Distribution of Reasoning Categories')
    plt.tight_layout()
    plt.show()
    
    return df

category_distribution = plot_category_distribution(all_chunks)
print(category_distribution)

In [None]:
def compute_transition_matrix(chunks: List[Dict]):
    """
    Compute transition matrix between reasoning categories.
    
    Args:
        chunks: List of chunk dictionaries
        
    Returns:
        Transition matrix as DataFrame
    """
    # Get unique categories
    categories = sorted(set(chunk['category'] for chunk in chunks))
    
    # Initialize transition counts
    transitions = defaultdict(Counter)
    
    # Group chunks by problem and seed
    problem_chunks = defaultdict(list)
    for chunk in chunks:
        key = (chunk['problem_id'], chunk['seed'])
        problem_chunks[key].append(chunk)
    
    # Count transitions within each problem
    for key, chunks in problem_chunks.items():
        # Sort chunks by index
        sorted_chunks = sorted(chunks, key=lambda x: x['index'])
        
        # Count transitions
        for i in range(len(sorted_chunks) - 1):
            from_cat = sorted_chunks[i]['category']
            to_cat = sorted_chunks[i + 1]['category']
            transitions[from_cat][to_cat] += 1
    
    # Create transition matrix
    matrix = []
    for from_cat in categories:
        row = []
        total = sum(transitions[from_cat].values())
        for to_cat in categories:
            # Calculate probability if total > 0
            prob = transitions[from_cat][to_cat] / total if total > 0 else 0
            row.append(prob)
        matrix.append(row)
    
    # Create DataFrame
    df = pd.DataFrame(matrix, index=categories, columns=categories)
    
    return df

def plot_transition_matrix(transition_matrix: pd.DataFrame):
    """
    Plot transition matrix as heatmap.
    
    Args:
        transition_matrix: Transition matrix as DataFrame
    """
    plt.figure(figsize=(14, 12))
    sns.heatmap(
        transition_matrix, 
        annot=True, 
        cmap='viridis', 
        vmin=0, 
        vmax=0.5,  # Cap at 0.5 for better color distribution
        fmt='.2f'
    )
    plt.title('Transition Probabilities Between Reasoning Categories')
    plt.xlabel('To Category')
    plt.ylabel('From Category')
    plt.tight_layout()
    plt.show()

# Compute and plot transition matrix
transition_matrix = compute_transition_matrix(all_chunks)
plot_transition_matrix(transition_matrix)

In [None]:
def analyze_category_positions(chunks: List[Dict]):
    """
    Analyze the relative positions of each category in the reasoning chain.
    
    Args:
        chunks: List of chunk dictionaries
        
    Returns:
        DataFrame with position statistics for each category
    """
    # Group by category
    category_positions = defaultdict(list)
    
    for chunk in chunks:
        category = chunk['category']
        position = chunk['relative_position']
        category_positions[category].append(position)
    
    # Calculate statistics
    stats = []
    for category, positions in category_positions.items():
        stats.append({
            'Category': category,
            'Count': len(positions),
            'Mean Position': np.mean(positions),
            'Std Dev': np.std(positions),
            'Min': np.min(positions),
            'Max': np.max(positions),
            '25%': np.percentile(positions, 25),
            '50%': np.percentile(positions, 50),
            '75%': np.percentile(positions, 75)
        })
    
    # Create DataFrame
    df = pd.DataFrame(stats)
    df = df.sort_values('Mean Position')
    
    return df

def plot_category_positions(chunks: List[Dict]):
    """
    Plot the distribution of category positions in the reasoning chain.
    
    Args:
        chunks: List of chunk dictionaries
    """
    # Group by category
    category_positions = defaultdict(list)
    
    for chunk in chunks:
        category = chunk['category']
        position = chunk['relative_position']
        category_positions[category].append(position)
    
    # Create DataFrame for plotting
    data = []
    for category, positions in category_positions.items():
        for pos in positions:
            data.append({'Category': category, 'Relative Position': pos})
    
    df = pd.DataFrame(data)
    
    # Sort categories by mean position
    category_order = df.groupby('Category')['Relative Position'].mean().sort_values().index
    
    # Plot
    plt.figure(figsize=(14, 10))
    
    # Box plot
    ax = sns.boxplot(
        x='Category', 
        y='Relative Position', 
        data=df, 
        order=category_order
    )
    
    # Add scatter points for individual data points
    sns.stripplot(
        x='Category', 
        y='Relative Position', 
        data=df, 
        order=category_order,
        size=4, 
        color='black', 
        alpha=0.3
    )
    
    plt.xticks(rotation=45, ha='right')
    plt.title('Distribution of Category Positions in Reasoning Chain')
    plt.ylabel('Relative Position (0=start, 1=end)')
    plt.tight_layout()
    plt.show()
    
    # Return statistics
    return analyze_category_positions(chunks)

# Analyze and plot category positions
position_stats = plot_category_positions(all_chunks)
print(position_stats)

In [None]:
def analyze_backtracking(chunks: List[Dict]):
    """
    Analyze when backtracking occurs in the reasoning chain.
    
    Args:
        chunks: List of chunk dictionaries
    """
    # Filter for backtracking chunks
    backtracking_chunks = [chunk for chunk in chunks if chunk['category'] == 'Backtracking']
    
    if not backtracking_chunks:
        print("No backtracking chunks found in the dataset.")
        return
    
    # Get positions
    positions = [chunk['relative_position'] for chunk in backtracking_chunks]
    
    # Plot histogram
    plt.figure(figsize=(10, 6))
    plt.hist(positions, bins=20, alpha=0.7, color='blue')
    plt.axvline(np.mean(positions), color='red', linestyle='dashed', linewidth=2, label=f'Mean: {np.mean(positions):.2f}')
    
    plt.title('Distribution of Backtracking in Reasoning Chain')
    plt.xlabel('Relative Position (0=start, 1=end)')
    plt.ylabel('Frequency')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"Backtracking Statistics:")
    print(f"  Count: {len(backtracking_chunks)}")
    print(f"  Mean Position: {np.mean(positions):.3f}")
    print(f"  Std Dev: {np.std(positions):.3f}")
    print(f"  Min: {np.min(positions):.3f}")
    print(f"  Max: {np.max(positions):.3f}")
    
    # What categories typically precede backtracking?
    preceding_categories = []
    
    # Group chunks by problem and seed
    problem_chunks = defaultdict(list)
    for chunk in chunks:
        key = (chunk['problem_id'], chunk['seed'])
        problem_chunks[key].append(chunk)
    
    # Find categories that precede backtracking
    for key, prob_chunks in problem_chunks.items():
        # Sort chunks by index
        sorted_chunks = sorted(prob_chunks, key=lambda x: x['index'])
        
        # Find backtracking chunks
        for i, chunk in enumerate(sorted_chunks):
            if chunk['category'] == 'Backtracking' and i > 0:
                preceding_categories.append(sorted_chunks[i-1]['category'])
    
    # Count preceding categories
    preceding_counts = Counter(preceding_categories)
    
    # Plot
    plt.figure(figsize=(10, 6))
    categories, counts = zip(*preceding_counts.most_common())
    
    # Calculate percentages
    total = sum(counts)
    percentages = [count/total*100 for count in counts]
    
    # Create bars
    bars = plt.bar(categories, percentages)
    
    # Add percentage labels
    for bar, percentage in zip(bars, percentages):
        plt.text(
            bar.get_x() + bar.get_width()/2,
            bar.get_height() + 1,
            f"{percentage:.1f}%",
            ha='center'
        )
    
    plt.title('Categories That Precede Backtracking')
    plt.xlabel('Category')
    plt.ylabel('Percentage')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

# Analyze backtracking
analyze_backtracking(all_chunks)

In [None]:
def visualize_category_sequences(chunks: List[Dict], max_problems: int = 10):
    """
    Visualize the sequence of categories for each problem.
    
    Args:
        chunks: List of chunk dictionaries
        max_problems: Maximum number of problems to visualize
    """
    # Get unique categories and assign colors
    categories = sorted(set(chunk['category'] for chunk in chunks))
    color_map = dict(zip(categories, sns.color_palette("husl", len(categories))))
    
    # Group chunks by problem and seed
    problem_chunks = defaultdict(list)
    for chunk in chunks:
        key = (chunk['problem_id'], chunk['seed'])
        problem_chunks[key].append(chunk)
    
    # Select a subset of problems
    selected_problems = random.sample(list(problem_chunks.keys()), max_problems)
    
    # Create figure
    fig, axes = plt.subplots(len(selected_problems), 1, figsize=(15, len(selected_problems) * 1.5))
    if len(selected_problems) == 1:
        axes = [axes]
    
    # Plot each problem
    for i, (key, ax) in enumerate(zip(selected_problems, axes)):
        problem_id, seed = key
        prob_chunks = sorted(problem_chunks[key], key=lambda x: x['index'])
        
        # Create colored blocks for each category
        for j, chunk in enumerate(prob_chunks):
            category = chunk['category']
            ax.barh(0, 1, left=j, color=color_map[category], alpha=0.7)
            
        # Set labels
        ax.set_yticks([])
        ax.set_xlabel('Chunk Index')
        ax.set_title(f'Problem: {problem_id}, Seed: {seed}')
        ax.set_xlim(0, len(prob_chunks))
    
    # Create legend
    handles = [plt.Rectangle((0,0),1,1, color=color_map[cat]) for cat in categories]
    fig.legend(handles, categories, loc='upper center', bbox_to_anchor=(0.5, 0), 
               ncol=min(5, len(categories)), frameon=True)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15 + 0.02 * min(5, len(categories)))
    plt.show()

# Visualize category sequences
visualize_category_sequences(all_chunks)

In [None]:
def analyze_cosine_vs_correlation(analysis_dir: Path):
    """
    Analyze the relationship between cosine similarity and Pearson correlation.
    
    Args:
        analysis_dir: Path to the analysis directory
    """
    # Store data for all problems
    all_data = []
    
    # Iterate through all problem directories
    for problem_dir in tqdm([d for d in analysis_dir.iterdir() if d.is_dir()]):
        problem_id = problem_dir.name
        
        # Iterate through seed directories
        for seed_dir in [d for d in problem_dir.iterdir() if d.is_dir()]:
            seed = seed_dir.name
            
            # Check if both matrices exist
            cosine_path = seed_dir / "chunk_cosine_similarity.npy"
            corr_path = seed_dir / "chunk_correlation.npy"
            chunks_path = seed_dir / "chunks_analyzed.json"
            
            if cosine_path.exists() and corr_path.exists() and chunks_path.exists():
                # Load matrices
                cosine_matrix = np.load(cosine_path)
                corr_matrix = np.load(corr_path)
                
                # Load chunk data
                with open(chunks_path, 'r', encoding='utf-8') as f:
                    chunks_data = json.load(f)
                
                # Extract chunk categories
                categories = [chunk['category'] for chunk in chunks_data]
                abbreviations = [chunk['abbreviation'] for chunk in chunks_data]
                
                # Ensure matrices have the same shape
                if cosine_matrix.shape == corr_matrix.shape:
                    n = cosine_matrix.shape[0]
                    
                    # Extract all pairs (excluding self-comparisons)
                    for i in range(n):
                        for j in range(i+1, n):  # Only upper triangle
                            all_data.append({
                                'problem_id': problem_id,
                                'seed': seed,
                                'chunk_i': i,
                                'chunk_j': j,
                                'category_i': categories[i],
                                'category_j': categories[j],
                                'abbrev_i': abbreviations[i],
                                'abbrev_j': abbreviations[j],
                                'cosine_sim': cosine_matrix[i, j],
                                'correlation': corr_matrix[i, j]
                            })
    
    # Convert to DataFrame
    df = pd.DataFrame(all_data)
    
    return df

# Load and analyze cosine vs correlation data
cosine_corr_df = analyze_cosine_vs_correlation(analysis_dir)

# Plot the relationship
plt.figure(figsize=(10, 8))
plt.scatter(
    cosine_corr_df['cosine_sim'], 
    cosine_corr_df['correlation'],
    alpha=0.5,
    s=10
)

plt.xlabel('Cosine Similarity')
plt.ylabel('Pearson Correlation')
plt.title('Cosine Similarity vs Pearson Correlation for All Chunk Pairs')

# Add reference lines
plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
plt.axvline(x=0, color='gray', linestyle='--', alpha=0.5)

# Add a diagonal line for reference
x = np.linspace(*plt.xlim())
plt.plot(x, x, color='red', linestyle='--', alpha=0.5, label='y=x')

plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# Calculate correlation between cosine similarity and Pearson correlation
corr = cosine_corr_df['cosine_sim'].corr(cosine_corr_df['correlation'])
print(f"Correlation between cosine similarity and Pearson correlation: {corr:.3f}")

In [None]:
def find_interesting_chunk_pairs(df: pd.DataFrame, n_pairs: int = 10):
    """
    Find interesting chunk pairs with:
    1. High cosine similarity but low correlation
    2. Low cosine similarity but high correlation
    
    Args:
        df: DataFrame with cosine similarity and correlation data
        n_pairs: Number of pairs to find in each category
        
    Returns:
        DataFrames with interesting pairs
    """
    # Calculate the difference between cosine similarity and correlation
    df['cos_minus_corr'] = df['cosine_sim'] - df['correlation']
    df['corr_minus_cos'] = df['correlation'] - df['cosine_sim']
    
    # Find pairs with high cosine but low correlation
    high_cos_low_corr = df[
        (df['cosine_sim'] > 0.7) &  # High cosine threshold
        (df['correlation'] < 0.3)    # Low correlation threshold
    ].sort_values('cos_minus_corr', ascending=False).head(n_pairs)
    
    # Find pairs with low cosine but high correlation
    low_cos_high_corr = df[
        (df['cosine_sim'] < 0.3) &   # Low cosine threshold
        (df['correlation'] > 0.7)     # High correlation threshold
    ].sort_values('corr_minus_cos', ascending=False).head(n_pairs)
    
    return high_cos_low_corr, low_cos_high_corr

def plot_interesting_pairs(df: pd.DataFrame):
    """
    Create a scatter plot highlighting interesting pairs.
    
    Args:
        df: DataFrame with cosine similarity and correlation data
    """
    # Find interesting pairs
    high_cos_low_corr, low_cos_high_corr = find_interesting_chunk_pairs(df)
    
    # Create plot
    plt.figure(figsize=(12, 10))
    
    # Plot all points
    plt.scatter(
        df['cosine_sim'], 
        df['correlation'],
        alpha=0.2,
        s=10,
        color='gray',
        label='All pairs'
    )
    
    # Highlight high cosine, low correlation pairs
    plt.scatter(
        high_cos_low_corr['cosine_sim'],
        high_cos_low_corr['correlation'],
        alpha=1.0,
        s=100,
        color='red',
        marker='o',
        edgecolor='black',
        label='High cosine, low correlation'
    )
    
    # Highlight low cosine, high correlation pairs
    plt.scatter(
        low_cos_high_corr['cosine_sim'],
        low_cos_high_corr['correlation'],
        alpha=1.0,
        s=100,
        color='blue',
        marker='s',
        edgecolor='black',
        label='Low cosine, high correlation'
    )
    
    # Add reference lines
    plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    plt.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    plt.plot([0, 1], [0, 1], color='green', linestyle='--', alpha=0.5)
    
    # Add regions
    plt.axhspan(0.7, 1.0, xmax=0.3, alpha=0.1, color='blue', label='_nolegend_')
    plt.axvspan(0.7, 1.0, ymax=0.3, alpha=0.1, color='red', label='_nolegend_')
    
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Pearson Correlation')
    plt.title('Interesting Chunk Pairs: Divergent Cosine Similarity and Correlation')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return high_cos_low_corr, low_cos_high_corr

# Plot interesting pairs
high_cos_low_corr, low_cos_high_corr = plot_interesting_pairs(cosine_corr_df)

# Print details about interesting pairs
print("High Cosine Similarity, Low Correlation Pairs:")
print(high_cos_low_corr[['problem_id', 'category_i', 'category_j', 'cosine_sim', 'correlation']])
print("\nLow Cosine Similarity, High Correlation Pairs:")
print(low_cos_high_corr[['problem_id', 'category_i', 'category_j', 'cosine_sim', 'correlation']])

In [None]:
def analyze_category_pairs(df: pd.DataFrame):
    """
    Analyze which category pairs tend to have high cosine but low correlation
    and which have low cosine but high correlation.
    
    Args:
        df: DataFrame with cosine similarity and correlation data
    """
    # Create category pair labels
    df['category_pair'] = df.apply(
        lambda row: f"{row['category_i']} → {row['category_j']}" 
        if row['category_i'] <= row['category_j'] 
        else f"{row['category_j']} → {row['category_i']}", 
        axis=1
    )
    
    # Define interesting pairs
    df['high_cos_low_corr'] = (df['cosine_sim'] > 0.7) & (df['correlation'] < 0.3)
    df['low_cos_high_corr'] = (df['cosine_sim'] < 0.3) & (df['correlation'] > 0.7)
    
    # Count occurrences by category pair
    high_cos_low_corr_counts = df[df['high_cos_low_corr']].groupby('category_pair').size()
    low_cos_high_corr_counts = df[df['low_cos_high_corr']].groupby('category_pair').size()
    
    # Get top pairs
    top_high_cos_low_corr = high_cos_low_corr_counts.sort_values(ascending=False).head(10)
    top_low_cos_high_corr = low_cos_high_corr_counts.sort_values(ascending=False).head(10)
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # High cosine, low correlation
    if not top_high_cos_low_corr.empty:
        top_high_cos_low_corr.plot.barh(ax=ax1, color='red')
        ax1.set_title('Top Category Pairs with High Cosine, Low Correlation')
        ax1.set_xlabel('Count')
        ax1.set_ylabel('Category Pair')
    else:
        ax1.text(0.5, 0.5, "No pairs found", ha='center', va='center')
        ax1.set_title('High Cosine, Low Correlation (None Found)')
    
    # Low cosine, high correlation
    if not top_low_cos_high_corr.empty:
        top_low_cos_high_corr.plot.barh(ax=ax2, color='blue')
        ax2.set_title('Top Category Pairs with Low Cosine, High Correlation')
        ax2.set_xlabel('Count')
        ax2.set_ylabel('Category Pair')
    else:
        ax2.text(0.5, 0.5, "No pairs found", ha='center', va='center')
        ax2.set_title('Low Cosine, High Correlation (None Found)')
    
    plt.tight_layout()
    plt.show()
    
    return top_high_cos_low_corr, top_low_cos_high_corr

# Analyze category pairs
top_high_cos_low_corr, top_low_cos_high_corr = analyze_category_pairs(cosine_corr_df)

In [None]:
def load_chunk_text(problem_id, seed, chunk_idx, analysis_dir):
    """Load the text of a specific chunk."""
    chunks_file = analysis_dir / problem_id / seed / "chunks_analyzed.json"
    if chunks_file.exists():
        with open(chunks_file, 'r', encoding='utf-8') as f:
            chunks_data = json.load(f)
            for chunk in chunks_data:
                if chunk['index'] == chunk_idx:
                    return chunk['text']
    return "Chunk text not found"

def visualize_example_pairs(high_cos_low_corr, low_cos_high_corr, analysis_dir):
    """
    Visualize example pairs of chunks with interesting properties.
    
    Args:
        high_cos_low_corr: DataFrame with high cosine, low correlation pairs
        low_cos_high_corr: DataFrame with low cosine, high correlation pairs
        analysis_dir: Path to the analysis directory
    """
    # Select one example from each category
    high_cos_example = high_cos_low_corr.iloc[0] if not high_cos_low_corr.empty else None
    low_cos_example = low_cos_high_corr.iloc[0] if not low_cos_high_corr.empty else None
    
    fig, axs = plt.subplots(2, 1, figsize=(12, 12))
    
    # High cosine, low correlation example
    if high_cos_example is not None:
        problem_id = high_cos_example['problem_id']
        seed = high_cos_example['seed']
        chunk_i = high_cos_example['chunk_i']
        chunk_j = high_cos_example['chunk_j']
        
        text_i = load_chunk_text(problem_id, seed, chunk_i, analysis_dir)
        text_j = load_chunk_text(problem_id, seed, chunk_j, analysis_dir)
        
        axs[0].text(0.01, 0.99, f"High Cosine ({high_cos_example['cosine_sim']:.2f}), Low Correlation ({high_cos_example['correlation']:.2f})",
                 fontsize=14, fontweight='bold', va='top', ha='left')
        
        axs[0].text(0.01, 0.90, f"Problem: {problem_id}, Seed: {seed}", fontsize=12, va='top', ha='left')
        axs[0].text(0.01, 0.85, f"Categories: {high_cos_example['category_i']} → {high_cos_example['category_j']}", 
                 fontsize=12, va='top', ha='left')
        
        axs[0].text(0.01, 0.75, "Chunk 1:", fontsize=12, fontweight='bold', va='top', ha='left')
        axs[0].text(0.01, 0.70, text_i[:500] + ("..." if len(text_i) > 500 else ""), 
                 fontsize=10, va='top', ha='left', wrap=True)
        
        axs[0].text(0.01, 0.40, "Chunk 2:", fontsize=12, fontweight='bold', va='top', ha='left')
        axs[0].text(0.01, 0.35, text_j[:500] + ("..." if len(text_j) > 500 else ""), 
                 fontsize=10, va='top', ha='left', wrap=True)
        
        axs[0].axis('off')
    else:
        axs[0].text(0.5, 0.5, "No high cosine, low correlation example found", 
                 fontsize=14, ha='center', va='center')
        axs[0].axis('off')
    
    # Low cosine, high correlation example
    if low_cos_example is not None:
        problem_id = low_cos_example['problem_id']
        seed = low_cos_example['seed']
        chunk_i = low_cos_example['chunk_i']
        chunk_j = low_cos_example['chunk_j']
        
        text_i = load_chunk_text(problem_id, seed, chunk_i, analysis_dir)
        text_j = load_chunk_text(problem_id, seed, chunk_j, analysis_dir)
        
        axs[1].text(0.01, 0.99, f"Low Cosine ({low_cos_example['cosine_sim']:.2f}), High Correlation ({low_cos_example['correlation']:.2f})",
                 fontsize=14, fontweight='bold', va='top', ha='left')
        
        axs[1].text(0.01, 0.90, f"Problem: {problem_id}, Seed: {seed}", fontsize=12, va='top', ha='left')
        axs[1].text(0.01, 0.85, f"Categories: {low_cos_example['category_i']} → {low_cos_example['category_j']}", 
                 fontsize=12, va='top', ha='left')
        
        axs[1].text(0.01, 0.75, "Chunk 1:", fontsize=12, fontweight='bold', va='top', ha='left')
        axs[1].text(0.01, 0.70, text_i[:500] + ("..." if len(text_i) > 500 else ""), 
                 fontsize=10, va='top', ha='left', wrap=True)
        
        axs[1].text(0.01, 0.40, "Chunk 2:", fontsize=12, fontweight='bold', va='top', ha='left')
        axs[1].text(0.01, 0.35, text_j[:500] + ("..." if len(text_j) > 500 else ""), 
                 fontsize=10, va='top', ha='left', wrap=True)
        
        axs[1].axis('off')
    else:
        axs[1].text(0.5, 0.5, "No low cosine, high correlation example found", 
                 fontsize=14, ha='center', va='center')
        axs[1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize example pairs
visualize_example_pairs(high_cos_low_corr, low_cos_high_corr, analysis_dir)