In [None]:
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import os
import glob
import faiss
import torchvision.models as models
import matplotlib.pyplot as plt
import pandas as pd
import time
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from collections import defaultdict
import seaborn as sns
import matplotlib.patches as mpatches


# ====================== Output Folder Setup =======================
OUTPUT_DIR = "new_run"
os.makedirs(OUTPUT_DIR, exist_ok=True)


# ====================== Feature Extraction for Multiple Models =======================

def extract_features_model(image, model, transform):
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))  # Move to GPU if available
    
    with torch.no_grad():
        features = model(image)
    
    return features.cpu().numpy().flatten()

# ====================== Load Images from Multiple Directories =======================

def load_images_from_directories(directories):
    images, filenames = [], []
    for directory in directories:
        for file_path in glob.glob(os.path.join(directory, '*.jpg')):
            img = Image.open(file_path)
            images.append(img)
            filenames.append(os.path.basename(file_path))
    return images, filenames

def extract_features_from_dataset(directories, model):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    images, filenames = load_images_from_directories(directories)
    
    features_array = np.array([extract_features_model(img, model, transform) for img in images])
    
    return features_array, filenames, images

# ====================== FAISS Indexing =======================

def build_faiss_index(features_array, distance_metric='L2'):
    embedding_dim = features_array.shape[1]
    
    if distance_metric == 'L2':
        index = faiss.IndexFlatL2(embedding_dim)
    elif distance_metric == 'Cosine':
        index = faiss.IndexFlatIP(embedding_dim)
        # Normalize features for cosine similarity
        features_array_copy = features_array.copy()
        faiss.normalize_L2(features_array_copy)
        index.add(features_array_copy)
        return index
    
    index.add(features_array)
    return index


def build_ivf_pq_index(features_array, n_clusters, n_bits):
    """Build IVF-PQ index with specified parameters for hyperparameter testing"""
    embedding_dim = features_array.shape[1]
    
    # Create quantizer
    quantizer = faiss.IndexFlatL2(embedding_dim)
    
    # Create IVF-PQ index
    index = faiss.IndexIVFPQ(quantizer, embedding_dim, n_clusters, 
                             n_bits, 8)  # 8 is the number of bits per subquantizer
    
    # Train and add vectors
    index.train(features_array)
    index.add(features_array)
    
    return index

def perform_faiss_search(index, query_index, features_array, k=5, distance_metric='L2', n_trials=10):
    """Perform search with multiple trials to calculate mean and std of search time"""
    query_vector = features_array[query_index:query_index+1].copy()
    
    if distance_metric == 'Cosine':
        faiss.normalize_L2(query_vector)
    
    # Run multiple trials for timing
    times = []
    for _ in range(n_trials):
        start_time = time.perf_counter()  # Higher precision timer
        distances, indices = index.search(query_vector, k=k)
        times.append(time.perf_counter() - start_time)
    
    search_time_mean = np.mean(times)
    search_time_std = np.std(times)
    
    return distances[0], indices[0], search_time_mean, search_time_std

# ====================== Load BIRADS Labels from Excel =======================

def load_birads_labels(excel_file):
    df = pd.read_excel(excel_file, sheet_name='all')
    birads_labels = df['BIRADS'].tolist()
    
    # Count occurrences of each BIRADS category
    category_counts = df['BIRADS'].value_counts().to_dict()
    print(f"BIRADS Category Distribution: {category_counts}")
    
    return birads_labels, df['BIRADS']

# ====================== Precision, Recall, NDCG =======================

def precision_at_k(retrieved_labels, true_label, k=5):
    return sum([1 for label in retrieved_labels[:k] if label == true_label]) / k

def recall_at_k(retrieved_labels, true_label, total_relevant, k=5):
    retrieved_relevant = sum([1 for label in retrieved_labels[:k] if label == true_label])
    return retrieved_relevant / total_relevant if total_relevant > 0 else 0

