In [None]:
from vislearnlabpy.embeddings.generate_embeddings import EmbeddingGenerator
from vislearnlabpy.embeddings.embedding_store import EmbeddingStore
from vislearnlabpy.embeddings.stimuli_loader import ImageExtractor
from vislearnlabpy.embeddings.utils import display_search_results, zscore_embeddings, filter_embeddings 
import numpy as np
import pandas as pd
from pathlib import Path

In [2]:
# clip input transformation
transforms_thumbnail = ImageExtractor.get_transformations(
        resize_dim=256,
        crop_dim=224,
        apply_content_crop=True,
        apply_center_crop=False,
        use_thumbnail=True
    )

clip_generator = EmbeddingGenerator(model_type="clip", device="mps", output_type="doc", transform=transforms_thumbnail) # change device="cpu" if you are using

First getting Beijing and CDM drawings

In [29]:
beijing_dir = Path("/Volumes/vislearnlab/experiments/drawing/data/beijing/object_drawings")
kisumu_dir = Path("/Volumes/vislearnlab/experiments/drawing/data/kisumu/drawings")
categories = ["airplane", "bike", "bird", "car", "cat", "chair", "cup", "hat", "house", "rabbit", "tree", "watch"]

In [None]:
clip_generator.generate_image_embeddings(output_path="beijing_cdm_drawings", input_dir=beijing_dir, batch_size=100, overwrite=True)

Now grabbing Kisumu drawings and getting the text embeddings for each category using the prompt 'drawing of a..'

In [None]:
clip_generator.generate_image_embeddings(output_path="kisumu_drawings", input_dir=kisumu_dir, batch_size=100, overwrite=True)
clip_generator.model.text_prompt = ""
clip_generator.save_text_embeddings([f"drawing of a {category}" for category in categories], "categories", overwrite=False)

In [None]:
kisumu_store =  EmbeddingStore.from_doc("kisumu_drawings/image_embeddings/clip_image_embeddings_doc")
beijing_cdm_store =  EmbeddingStore.from_doc("beijing_cdm_drawings/image_embeddings/clip_image_embeddings_doc")
text_embeddings = EmbeddingStore.from_doc("categories/text_embeddings/clip_text_embeddings_doc")

In [26]:
# only choosing embeddings that include verbal cues and not picture cues and separating CDM and Beijing
cdm_store = EmbeddingStore()
beijing_store = EmbeddingStore()
for embedding in beijing_cdm_store.EmbeddingList:
    # if this embedding is a verbal cue and not a picture cue
    if "S_" in embedding.url:
        if "CDM" in embedding.url:
            cdm_store.add_embedding(embedding=embedding.embedding, url=embedding.url)
        else:
            beijing_store.add_embedding(embedding=embedding.embedding, url=embedding.url)

In [None]:
print(f"Kisumu embeddings: {len(kisumu_store.EmbeddingList)}")
print(f"Beijing embeddings: {len(beijing_store.EmbeddingList)}")
print(f"CDM embeddings: {len(cdm_store.EmbeddingList)}")

Extract metadata about location, category and age from the file names and calculate probability of the correct category being chosen

In [5]:
import os
import re
from vislearnlabpy.embeddings.similarity_utils import calculate_probability
def extract_beijing_metadata(url):
    """Extract metadata from Beijing URL format"""
    filename = os.path.basename(url)
    # Determine location
    location = "Beijing" if "THU" in filename else "USA"
    # Extract age
    age_match = re.search(r'age(\d+)', filename)
    age = int(age_match.group(1)) if age_match else None
    # Extract category (between first and second underscore, after the initial letter)
    parts = filename.split('_')
    category = parts[1] if len(parts) > 1 else None
    # Extract participant ID (everything before .png)
    participant_id = filename.replace('.png', '').split("_")[-1]
    return {
        'location': location,
        'age': age,
        'category': category,
        'participant_id': participant_id
    }

