In [3]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import pearsonr

def load_similarity_data(vit_dir, vgg_dir):
    vit_data = pd.read_csv(os.path.join(vit_dir, 'similarity_metrics.csv'))
    
    vgg_data = {}
    for file in os.listdir(vgg_dir):
        if file.startswith('similarity_metrics_') and file.endswith('.csv'):
            layer_name = file.split('_', 2)[2].split('.')[0]
            vgg_data[layer_name] = pd.read_csv(os.path.join(vgg_dir, file))
    
    return vit_data, vgg_data

def calculate_correlation(vit_data, vgg_data):
    correlations = {}
    for layer, data in vgg_data.items():
        merged_data = pd.merge(vit_data, data, on=['Image1', 'Image2'], suffixes=('_vit', '_vgg'))
        correlation, _ = pearsonr(merged_data['Similarity_vit'], merged_data['Similarity_vgg'])
        correlations[layer] = correlation
    return correlations

def plot_correlation_heatmap(correlations):
    plt.figure(figsize=(12, 8))
    sns.heatmap(pd.DataFrame.from_dict(correlations, orient='index', columns=['Correlation']),
                annot=True, cmap='coolwarm', center=0)
    plt.title('Correlation between ViT and VGG16 Layer Similarities')
    plt.tight_layout()
    plt.savefig('correlation_heatmap.png')
    plt.close()

def plot_similarity_comparison(vit_data, vgg_data, num_layers=5):
    plt.figure(figsize=(15, 10))
    
    # Plot ViT similarities
    plt.scatter(range(len(vit_data)), vit_data['Similarity'], label='ViT', alpha=0.7)
    
    # Plot VGG16 similarities for selected layers
    layers = list(vgg_data.keys())
    selected_layers = layers[:num_layers] + layers[-num_layers:]
    for layer in selected_layers:
        plt.scatter(range(len(vgg_data[layer])), vgg_data[layer]['Similarity'], label=f'VGG16 - {layer}', alpha=0.7)
    
    plt.xlabel('Image Pair Index')
    plt.ylabel('Similarity')
    plt.title('Similarity Comparison: ViT vs VGG16 Layers')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig('similarity_comparison.png')
    plt.close()

def calculate_average_similarity(data):
    return data['Similarity'].mean()

def plot_average_similarities(vit_avg, vgg_avgs):
    plt.figure(figsize=(12, 6))
    
    layers = list(vgg_avgs.keys())
    vgg_values = list(vgg_avgs.values())
    
    plt.plot(layers, vgg_values, marker='o', label='VGG16')
    plt.axhline(y=vit_avg, color='r', linestyle='--', label='ViT')
    
    plt.xlabel('VGG16 Layers')
    plt.ylabel('Average Similarity')
    plt.title('Average Similarities: ViT vs VGG16 Layers')
    plt.xticks(rotation=90)
    plt.legend()
    plt.tight_layout()
    plt.savefig('average_similarities.png')
    plt.close()

def plot_similarity_distribution(vit_data, vgg_data):
    plt.figure(figsize=(12, 6))
    sns.kdeplot(data=vit_data, x='Similarity', label='ViT', shade=True)
    
    for layer, data in vgg_data.items():
        if layer in ['block1_conv1', 'block3_conv1', 'block5_conv1']:
            sns.kdeplot(data=data, x='Similarity', label=f'VGG16 - {layer}', shade=True)
    
    plt.xlabel('Similarity')
    plt.ylabel('Density')
    plt.title('Distribution of Similarities: ViT vs Selected VGG16 Layers')
    plt.legend()
    plt.tight_layout()
    plt.savefig('similarity_distribution.png')
    plt.close()