def ndcg_at_k(retrieved_labels, true_label, k=5):
    dcg = sum([1 / np.log2(i + 2) if retrieved_labels[i] == true_label else 0 for i in range(k)])
    idcg = sum([1 / np.log2(i + 2) for i in range(min(k, sum([1 for l in retrieved_labels if l == true_label])))])
    return dcg / idcg if idcg > 0 else 0

def rank_of_true_label(retrieved_labels, true_label):
    try:
        return retrieved_labels.index(true_label) + 1
    except ValueError:
        return None

# ====================== PCA and t-SNE Visualization =======================

def visualize_pca_and_tsne(features_array, labels, model_name, n_points=500):
    # Convert labels to numpy array for easier handling
    labels_array = np.array(labels[:n_points])
    
    # Perform PCA
    pca = PCA(n_components=50)
    reduced_pca = pca.fit_transform(features_array[:n_points])
    
    # Calculate and print explained variance
    explained_variance = np.sum(pca.explained_variance_ratio_)
    print(f"Explained variance ratio by PCA components for {model_name}: {explained_variance:.2f}")
    
    # First 2 components for PCA visualization
    plt.figure(figsize=(12, 8))
    
    # Get unique labels and create a categorical colormap
    unique_labels = np.unique(labels_array)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
    
    # Create scatter plot with clear label distinctions
    for i, label in enumerate(unique_labels):
        idx = np.where(labels_array == label)[0]
        plt.scatter(reduced_pca[idx, 0], reduced_pca[idx, 1], 
                   color=colors[i], label=f'BIRADS {int(label)}', alpha=0.7)
    
    plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.2f})")
    plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.2f})")
    plt.title(f"PCA Visualization of {model_name} Feature Space")
    plt.legend(title='BIRADS Category')
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(OUTPUT_DIR, f'pca_{model_name}.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Perform t-SNE on the PCA reduced data
    tsne = TSNE(n_components=2, random_state=42)
    reduced_tsne = tsne.fit_transform(reduced_pca)
    
    # Create t-SNE visualization with clear label distinctions
    plt.figure(figsize=(12, 8))
    for i, label in enumerate(unique_labels):
        idx = np.where(labels_array == label)[0]
        plt.scatter(reduced_tsne[idx, 0], reduced_tsne[idx, 1], 
                   color=colors[i], label=f'BIRADS {int(label)}', alpha=0.7)
    
    plt.title(f"t-SNE Visualization of {model_name} Embedding Space (after PCA)")
    plt.legend(title='BIRADS Category')
    plt.grid(alpha=0.3)
    plt.savefig(os.path.join(OUTPUT_DIR,f"tsne_{model_name}.png"), dpi=300, bbox_inches='tight')
    plt.show()
    
    return reduced_pca, reduced_tsne

# ====================== Distance Metric Analysis =======================

def analyze_distance_metrics(features_array, birads_labels, model_name, query_indices, k_values):
    """Compare L2 distance vs Cosine similarity"""
    distance_metrics = ['L2', 'Cosine']
    results = []
    
    for distance_metric in distance_metrics:
        print(f"\nAnalyzing {distance_metric} distance for {model_name}...")
        
        # Build index
        index = build_faiss_index(features_array, distance_metric)
        
        for query_index in query_indices:
            true_label = birads_labels[query_index]
            total_relevant = birads_labels.count(true_label)
            
            for k in k_values:
                distances, indices, search_time_mean, search_time_std = perform_faiss_search(
                    index, query_index, features_array, k=k, 
                    distance_metric=distance_metric, n_trials=10
                )
                
                retrieved_labels = [birads_labels[idx] for idx in indices if idx < len(birads_labels)]
                
                precision = precision_at_k(retrieved_labels, true_label, k=k)
                recall = recall_at_k(retrieved_labels, true_label, total_relevant, k=k)
                ndcg = ndcg_at_k(retrieved_labels, true_label, k=k)
                
                results.append({
                    'model': model_name,
                    'distance_metric': distance_metric,
                    'query_image': query_index,
                    'k': k,
                    'precision': precision,
                    'recall': recall,
                    'ndcg': ndcg,
                    'search_time_mean': search_time_mean,
                    'search_time_std': search_time_std,
                    'birads_category': true_label
                })
    
    # Save results
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(OUTPUT_DIR, f'distance_metric_comparison_{model_name}.csv'), index=False)
    
    # Create comparison plot
    plt.figure(figsize=(15, 10))
    
    # Subplot for precision
    plt.subplot(2, 2, 1)
    for metric in distance_metrics:
        metric_data = df[df['distance_metric'] == metric]
        for k in k_values:
            k_data = metric_data[metric_data['k'] == k]
            plt.scatter(k, k_data['precision'].mean(), 
                      label=f"{metric} (k={k})")
    plt.xlabel('k value')
    plt.ylabel('Average Precision')
    plt.title(f'Average Precision by Distance Metric ({model_name})')
    plt.grid(alpha=0.3)
    
    # Subplot for recall
    plt.subplot(2, 2, 2)
    for metric in distance_metrics:
        metric_data = df[df['distance_metric'] == metric]
        for k in k_values:
            k_data = metric_data[metric_data['k'] == k]
            plt.scatter(k, k_data['recall'].mean(), 
                      label=f"{metric} (k={k})")
    plt.xlabel('k value')
    plt.ylabel('Average Recall')
    plt.title(f'Average Recall by Distance Metric ({model_name})')
    plt.grid(alpha=0.3)
    
    # Subplot for NDCG
    plt.subplot(2, 2, 3)
    for metric in distance_metrics:
        metric_data = df[df['distance_metric'] == metric]
        for k in k_values:
            k_data = metric_data[metric_data['k'] == k]
            plt.scatter(k, k_data['ndcg'].mean(), 
                      label=f"{metric} (k={k})")
    plt.xlabel('k value')
    plt.ylabel('Average NDCG')
    plt.title(f'Average NDCG by Distance Metric ({model_name})')
    plt.grid(alpha=0.3)
    
    # Subplot for search time
    plt.subplot(2, 2, 4)
    for metric in distance_metrics:
        metric_data = df[df['distance_metric'] == metric]
        for k in k_values:
            k_data = metric_data[metric_data['k'] == k]
            plt.errorbar(k, k_data['search_time_mean'].mean(), 
                       yerr=k_data['search_time_std'].mean(),
                       label=f"{metric} (k={k})")
    plt.xlabel('k value')
    plt.ylabel('Average Search Time (s)')
    plt.title(f'Average Search Time by Distance Metric ({model_name})')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR,f'distance_metric_comparison_{model_name}.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return results