def extract_kisumu_metadata(url):
    """Extract metadata from Kisumu URL format"""
    filename = os.path.basename(url)
    
    # Split on first underscore to get participant and rest
    parts = filename.split('_', 1)
    participant_id = parts[0]
    
    # Extract age and add 3
    age_match = re.search(r'age(\d+)', filename)
    age = int(age_match.group(1)) + 3 if age_match else None
    
    # Extract category (between age and trial)
    category_match = re.search(r'age\d+_([^_]+)_trial', filename)
    category = category_match.group(1) if category_match else None
    category = "bike" if category == "Bicycle" else category
    
    return {
        'location': 'Kisumu',
        'age': age,
        'category': category,
        'participant_id': participant_id
    }

def _process_site(
    text_embeddings,
    docs,
    extract_metadata_fn,
    recognizability_fn,
    results,
):
    """Helper to process embeddings for a single site and extend `results`."""
    for doc in docs:
        md = extract_metadata_fn(doc.url)
        if md["category"] and md["age"] is not None:
            results.append(
                {
                    "location": md["location"],
                    "recognizability": recognizability_fn(
                        doc.embedding,
                        text_embeddings,     
                        md["category"].lower()
                    ),
                    "age": md["age"],
                    "participant_id": md["participant_id"],
                    "drawing_category": md["category"].lower(),
                    "url": doc.url,
                }
            )

def process_embeddings(text_embeddings, beijing_list, cdm_list, kisumu_list):
    """Process all embeddings and create the final dataset."""
    results = []
    # processing beijing
    _process_site(
        text_embeddings,
        beijing_list, 
        extract_beijing_metadata, 
        calculate_probability,
        results,
    )
    # processing kisumu
    _process_site(
        text_embeddings, 
        kisumu_list,
        extract_kisumu_metadata,
        calculate_probability,   
        results,
    )
    # processing CDM
    _process_site(
        text_embeddings, 
        cdm_list,
        extract_beijing_metadata,
        calculate_probability,   
        results,
    )

    return results


# Main processing function
def create_drawing_analysis_csv(text_embeddings, beijing_list, cdm_list, kisumu_list, output_file='drawing_analysis.csv'):
    """
    Main function to process all embeddings and create CSV
    """
    
    # Process all embeddings
    results = process_embeddings(text_embeddings, beijing_list, cdm_list, kisumu_list)
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")
    print(f"Total drawings: {len(results)}")
    print(f"Locations: {df['location'].value_counts().to_dict()}")
    print(f"Age range: {df['age'].min()} - {df['age'].max()}")
    print(f"Categories: {df['drawing_category'].value_counts().to_dict()}")
    
    return df



In [None]:
recognizability_df = create_drawing_analysis_csv(text_embeddings.EmbeddingList, beijing_store.EmbeddingList, cdm_store.EmbeddingList, kisumu_store.EmbeddingList)

Adding in categories to the embedding stores now

In [27]:
# Map each URL to the row as a dictionary of fields
url_to_fields = recognizability_df.set_index('url').to_dict(orient='index')

# Assign fields in loop
for store in [beijing_store, kisumu_store, cdm_store]:
    for embedding in store.EmbeddingList:
        fields = url_to_fields.get(embedding.url, {})
        embedding.text = fields.get('drawing_category')

Calculating centroid distances

In [7]:
import pandas as pd
import numpy as np
from collections import defaultdict
from vislearnlabpy.embeddings.similarity_utils import cosine_sim
from docarray.utils.filter import filter_docs

