In [1]:
# Simple RDM Generator for Google Colab
# Upload a folder of images and generate a Representational Dissimilarity Matrix

# =============================================================================
# SETUP - Run this cell first
# =============================================================================

# Install required packages
!pip install open-clip-torch scipy

# Import libraries
import numpy as np
import cv2
import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import open_clip
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns
import matplotlib.pyplot as plt
from google.colab import files
from IPython.display import display
import scipy.io # Import scipy.io for saving .mat files
from google.colab import drive # Import drive here for mounting

print("✅ Setup complete!")

# =============================================================================
# CONFIGURATION - Modify these settings if needed
# =============================================================================

# Choose your model (options: 'resnet18', 'vgg16', 'alexnet', 'clip')
# If pooledModel is False, this single model will be used.
MODEL_NAME = 'clip'

# Maximum number of images to process (set to None for all images)
MAX_IMAGES = None

# Image file extensions to look for
VALID_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')

# Introduce the pooledModel variable
pooledModel = True # Set to True to pool RDMs from multiple models

# Define the list of models to use if pooledModel is True
POOLED_MODELS = ['resnet18', 'vgg16', 'alexnet', 'clip']

print(f"🔧 Configuration:")
if pooledModel:
    print(f"🔧 Using pooled models: {POOLED_MODELS}")
else:
    print(f"🔧 Using single model: {MODEL_NAME}")

if MAX_IMAGES:
    print(f"🔧 Will process maximum {MAX_IMAGES} images")
else:
    print("🔧 Will process all images in folder")

# =============================================================================
# GOOGLE DRIVE SETUP - Connect to your Drive folder
# =============================================================================

# Mount Google Drive
print("📁 Mounting Google Drive...")
drive.mount('/content/drive', force_remount=True)
print("✅ Google Drive mounted successfully!")

# =============================================================================
# FOLDER CONFIGURATION - Set your image folder path here
# =============================================================================

# CHANGE THIS PATH TO YOUR IMAGE FOLDER IN GOOGLE DRIVE
# Examples:
#   'MyDrive/your_folder_name'
#   'MyDrive/experiments/stimuli_set_1'
#   'MyDrive/research/images'

IMAGE_FOLDER = 'MyDrive/CVFolder/stimuli'  # ← CHANGE THIS PATH
OUTPUT_FOLDER = 'MyDrive/CVFolder/output' # ← Set your output folder path

# Full path to your images
full_image_path = f'/content/drive/{IMAGE_FOLDER}'
full_output_path = f'/content/drive/{OUTPUT_FOLDER}' # ← Full path to output folder

print(f"📂 Looking for images in: {IMAGE_FOLDER}")
print(f"📂 Output will be saved to: {OUTPUT_FOLDER}")

# Check if folder exists and get image files
if not os.path.exists(full_image_path):
    print(f"❌ ERROR: Folder not found!")
    print(f"   Expected path: {full_image_path}")
    print(f"   Please check your IMAGE_FOLDER path above.")
    print(f"   Make sure the folder exists in your Google Drive.")
    raise FileNotFoundError(f"Image folder not found: {IMAGE_FOLDER}")

# Create output folder if it doesn't exist
if not os.path.exists(full_output_path):
    print(f"Creating output folder: {full_output_path}")
    os.makedirs(full_output_path, exist_ok=True)
    print("✅ Output folder created!")

# Get all image files
all_files = os.listdir(full_image_path)
image_files = [f for f in all_files if f.lower().endswith(VALID_EXTENSIONS)]
image_files = sorted(image_files)

if len(image_files) == 0:
    print(f"❌ ERROR: No images found in {IMAGE_FOLDER}")
    print(f"   Looking for files with extensions: {VALID_EXTENSIONS}")
    print(f"   Files found: {all_files[:10]}{'...' if len(all_files) > 10 else ''}")
    raise ValueError("No valid image files found")

if MAX_IMAGES:
    image_files = image_files[:MAX_IMAGES]
    print(f"🔢 Limited to first {MAX_IMAGES} images")

print(f"✅ Found {len(image_files)} images:")
print(f"   {image_files[:5]}{'...' if len(image_files) > 5 else ''}")

# Set the images directory for processing
images_dir = full_image_path

# =============================================================================
# MODEL SETUP
# =============================================================================