# ====================== Hyperparameter Analysis for FAISS =======================

def analyze_faiss_hyperparameters(features_array, birads_labels, model_name, query_indices):
    """Test different hyperparameters for FAISS indexing"""
    n_clusters_values = [4, 8, 16, 32]
    n_bits_values = [4, 8]
    results = []
    k = 10  # Fixed k for hyperparameter analysis
    
    for n_clusters in n_clusters_values:
        for n_bits in n_bits_values:
            print(f"\nTesting FAISS with n_clusters={n_clusters}, n_bits={n_bits}...")
            
            try:
                # Build index with these parameters
                index = build_ivf_pq_index(features_array, n_clusters, n_bits)
                
                # Set nprobe (number of clusters to visit during search)
                index.nprobe = max(1, n_clusters // 4)
                
                for query_index in query_indices:
                    true_label = birads_labels[query_index]
                    total_relevant = birads_labels.count(true_label)
                    
                    # Perform search
                    start_times = []
                    all_distances = []
                    all_indices = []
                    
                    # Run multiple trials
                    n_trials = 5
                    for _ in range(n_trials):
                        start_time = time.perf_counter()
                        distances, indices = index.search(features_array[query_index:query_index+1], k=k)
                        start_times.append(time.perf_counter() - start_time)
                        all_distances.append(distances[0])
                        all_indices.append(indices[0])
                    
                    # Use most frequent indices result
                    from collections import Counter
                    indices_counter = Counter([tuple(arr) for arr in all_indices])
                    most_common_indices = np.array(indices_counter.most_common(1)[0][0])
                    
                    # Calculate metrics
                    retrieved_labels = [birads_labels[idx] for idx in most_common_indices if idx < len(birads_labels)]
                    precision = precision_at_k(retrieved_labels, true_label, k=k)
                    recall = recall_at_k(retrieved_labels, true_label, total_relevant, k=k)
                    ndcg = ndcg_at_k(retrieved_labels, true_label, k=k)
                    
                    results.append({
                        'model': model_name,
                        'n_clusters': n_clusters,
                        'n_bits': n_bits,
                        'query_image': query_index,
                        'precision': precision,
                        'recall': recall,
                        'ndcg': ndcg,
                        'search_time_mean': np.mean(start_times),
                        'search_time_std': np.std(start_times),
                        'birads_category': true_label
                    })
            except Exception as e:
                print(f"Error with n_clusters={n_clusters}, n_bits={n_bits}: {e}")
    
    # Save results
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(OUTPUT_DIR,f'faiss_hyperparameter_analysis_{model_name}.csv'), index=False)
    
    # Create visualization
    plt.figure(figsize=(15, 12))
    
    # Metrics to plot
    metrics = ['precision', 'recall', 'ndcg', 'search_time_mean']
    titles = ['Precision', 'Recall', 'NDCG', 'Search Time (s)']
    
    for i, (metric, title) in enumerate(zip(metrics, titles)):
        plt.subplot(2, 2, i+1)
        
        # Group by clusters and bits
        grouped_data = df.groupby(['n_clusters', 'n_bits'])[metric].mean().reset_index()
        
        # Pivot for heatmap
        pivot_data = grouped_data.pivot(index='n_clusters', columns='n_bits', values=metric)
        
        # Plot heatmap
        sns.heatmap(pivot_data, annot=True, cmap='viridis', fmt='.4f')
        plt.title(f'Average {title} by FAISS Parameters')
        plt.xlabel('Number of Bits')
        plt.ylabel('Number of Clusters')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR,f'faiss_hyperparameter_analysis_{model_name}.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    return results