def calculate_centroid_distances(df, kisumu_store, beijing_store, cdm_store):
    """
    Calculate centroid embeddings for each age-location-category combination and compute distances.
    Returns:
    df: Original dataframe with added 'distance' column
    """
    
    # Helper function to get embedding list based on location
    def get_embedding_list(location):
        location_lower = location.lower()
        if 'kisumu' in location_lower:
            return kisumu_store.EmbeddingList
        elif 'beijing' in location_lower:
            return beijing_store.EmbeddingList
        elif 'usa' in location_lower:
            return cdm_store.EmbeddingList
        else:
            return {}
    
    # Helper function to extract embedding from store
    def get_embedding(url, embedding_list):
        matches = filter_docs(embedding_list, {'url': {'$eq': url}})
        embedding = matches[0].embedding if matches else None
        return embedding
    
    # Group by age and location to calculate centroids
    centroids = {}
    
    for (age, location, category), group in df.groupby(['age', 'location', 'drawing_category']):
        embedding_list = get_embedding_list(location)
        embeddings = []
        
        # Collect all embeddings for this age-location combination
        for url in group['url']:
            embedding = get_embedding(url, embedding_list)
            if embedding is not None:
                embeddings.append(np.array(embedding))
        
        # Calculate centroid if we have embeddings
        if embeddings:
            centroid = np.mean(embeddings, axis=0)
            centroids[(age, location, category)] = centroid
        else:
            print(f"Warning: No embeddings found for age={age}, category={category}, location={location}")
            centroids[(age, location, category)] = None
    
    # Calculate distances for each row in the dataframe
    distances = []
    distances_euclidean = []
    for _, row in df.iterrows():
        age = row['age']
        location = row['location']
        category = row['drawing_category']
        url = row['url']
        # Get the centroid for this age-location combination
        centroid = centroids.get((age, location, category))
        if centroid is None:
            distances.append(np.nan)
            continue
        # Get the embedding for this specific URL
        embedding_list = get_embedding_list(location)
        embedding = get_embedding(url, embedding_list)
        
        if embedding is None:
            distances.append(np.nan)
            continue
        
        embedding = np.stack(embedding)
        # Cosine distance calculation
        distance = 1 - cosine_sim(embedding, np.stack(centroid))
        distance_euclidean = np.linalg.norm(embedding - centroid)
        distances_euclidean.append(distance_euclidean)
        distances.append(distance)
    
    df_result = df.copy()
    df_result['distance'] = distances
    df_result['distance_euclidan'] = distances_euclidean
    return df_result, centroids


df_with_recognizability_distances, output_centroids = calculate_centroid_distances(recognizability_df, kisumu_store, beijing_store, cdm_store)

In [8]:
from docarray import DocList
from vislearnlabpy.embeddings.embedding_store import CLIPImageEmbedding
full_embedding_store = EmbeddingStore()
full_embedding_store.EmbeddingList = DocList[CLIPImageEmbedding](
    cdm_store.EmbeddingList + beijing_store.EmbeddingList + kisumu_store.EmbeddingList
)

Retrieving all relevant cosine similarity values

In [17]:
# image_sims ends up being a ~400M file here so should think about filtering with text_pairs param or not finding sim values across all drawings
image_sims = full_embedding_store.retrieve_similarities()
text_sims = text_embeddings.retrieve_similarities(output_path="text_sims.csv")
# image to text similarity values
cross_sims = full_embedding_store.retrieve_cross_similarity(text_embeddings.EmbeddingList)

Just storing the file names as the drawing ID for ease of use and understanding and because they are unique.

In [None]:
image_sims[['text1', 'text2']] = image_sims[['text1', 'text2']].applymap(lambda p: Path(p).name)
cross_sims['text1'] = cross_sims['text1'].apply(lambda p: Path(p).name)
df_with_recognizability_distances['id'] = df_with_recognizability_distances['url'].apply(lambda p: Path(p).name) 

Saving similarities and recognizability values with new paths

In [None]:
cross_sims.to_csv("cross_sims.csv")
df_with_recognizability_distances.drop(columns=['url']).to_csv("../data/clip_recognizability.csv")
image_sims.to_csv("image_sims.csv")

Plotting RDMs

In [85]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from vislearnlabpy.embeddings.similarity_utils import correlate_rdms