def setup_model(model_name):
    """Setup the specified model"""

    # Image preprocessing
    torch_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    if model_name == 'resnet18':
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
        preprocess = torch_transform

    elif model_name == 'vgg16':
        model = models.vgg16_bn(weights=models.VGG16_BN_Weights.IMAGENET1K_V1)
        feature_extractor = torch.nn.Sequential(
            *list(model.features.children()),
            model.avgpool
        )
        preprocess = torch_transform

    elif model_name == 'alexnet':
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        feature_extractor = torch.nn.Sequential(
            *list(model.features.children()),
            model.avgpool
        )
        preprocess = torch_transform

    elif model_name == 'clip':
        model, preprocess, tokenizer = open_clip.create_model_and_transforms("ViT-B-32", pretrained='openai')
        feature_extractor = model.encode_image
    else:
        raise ValueError(f"Unknown model: {model_name}")

    # Only call .eval() if feature_extractor is a torch.nn.Module
    if isinstance(feature_extractor, torch.nn.Module):
      feature_extractor.eval()

    return feature_extractor, preprocess

def extract_features(image_path, model, preprocess, model_name):
    """Extract features from a single image"""

    # Load image
    if model_name == 'clip':
        image = Image.open(image_path).convert('RGB')
        image_tensor = preprocess(image).unsqueeze(0)
    else:
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_tensor = preprocess(image).unsqueeze(0)

    # Extract features
    with torch.no_grad():
        if model_name == 'clip':
            features = model(image_tensor)
        else:
            features = model(image_tensor)

    return features.flatten().numpy()

# =============================================================================
# FEATURE EXTRACTION
# =============================================================================

# Store features for each model
features_by_model = {}

if pooledModel:
    models_to_process = POOLED_MODELS
else:
    models_to_process = [MODEL_NAME]

for current_model_name in models_to_process:
    print(f"\n🤖 Setting up {current_model_name} model...")
    feature_extractor, preprocess = setup_model(current_model_name)
    print("✅ Model ready!")

    print(f"🔄 Extracting features from {len(image_files)} images using {current_model_name}...")
    current_model_features = []

    for i, img_file in enumerate(image_files):
        img_path = os.path.join(images_dir, img_file)
        features = extract_features(img_path, feature_extractor, preprocess, current_model_name)
        current_model_features.append(features)

        if (i + 1) % 10 == 0 or (i + 1) == len(image_files):
            print(f"   Processed {i + 1}/{len(image_files)} images...")

    features_by_model[current_model_name] = np.array(current_model_features)
    print(f"✅ Feature extraction complete for {current_model_name}! Shape: {features_by_model[current_model_name].shape}")

# =============================================================================
# RDM GENERATION AND VISUALIZATION
# =============================================================================

def create_rdm_plot(rdm_data, image_names, title_suffix, output_path):
    """Create and visualize the RDM plot"""

    # Use the provided RDM data directly
    rdm = rdm_data

    # Create the plot
    plt.figure(figsize=(12, 10))

    # Create heatmap
    ax = sns.heatmap(rdm,
                     cmap='viridis',
                     square=True,
                     xticklabels=image_names,
                     yticklabels=image_names,
                     cbar_kws={'label': 'Dissimilarity'},
                     fmt='.2f')

    # Rotate labels for better readability
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    plt.title(f'Representational Dissimilarity Matrix\n{title_suffix}',
              fontsize=16, fontweight='bold')
    plt.tight_layout()

    # Determine output filename based on title_suffix
    if 'Pooled' in title_suffix:
         output_filename = 'rdm_matrix_pooled.png'
    else:
         # Assuming title_suffix is like 'ALEXNET Model'
         model_name_for_filename = title_suffix.replace(' Model', '').lower()
         output_filename = f'rdm_matrix_{model_name_for_filename}.png'

    plt.savefig(os.path.join(output_path, output_filename), dpi=300, bbox_inches='tight')
    plt.show()

    return output_filename

def average_rdms(rdm_list):
    """Averages a list of RDM matrices."""
    # Convert the list of RDM matrices into a NumPy array
    rdm_array = np.array(rdm_list)
    # Calculate the element-wise mean across all the RDM matrices
    pooled_rdm = np.mean(rdm_array, axis=0)
    return pooled_rdm

# =============================================================================
# DOWNLOAD RESULTS
# =============================================================================

def download_results(rdm_matrix, output_path, model_name_identifier):
    """Save the RDM image and data to Google Drive"""

    # Save RDM data as numpy array
    npy_filename = f'rdm_data_{model_name_identifier.lower()}.npy'
    np.save(os.path.join(output_path, npy_filename), rdm_matrix)

    # Save RDM data in .mat format
    mat_filename = f'rdm_data_{model_name_identifier.lower()}.mat'
    scipy.io.savemat(os.path.join(output_path, mat_filename), {'rdm_matrix': rdm_matrix})

    print("✅ Results data saved to Google Drive!")
    print(f"   - {os.path.join(OUTPUT_FOLDER, npy_filename)}: Raw RDM data (can be loaded with np.load())")
    print(f"   - {os.path.join(OUTPUT_FOLDER, mat_filename)}: Raw RDM data in .mat format")