# ====================== BIRADS Category Analysis =======================

def analyze_by_birads_category(all_results, model_names, k_values):
    """Analyze performance metrics grouped by BIRADS category"""
    # Convert to DataFrame
    df = pd.DataFrame(all_results)
    
    # Get unique BIRADS categories
    unique_categories = df['birads_category'].unique()
    
    # Create figure for plotting
    plt.figure(figsize=(20, 15))
    metric_names = ['precision', 'recall', 'ndcg']
    titles = ['Precision', 'Recall', 'NDCG']
    
    # Plot each metric
    for metric_idx, (metric, title) in enumerate(zip(metric_names, titles)):
        plt.subplot(3, 1, metric_idx+1)
        
        for model in model_names:
            model_data = df[df['model'] == model]
            
            # Calculate mean metric for each k and BIRADS category
            category_performance = []
            
            for category in unique_categories:
                for k in k_values:
                    category_k_data = model_data[(model_data['birads_category'] == category) & 
                                                 (model_data['k'] == k)]
                    
                    if not category_k_data.empty:
                        mean_metric = category_k_data[metric].mean()
                        category_performance.append({
                            'model': model,
                            'category': category,
                            'k': k,
                            'mean_metric': mean_metric
                        })
            
            # Create DataFrame for this model
            model_perf_df = pd.DataFrame(category_performance)
            
            # Plot lines for each BIRADS category
            for category in unique_categories:
                category_data = model_perf_df[model_perf_df['category'] == category]
                if not category_data.empty:
                    plt.plot(category_data['k'], category_data['mean_metric'], 
                             marker='o', label=f'{model} - BIRADS {int(category)}')
        
        plt.title(f'{title} by BIRADS Category and k Value')
        plt.xlabel('k Value')
        plt.ylabel(f'Mean {title}')
        plt.grid(alpha=0.3)
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR,'birads_category_analysis.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create and save detailed statistics table
    stats_rows = []
    
    for model in model_names:
        model_data = df[df['model'] == model]
        
        for category in unique_categories:
            category_data = model_data[model_data['birads_category'] == category]
            
            for k in k_values:
                k_data = category_data[category_data['k'] == k]
                
                if not k_data.empty:
                    stats_rows.append({
                        'Model': model,
                        'BIRADS': int(category),
                        'k': k,
                        'Precision': k_data['precision'].mean(),
                        'Recall': k_data['recall'].mean(),
                        'NDCG': k_data['ndcg'].mean(),
                        'Search Time (s)': k_data['search_time_mean'].mean() if 'search_time_mean' in k_data.columns else k_data['search_time_sec'].mean(),
                        'Count': len(k_data)
                    })
    
    stats_df = pd.DataFrame(stats_rows)
    stats_df.to_csv(os.path.join(OUTPUT_DIR,'birads_category_detailed_stats.csv'), index=False)
    
    # Generate summary table (Table 2 replacement)
    summary_table = stats_df.pivot_table(
        index=['Model', 'BIRADS'],
        columns=['k'],
        values=['Precision', 'Recall', 'NDCG'],
        aggfunc='mean'
    )
    
    summary_table.to_csv('summary_table_by_birads_and_k.csv')
    print("Generated summary table to replace the missing Table 2")
    
    return summary_table