def plot_heatmaps(matrices, categories, titles, corr_values=None, cmap="viridis", vmin=None, vmax=None, figsize=(18,6), center=None, cbar=True, suptitle=None):
    n = len(matrices)
    fig, axs = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axs = [axs]
    if vmin is None:
        vmin = min(m.min() for m in matrices)
    if vmax is None:
        vmax = max(m.max() for m in matrices)
    
    for i, (mat, title) in enumerate(zip(matrices, titles)):
        sns.heatmap(mat, ax=axs[i], xticklabels=categories, yticklabels=categories,
                    cmap=cmap, vmin=vmin, vmax=vmax, center=center, cbar=cbar if i == n-1 else False)
        corr_text = f"\nMean diagonal: {corr_values[i]:.3f}" if corr_values else ""
        axs[i].set_title(title + corr_text)
        if " vs " in title:
            x_label, y_label = title.split(" vs ")
            axs[i].set_xlabel(x_label.strip())
            axs[i].set_ylabel(y_label.strip())
    if suptitle is not None:
        fig.suptitle(suptitle, 
        fontsize=16
        )
    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from vislearnlabpy.embeddings.similarity_utils import correlate_rdms

# Compute the RDMs 
cdm_rdm = cdm_store.compute_text_rdm()
beijing_rdm = beijing_store.compute_text_rdm()
kisumu_rdm = kisumu_store.compute_text_rdm()

# change ordering to create clusters: TODO: support custom ordering for our rdm functions
rdm_categories = ["airplane", "bike", "car", "bird", "cat", "rabbit", "tree", "house", "chair", "cup", "hat", "watch"]
reorder_idx = [categories.index(cat) for cat in rdm_categories]
cdm_rdm, beijing_rdm, kisumu_rdm = [
    rdm[np.ix_(reorder_idx, reorder_idx)] 
    for rdm in [cdm_rdm, beijing_rdm, kisumu_rdm]
]

# Compute correlations
corr_beijing_kisumu = correlate_rdms(beijing_rdm, kisumu_rdm)
corr_beijing_cdm = correlate_rdms(beijing_rdm, cdm_rdm)
corr_kisumu_cdm = correlate_rdms(kisumu_rdm, cdm_rdm)

plot_heatmaps(
    [beijing_rdm, kisumu_rdm, cdm_rdm],
    categories=rdm_categories,
    titles=["Beijing RDM", "Kisumu RDM", "CDM RDM"],
    corr_values=None,
    cmap="viridis",
    figsize=(18,6),
    cbar=True,
    suptitle= f"Correlations:\nBeijing-Kisumu: {corr_beijing_kisumu:.3f} | "
    f"Beijing-CDM: {corr_beijing_cdm:.3f} | Kisumu-CDM: {corr_kisumu_cdm:.3f}"
)

In [None]:
from vislearnlabpy.embeddings.similarity_utils import cosine_matrix

def get_mean_embeddings(embedding_list, categories):
    means = []
    for cat in categories:
        embeddings = [emb.embedding for emb in embedding_list if emb.text == cat]
        if not embeddings:
            means.append(np.zeros_like(embedding_list[0].embedding))  # or np.nan
        else:
            means.append(np.mean(embeddings, axis=0))
    return np.stack(means)

# Compute mean embeddings per category
kisumu_means = get_mean_embeddings(kisumu_store.EmbeddingList, rdm_categories)
beijing_means = get_mean_embeddings(beijing_store.EmbeddingList, rdm_categories)
cdm_means = get_mean_embeddings(cdm_store.EmbeddingList, rdm_categories)

# Compute similarity matrices (cosine similarity)
sim_kisumu_beijing = 1 - cosine_matrix(kisumu_means, beijing_means)
sim_kisumu_cdm = 1 - cosine_matrix(kisumu_means, cdm_means)
sim_cdm_beijing = 1 - cosine_matrix(cdm_means, beijing_means)

mean_diags = [np.mean(np.diag(m)) for m in [sim_kisumu_beijing, sim_kisumu_cdm, sim_cdm_beijing]]

plot_heatmaps(
    [sim_kisumu_beijing, sim_kisumu_cdm, sim_cdm_beijing],
    categories=rdm_categories,
    titles=["Kisumu vs Beijing", "Kisumu vs CDM", "CDM vs Beijing"],
    corr_values=mean_diags,
    cmap="coolwarm",
    center=0,
    figsize=(20,6)
)