def plot_similarity_evolution(vit_data, vgg_data):
    vgg_layers = list(vgg_data.keys())
    vgg_means = [data['Similarity'].mean() for data in vgg_data.values()]
    vgg_stds = [data['Similarity'].std() for data in vgg_data.values()]

    plt.figure(figsize=(12, 6))
    plt.errorbar(range(len(vgg_layers)), vgg_means, yerr=vgg_stds, fmt='-o', capsize=5, label='VGG16')
    plt.axhline(y=vit_data['Similarity'].mean(), color='r', linestyle='--', label='ViT Mean')
    plt.fill_between(range(len(vgg_layers)), 
                     vit_data['Similarity'].mean() - vit_data['Similarity'].std(),
                     vit_data['Similarity'].mean() + vit_data['Similarity'].std(),
                     alpha=0.2, color='r', label='ViT Std Dev')

    plt.xlabel('VGG16 Layers')
    plt.ylabel('Mean Similarity')
    plt.title('Evolution of Similarities Through VGG16 Layers')
    plt.xticks(range(len(vgg_layers)), vgg_layers, rotation=90)
    plt.legend()
    plt.tight_layout()
    plt.savefig('similarity_evolution.png')
    plt.close()

def plot_scatter_comparison(vit_data, vgg_data, selected_layers=['block1_conv1', 'block3_conv1', 'block5_conv1']):
    fig, axes = plt.subplots(1, len(selected_layers), figsize=(20, 6), sharey=True)
    fig.suptitle('ViT vs VGG16 Similarity Scatter Plots', fontsize=16)

    for i, layer in enumerate(selected_layers):
        merged_data = pd.merge(vit_data, vgg_data[layer], on=['Image1', 'Image2'], suffixes=('_vit', '_vgg'))
        axes[i].scatter(merged_data['Similarity_vit'], merged_data['Similarity_vgg'], alpha=0.6)
        axes[i].set_xlabel('ViT Similarity')
        axes[i].set_title(f'VGG16 - {layer}')
        
        # Add diagonal line
        axes[i].plot([0, 1], [0, 1], transform=axes[i].transAxes, ls='--', c='r')

    axes[0].set_ylabel('VGG16 Similarity')
    plt.tight_layout()
    plt.savefig('scatter_comparison.png')
    plt.close()

def calculate_mse(vit_data, vgg_data):
    mse_scores = {}
    for layer, data in vgg_data.items():
        merged_data = pd.merge(vit_data, data, on=['Image1', 'Image2'], suffixes=('_vit', '_vgg'))
        mse = mean_squared_error(merged_data['Similarity_vit'], merged_data['Similarity_vgg'])
        mse_scores[layer] = mse
    return mse_scores

def plot_mse_comparison(mse_scores):
    plt.figure(figsize=(12, 6))
    plt.bar(mse_scores.keys(), mse_scores.values())
    plt.xlabel('VGG16 Layers')
    plt.ylabel('Mean Squared Error')
    plt.title('MSE between ViT and VGG16 Layer Similarities')
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig('mse_comparison.png')
    plt.close()

def main():
    vit_dir = 'vit_visualizations'
    vgg_dir = 'vgg16_visualizations'
    
    vit_data, vgg_data = load_similarity_data(vit_dir, vgg_dir)
    
    # Existing plots
    correlations = calculate_correlation(vit_data, vgg_data)
    plot_correlation_heatmap(correlations)
    plot_similarity_comparison(vit_data, vgg_data)
    vit_avg = calculate_average_similarity(vit_data)
    vgg_avgs = {layer: calculate_average_similarity(data) for layer, data in vgg_data.items()}
    plot_average_similarities(vit_avg, vgg_avgs)
    
    # New plots
    plot_similarity_distribution(vit_data, vgg_data)
    plot_similarity_evolution(vit_data, vgg_data)
    plot_scatter_comparison(vit_data, vgg_data)
    
    mse_scores = calculate_mse(vit_data, vgg_data)
    plot_mse_comparison(mse_scores)
    
    print("Enhanced benchmark analysis complete. Check the output directory for visualization results.")

if __name__ == "__main__":
    main()