# ====================== Save Results to CSV =======================

def save_results_to_csv(results, output_file):
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")

# ====================== Generate Missing Table 2 =======================

def generate_table_2(all_results, model_names, k_values=[10]):
    """Generate the missing Table 2 referenced in the paper"""
    df = pd.DataFrame(all_results)
    
    # Filter to specific k values
    df_filtered = df[df['k'].isin(k_values)]
    
    # Group by model and calculate average metrics
    table_data = []
    
    for model in model_names:
        model_data = df_filtered[df_filtered['model'] == model]
        
        avg_precision = model_data['precision'].mean()
        avg_recall = model_data['recall'].mean()
        avg_ndcg = model_data['ndcg'].mean()
        
        if 'search_time_mean' in model_data.columns:
            avg_search_time = model_data['search_time_mean'].mean()
            std_search_time = model_data['search_time_std'].mean()
        else:
            avg_search_time = model_data['search_time_sec'].mean()
            std_search_time = 0
        
        table_data.append({
            'Model': model,
            'Average Precision': avg_precision,
            'Average Recall': avg_recall,
            'Average NDCG': avg_ndcg,
            'Average Search Time (s)': avg_search_time,
            'Std Search Time (s)': std_search_time
        })
    
    # Create and save table
    table_df = pd.DataFrame(table_data)
    table_df.to_csv(os.path.join(OUTPUT_DIR,'table_2_model_comparison.csv'), index=False)
    
    print("Table 2 (Model Comparison) created and saved")
    return table_df

# ====================== Main Execution =======================