# Generate and display RDM(s)
print(f"📊 Creating RDM(s)...")

# Use just the filename (without extension) for cleaner labels
image_labels = [os.path.splitext(name)[0] for name in image_files]

if pooledModel:
    print("Calculating RDMs for each model for pooling...")
    rdm_list = []
    for current_model_name, features in features_by_model.items():
        print(f"  Calculating RDM for {current_model_name}...")
        # Calculate similarity matrix
        similarity = cosine_similarity(features)
        # Convert to dissimilarity (RDM)
        current_rdm = 1 - similarity
        rdm_list.append(current_rdm)
        # Optionally, visualize individual RDMs during pooling (can be commented out)
        # create_rdm_plot(current_rdm, image_labels, f'{current_model_name.upper()} Model (Individual)', full_output_path)

    print("✅ Individual RDMs calculated!")

    # Average the RDMs
    print("Averaging individual RDMs...")
    pooled_rdm = average_rdms(rdm_list)
    print("✅ RDMs averaged!")

    # Visualize the pooled RDM
    print("📊 Visualizing pooled RDM...")
    pooled_rdm_filename = create_rdm_plot(pooled_rdm, image_labels, 'Pooled Models', full_output_path)
    print("✅ Pooled RDM visualization created!")
    print(f"💾 Pooled RDM plot saved as '{os.path.join(OUTPUT_FOLDER, pooled_rdm_filename)}'")

    # Download the pooled RDM results
    print("💾 Saving pooled RDM data...")
    download_results(pooled_rdm, full_output_path, 'pooled') # Use 'pooled' as identifier
    print("✅ Pooled RDM data saved!")


    print("\n" + "="*60)
    print("🎉 ANALYSIS COMPLETE!")
    print("="*60)
    print(f"📊 Generated pooled RDM for {len(image_files)} images using {len(POOLED_MODELS)} models")
    print(f"💾 Results saved to {full_output_path}")
    print("🔄 To process new images, change IMAGE_FOLDER path and restart")
    print("⚙️  To use a different model or change pooling, modify configuration")


else:
    # Process the single model as before
    single_model_name = models_to_process[0]
    all_features = features_by_model[single_model_name]
    # Calculate similarity matrix
    similarity = cosine_similarity(all_features)
    # Convert to dissimilarity (RDM)
    rdm_matrix = 1 - similarity

    rdm_filename = create_rdm_plot(rdm_matrix, image_labels, f'{single_model_name.upper()} Model', full_output_path)

    print("✅ RDM created successfully!")
    print(f"💾 RDM saved as '{os.path.join(OUTPUT_FOLDER, rdm_filename)}'")

    # This function is already defined above, calling it here
    download_results(rdm_matrix, full_output_path, single_model_name)

    print("\n" + "="*60)
    print("🎉 ANALYSIS COMPLETE!")
    print("="*60)
    print(f"📊 Generated RDM for {len(image_files)} images using {single_model_name}")
    print(f"💾 Results saved to {full_output_path}")
    print("🔄 To process new images, change IMAGE_FOLDER path and restart")
    print("⚙️  To use a different model, change MODEL_NAME and restart")

Collecting open-clip-torch
  Downloading open_clip_torch-3.0.0-py3-none-any.whl.metadata (32 kB)
Collecting torch>=2.0 (from open-clip-torch)
  Downloading torch-2.7.1-cp312-cp312-win_amd64.whl.metadata (28 kB)
Collecting torchvision (from open-clip-torch)
  Downloading torchvision-0.22.1-cp312-cp312-win_amd64.whl.metadata (6.1 kB)
Collecting ftfy (from open-clip-torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting huggingface-hub (from open-clip-torch)
  Downloading huggingface_hub-0.34.3-py3-none-any.whl.metadata (14 kB)
Collecting safetensors (from open-clip-torch)
  Downloading safetensors-0.5.3-cp38-abi3-win_amd64.whl.metadata (3.9 kB)
Collecting timm>=1.0.17 (from open-clip-torch)
  Downloading timm-1.0.19-py3-none-any.whl.metadata (60 kB)
     ---------------------------------------- 0.0/60.8 kB ? eta -:--:--
     ---------------------------------------- 60.8/60.8 kB 3.2 MB/s eta 0:00:00
Collecting sympy>=1.13.3 (from torch>=2.0->open-clip-torch)
  Downl

ModuleNotFoundError: No module named 'cv2'