`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(data=vit_data, x='Similarity', label='ViT', shade=True)

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(data=data, x='Similarity', label=f'VGG16 - {layer}', shade=True)

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(data=data, x='Similarity', label=f'VGG16 - {layer}', shade=True)

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(data=data, x='Similarity', label=f'VGG16 - {layer}', shade=True)


Enhanced benchmark analysis complete. Check the output directory for visualization results.


In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error

def load_data(vit_file, vgg_file):
    vit_data = pd.read_csv(vit_file)
    vgg_data = pd.read_csv(vgg_file)
    return pd.merge(vit_data, vgg_data, on=['Image1', 'Image2'], suffixes=('_vit', '_vgg'))

def plot_scatter(data):
    plt.figure(figsize=(10, 10))
    plt.scatter(data['Similarity_vit'], data['Similarity_vgg'], alpha=0.6)
    plt.plot([0, 1], [0, 1], 'r--')  # Diagonal line
    plt.xlabel('ViT Similarity')
    plt.ylabel('VGG16 block1_conv1 Similarity')
    plt.title('Scatter Plot: ViT vs VGG16 block1_conv1 Similarities')
    plt.tight_layout()
    plt.savefig('vit_vs_vgg_block1_conv1_scatter.png')
    plt.close()

def plot_histogram(data):
    plt.figure(figsize=(12, 6))
    sns.histplot(data=data, x='Similarity_vit', kde=True, label='ViT', color='blue', alpha=0.5)
    sns.histplot(data=data, x='Similarity_vgg', kde=True, label='VGG16 block1_conv1', color='red', alpha=0.5)
    plt.xlabel('Similarity')
    plt.ylabel('Frequency')
    plt.title('Histogram: ViT vs VGG16 block1_conv1 Similarities')
    plt.legend()
    plt.tight_layout()
    plt.savefig('vit_vs_vgg_block1_conv1_histogram.png')
    plt.close()

def plot_difference_histogram(data):
    data['Difference'] = data['Similarity_vit'] - data['Similarity_vgg']
    plt.figure(figsize=(10, 6))
    sns.histplot(data=data, x='Difference', kde=True)
    plt.xlabel('Difference (ViT - VGG16 block1_conv1)')
    plt.ylabel('Frequency')
    plt.title('Histogram of Differences: ViT - VGG16 block1_conv1')
    plt.axvline(x=0, color='r', linestyle='--')
    plt.tight_layout()
    plt.savefig('vit_vs_vgg_block1_conv1_difference_histogram.png')
    plt.close()

def plot_bland_altman(data):
    mean = (data['Similarity_vit'] + data['Similarity_vgg']) / 2
    diff = data['Similarity_vit'] - data['Similarity_vgg']
    
    plt.figure(figsize=(10, 6))
    plt.scatter(mean, diff, alpha=0.6)
    plt.axhline(y=np.mean(diff), color='r', linestyle='--')
    plt.axhline(y=np.mean(diff) + 1.96 * np.std(diff), color='g', linestyle='--')
    plt.axhline(y=np.mean(diff) - 1.96 * np.std(diff), color='g', linestyle='--')
    plt.xlabel('Mean of ViT and VGG16 block1_conv1')
    plt.ylabel('Difference (ViT - VGG16 block1_conv1)')
    plt.title('Bland-Altman Plot: ViT vs VGG16 block1_conv1')
    plt.tight_layout()
    plt.savefig('vit_vs_vgg_block1_conv1_bland_altman.png')
    plt.close()

def calculate_metrics(data):
    correlation, _ = pearsonr(data['Similarity_vit'], data['Similarity_vgg'])
    mse = mean_squared_error(data['Similarity_vit'], data['Similarity_vgg'])
    mean_diff = np.mean(data['Similarity_vit'] - data['Similarity_vgg'])
    std_diff = np.std(data['Similarity_vit'] - data['Similarity_vgg'])
    
    return {
        'Correlation': correlation,
        'MSE': mse,
        'Mean Difference': mean_diff,
        'Std Difference': std_diff
    }

def main():
    vit_file = 'vit_visualizations/similarity_metrics.csv'
    vgg_file = 'vgg16_visualizations/similarity_metrics_block1_conv1.csv'
    
    data = load_data(vit_file, vgg_file)
    
    plot_scatter(data)
    plot_histogram(data)
    plot_difference_histogram(data)
    plot_bland_altman(data)
    
    metrics = calculate_metrics(data)
    
    print("Comparison Metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}")

if __name__ == "__main__":
    main()

Comparison Metrics:
Correlation: -0.1291
MSE: 0.0362
Mean Difference: -0.1802
Std Difference: 0.0608