if __name__ == "__main__":
    directories = ['../results/resized_low_energy', '../results/resized_subtracted']
    
    # List of models for comparison
    model_names = {
        'DenseNet121': models.densenet121(pretrained=True),
        'ResNet50': models.resnet50(pretrained=True),
        'VGG16': models.vgg16(pretrained=True),
        'EfficientNet': models.efficientnet_b0(pretrained=True)
    }

    k_values = [1, 5, 10, 20, 50, 100]
    
    # Use more query images to address reviewer concern about sample size
    # Select images from each BIRADS category to ensure comprehensive evaluation
    query_indices = [0, 10, 20, 30, 50, 100, 150, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1500]

    # Load BIRADS labels
    birads_labels, birads_categories = load_birads_labels('../annotations/Radiology-manual-annotations(2).xlsx')
    
    # Store all results for combined analysis
    all_results = []

    for model_name, model in model_names.items():
        print(f"\nProcessing {model_name}...")

        # Modify the classifier for each model to extract embeddings
        if model_name in ['DenseNet121', 'EfficientNet']:
            model.classifier = torch.nn.Identity()
        elif model_name in ['ResNet50']:
            model.fc = torch.nn.Identity()
        elif model_name in ['VGG16']:
            model.classifier[6] = torch.nn.Identity()

        model.eval()

        # Extract features
        features_array, filenames, images = extract_features_from_dataset(directories, model)
        
        print(f"Feature vector shape for {model_name}: {features_array.shape}")

        # Basic evaluation with L2 distance
        results = []

        for query_index in query_indices:
            if query_index >= len(birads_labels):
                continue
                
            true_label = birads_labels[query_index]
            total_relevant = birads_labels.count(true_label)
            birads_category = birads_categories[query_index]
            
            print(f"Query {query_index}: BIRADS category {birads_category}, Total relevant: {total_relevant}")
            
            for k in k_values:
                # Build FAISS index
                index = build_faiss_index(features_array, 'L2')
                
                distances, indices, search_time_mean, search_time_std = perform_faiss_search(
                    index, query_index, features_array, k=k, distance_metric='L2', n_trials=10
                )
                
                retrieved_labels = [birads_labels[idx] for idx in indices if idx < len(birads_labels)]

                precision = precision_at_k(retrieved_labels, true_label, k=k)
                recall = recall_at_k(retrieved_labels, true_label, total_relevant, k=k)
                ndcg = ndcg_at_k(retrieved_labels, true_label, k=k)
                rank_of_true = rank_of_true_label(retrieved_labels, true_label)

                result = {
                    'model': model_name,
                    'query_image': query_index,
                    'k': k,
                    'precision': precision,
                    'recall': recall,
                    'ndcg': ndcg,
                    'rank_of_true_label': rank_of_true,
                    'min_distance': min(distances),
                    'max_distance': max(distances),
                    'mean_distance': np.mean(distances),
                    'search_time_mean': search_time_mean,
                    'search_time_std': search_time_std,
                    'birads_category': birads_category
                }
                
                results.append(result)
                all_results.append(result)
                
                print(f"k={k}, Precision={precision:.4f}, Recall={recall:.4f}, NDCG={ndcg:.4f}, "
                      f"Search Time={search_time_mean:.8f} ± {search_time_std:.8f}s")

        # Save results for each model
        save_results_to_csv(results, os.path.join(OUTPUT_DIR, f'faiss_detailed_results_{model_name}.csv'))
        
        # Analyze distance metrics (L2 vs Cosine)
        if model_name in ['DenseNet121', 'ResNet50']:  # Only for selected models to save time
            distance_results = analyze_distance_metrics(
                features_array, birads_labels, model_name, query_indices[:5], k_values
            )
            all_results.extend(distance_results)
        
        # Analyze FAISS hyperparameters
        if model_name == 'DenseNet121':  # Only for primary model
            hyperparameter_results = analyze_faiss_hyperparameters(
                features_array, birads_labels, model_name, query_indices[:3]
            )
            # Don't add to all_results as they use different indexing method
        
        # Create visualizations for each model
        if model_name in ['DenseNet121', 'ResNet50']:  # Create for at least two models
            print(f"\nCreating visualizations for {model_name}...")
            reduced_pca, reduced_tsne = visualize_pca_and_tsne(
                features_array, birads_labels, model_name, n_points=500
            )

    # Analyze by BIRADS category
    print("\nAnalyzing performance by BIRADS category...")
    category_analysis = analyze_by_birads_category(
        all_results, model_names.keys(), k_values
    )
    
    # Generate missing Table 2
    print("\nGenerating Table 2 for the paper...")
    table_2 = generate_table_2(all_results, model_names.keys())
    print(table_2)