Ideally, you'll only run this file once when you have a new pull of the data etc

In [None]:
from vislearnlabpy.embeddings.generate_embeddings import EmbeddingGenerator
from vislearnlabpy.embeddings.embedding_store import EmbeddingStore
from vislearnlabpy.embeddings.stimuli_loader import ImageExtractor, ImgExtractionSettings
from vislearnlabpy.embeddings.utils import display_search_results, zscore_embeddings, filter_embeddings 
import numpy as np
import pandas as pd
from pathlib import Path

## Extraction settings and constants

In [None]:
general_extraction_settings = ImgExtractionSettings(
        resize_dim=104,
        apply_content_crop=True,
        apply_center_crop=False,
        use_thumbnail=False,
        filter_edge_artifacts=False,
        normalize_stroke_thickness=True,
        stroke_target_thickness=1,
        bg_component_size=1,
        double_resize=False
    )
# different settings for Kisumu images to handle higher resolution 
kisumu_extraction_settings = ImgExtractionSettings(
        resize_dim=104,
        apply_content_crop=True,
        apply_center_crop=False,
        use_thumbnail=True,
        filter_edge_artifacts=False,
        normalize_stroke_thickness=True,
        stroke_target_thickness=1.7,
        bg_component_size=1,
        double_resize=True
    )
# clip extraction settings after resizing
clip_extraction_settings = ImgExtractionSettings(
        resize_dim=224,
        apply_content_crop=True,
        bg_component_size=0,
        apply_center_crop=False,
        use_thumbnail=False,
        filter_edge_artifacts=False,
        normalize_stroke_thickness=False
)
clip_transforms = ImageExtractor.get_transformations(clip_extraction_settings)

In [None]:
clip_generator = EmbeddingGenerator(model_type="clip", device="mps", output_type="doc", transform=clip_transforms) # change device="cpu" if you are using

Testing image extraction settings

In [None]:
ImageExtractor.save_transformed("/Volumes/vislearnlab/experiments/drawing/data/newdelhi/sketches_full_dataset/a_rabbit/a_rabbit_sketch_HP355_newdelhi_run_v11757310528711.png", "../data/examples/newdelhi_example.png", general_extraction_settings)
ImageExtractor.save_transformed("/Volumes/vislearnlab/experiments/drawing/data/beijing/object_drawings/S_chair_sketch_age9_sanjose_photodraw_e21550773711411.png", "../data/examples/sanjose_example.png", general_extraction_settings)
ImageExtractor.save_transformed("/Volumes/vislearnlab/experiments/drawing/data/beijing/object_drawings/P_airplane_sketch_age5_IPAD4_THU5F2.png", "../data/examples/beijing_example.png", general_extraction_settings)
ImageExtractor.save_transformed("/Volumes/vislearnlab/experiments/drawing/data/kisumu/transformed_drawings/S52_age3_House_trial010.png", "../data/examples/kisumu_example.png", kisumu_extraction_settings)

Getting drawings across sites. newdelhi drawings are linked to subject data csv to get age information, but rest have that information within the urls

In [68]:
drawings_folder = Path("/Volumes/vislearnlab/experiments/drawing/data")
beijing_dir = drawings_folder / Path("beijing/object_drawings")
beijing_resized_dir = drawings_folder / Path("beijing/resized_drawings")
kisumu_dir = drawings_folder / Path("kisumu/drawings")
kisumu_cropped_drawings_path = drawings_folder / Path("kisumu/transformed_drawings")
kisumu_resized_dir = drawings_folder / Path("kisumu/resized_drawings")
newdelhi_dir = drawings_folder / Path("india/sketches_full_dataset")
newdelhi_resized_dir = drawings_folder / Path("india/resized_sketches_full_dataset")
newdelhi_df = pd.read_csv(drawings_folder / Path("india/AllDescriptives_images_final_india_run_v1.csv"))
newdelhi_subject_data = pd.read_csv(Path("/Volumes/vislearnlab/experiments/drawing/data/india/subject_data.csv"))
categories = ["airplane", "bike", "bird", "car", "cat", "chair", "cup", "hat", "house", "rabbit", "tree", "watch"]

## Resizing and cropping drawings

In [None]:
import os
from tqdm import tqdm
kisumu_cropped_drawings_path = Path("/Volumes/vislearnlab/experiments/drawing/data/kisumu/transformed_drawings")
os.makedirs(kisumu_cropped_drawings_path, exist_ok=True)
kisumu_extraction_settings = ImgExtractionSettings(apply_center_crop=True, crop_dim=1400, change_stroke_color=True, bg_threshold=1, bg_component_size=100, filter_edge_artifacts=True)
for img_path in tqdm(kisumu_dir.glob("*.png"), total=len(list(kisumu_dir.glob("*.png")))):
    # if (img_path.name == "S52_age3_House_trial010.png"):
    new_path = kisumu_cropped_drawings_path / img_path.name
    if not img_path.name.startswith(".") and not new_path.exists():
        ImageExtractor.save_transformed(img_path, new_path, kisumu_extraction_settings)

In [None]:
from tqdm import tqdm
def resize_drawings(input_dir, output_dir, extraction_settings):
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    img_paths = list(input_dir.glob("**/*.png"))
    for img_path in tqdm(img_paths, total=len(img_paths)):
        # Preserve subdirectory structure
        relative_path = img_path.relative_to(input_dir)
        new_path = output_dir / relative_path
        # Create parent directories if they don't exist
        new_path.parent.mkdir(parents=True, exist_ok=True)
        if not img_path.name.startswith(".") and not new_path.exists():
            try:
                ImageExtractor.save_transformed(img_path, new_path, extraction_settings)
            except Exception as e:
                print(f"Error processing {img_path}: {e}")
                ImageExtractor.save_transformed(img_path, new_path, None)
   

resize_drawings(newdelhi_dir, newdelhi_resized_dir, general_extraction_settings)
resize_drawings(kisumu_cropped_drawings_path, kisumu_resized_dir, kisumu_extraction_settings)
resize_drawings(beijing_dir, beijing_resized_dir, general_extraction_settings)

## Embedding extraction

Beijing and sanjose within the same folder

In [None]:
clip_generator.generate_image_embeddings(output_path="beijing_sanjose_drawings_resized", input_dir=beijing_resized_dir, batch_size=100, overwrite=True)
clip_generator.generate_image_embeddings(output_path="beijing_sanjose_drawings", input_dir=beijing_dir, batch_size=100, overwrite=True)

Kisumu

In [None]:
clip_generator.generate_image_embeddings(output_path="kisumu_drawings_resized", input_dir=kisumu_resized_dir, batch_size=100, overwrite=True)
clip_generator.generate_image_embeddings(output_path="kisumu_drawings", input_dir=kisumu_cropped_drawings_path, batch_size=100, overwrite=True)

Now grabbing newdelhi drawings

In [None]:
valid_pids = newdelhi_subject_data[newdelhi_subject_data["Age (months)"] > 0]["PID"].unique().tolist()
print(f"Number of participants: {len(valid_pids)}")
newdelhi_df = newdelhi_df.rename(columns={
    'filename': 'image1',
    'category': 'text1'
})
# getting rid of articles
newdelhi_df['text1'] = newdelhi_df['text1'].apply(lambda x: " ".join(x.split('_')))
#newdelhi_df['image1'] = newdelhi_df['image1'].apply(lambda x: x.replace("sketches_full_dataset", "resized_sketches_full_dataset"))
#remap path if needed: new_base="/file/storage/path"
#drawings_df["image1"] = drawings_df["image1"].apply(lambda x: remap_path(x, new_base))

# Filtering to just our actual participants
filtered_df = newdelhi_df[newdelhi_df['participant_id'].str.upper().isin(valid_pids)]
filtered_df.to_csv("tmp_draw_df.csv")
#clip_generator.generate_image_embeddings(output_path="newdelhi_drawings_resized", input_csv="tmp_draw_df.csv", batch_size=100, overwrite=True)
clip_generator.generate_image_embeddings(output_path="newdelhi_drawings", input_csv="tmp_draw_df.csv", batch_size=100, overwrite=True)

Now getting the text embeddings for each category using the prompt 'drawing of a..' (miss)

In [None]:
newdelhi_store = EmbeddingStore.from_doc("newdelhi_drawings_resized/image_embeddings/clip_image_embeddings_doc")
kisumu_store =  EmbeddingStore.from_doc("kisumu_drawings_resized/image_embeddings/clip_image_embeddings_doc")
beijing_sanjose_store =  EmbeddingStore.from_doc("beijing_sanjose_drawings_resized/image_embeddings/clip_image_embeddings_doc")
text_embeddings = EmbeddingStore.from_doc("../data/embeddings/text_embeddings")

In [None]:
# only choosing embeddings that include verbal cues and not picture cues and separating sanjose and Beijing
sanjose_store = EmbeddingStore()
beijing_store = EmbeddingStore()
for embedding in beijing_sanjose_store.EmbeddingList:
    # if this embedding is a verbal cue and not a picture cue
    if "S_" in embedding.url:
        if "sanjose" in embedding.url:
            sanjose_store.add_embedding(embedding=embedding.embedding, url=embedding.url)
        else:
            beijing_store.add_embedding(embedding=embedding.embedding, url=embedding.url)

In [None]:
print(f"Kisumu embeddings: {len(kisumu_store.EmbeddingList)}")
print(f"Beijing embeddings: {len(beijing_store.EmbeddingList)}")
print(f"sanjose embeddings: {len(sanjose_store.EmbeddingList)}")
print(f"newdelhi embeddings: {len(newdelhi_store.EmbeddingList)}")

Kisumu embeddings: 1440
Beijing embeddings: 736
sanjose embeddings: 716
newdelhi embeddings: 3983


Extract metadata about location, category and age from the file names and calculate probability of the correct category being chosen

In [76]:
import os
import re
from vislearnlabpy.embeddings.similarity_utils import calculate_probability

def extract_beijing_metadata(url):
    """Extract metadata from Beijing URL format"""
    filename = os.path.basename(url)
    # Determine location
    location = "Beijing" if "THU" in filename else "sanjose"
    # Extract age
    age_match = re.search(r'age(\d+)', filename)
    age = int(age_match.group(1)) if age_match else None
    # Extract category (between first and second underscore, after the initial letter)
    parts = filename.split('_')
    category = parts[1] if len(parts) > 1 else None
    # Extract participant ID (everything before .png)
    participant_id = filename.replace('.png', '').split("_")[-1]
    return {
        'location': location,
        'age': age,
        'category': category,
        'participant_id': participant_id
    }

def extract_newdelhi_metadata(url):
    filename = os.path.basename(url)
    # remove leading articles
    filename_clean = filename.removeprefix("a_").removeprefix("an_").removeprefix("three_").removeprefix("two_")
    parts = filename_clean.split('_')
    participant_id = parts[2].upper()
    category = parts[0]
    age = newdelhi_subject_data[newdelhi_subject_data["PID"] == participant_id]["Age"].values[0]
    return {
        'location': 'newdelhi',
        'age': age,
        'category': category,
        'participant_id': participant_id
    }

def extract_kisumu_metadata(url):
    """Extract metadata from Kisumu URL format"""
    filename = os.path.basename(url)
    
    # Split on first underscore to get participant and rest
    parts = filename.split('_', 1)
    participant_id = parts[0]
    
    # Extract age and add 3
    age_match = re.search(r'age(\d+)', filename)
    age = int(age_match.group(1)) + 3 if age_match else None
    
    # Extract category (between age and trial)
    category_match = re.search(r'age\d+_([^_]+)_trial', filename)
    category = category_match.group(1) if category_match else None
    category = "bike" if category == "Bicycle" else category
    
    return {
        'location': 'Kisumu',
        'age': age,
        'category': category,
        'participant_id': participant_id
    }

def _process_site(
    text_embeddings,
    docs,
    extract_metadata_fn,
    recognizability_fn,
    results,
):
    """Helper to process embeddings for a single site and extend `results`."""
    for doc in docs:
        md = extract_metadata_fn(doc.url)
        if md["category"] and md["age"] is not None:
            results.append(
                {
                    "location": md["location"],
                    "recognizability": recognizability_fn(
                        doc.embedding,
                        text_embeddings,     
                        md["category"].lower()
                    ),
                    "age": md["age"],
                    "participant_id": md["participant_id"],
                    "drawing_category": md["category"].lower(),
                    "url": doc.url,
                }
            )

def process_embeddings(text_embeddings, beijing_list, sanjose_list, kisumu_list, newdelhi_list):
    """Process all embeddings and create the final dataset."""
    results = []
    # processing beijing
    _process_site(
        text_embeddings,
        beijing_list, 
        extract_beijing_metadata, 
        calculate_probability,
        results,
    )
    # processing kisumu
    _process_site(
        text_embeddings, 
        kisumu_list,
        extract_kisumu_metadata,
        calculate_probability,   
        results,
    )
    # processing sanjose
    _process_site(
        text_embeddings, 
        sanjose_list,
        extract_beijing_metadata,
        calculate_probability,   
        results,
    )
    # processing newdelhi
    _process_site(
        text_embeddings, 
        newdelhi_list,
        extract_newdelhi_metadata,
        calculate_probability,   
        results,
    )

    return results


# Main processing function
def create_drawing_analysis_csv(text_embeddings, beijing_list, sanjose_list, kisumu_list, newdelhi_list, output_file='drawing_analysis.csv'):
    """
    Main function to process all embeddings and create CSV
    """
    
    # Process all embeddings
    results = process_embeddings(text_embeddings, beijing_list, sanjose_list, kisumu_list, newdelhi_list)
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False)
    print(f"Results saved to {output_file}")
    print(f"Total drawings: {len(results)}")
    print(f"Locations: {df['location'].value_counts().to_dict()}")
    print(f"Age range: {df['age'].min()} - {df['age'].max()}")
    print(f"Categories: {df['drawing_category'].value_counts().to_dict()}")
    
    return df


In [77]:
recognizability_df = create_drawing_analysis_csv(text_embeddings.EmbeddingList, beijing_store.EmbeddingList, sanjose_store.EmbeddingList, kisumu_store.EmbeddingList, newdelhi_store.EmbeddingList)

Results saved to drawing_analysis.csv
Total drawings: 6875
Locations: {'newdelhi': 3983, 'Kisumu': 1440, 'Beijing': 736, 'sanjose': 716}
Age range: 4 - 11
Categories: {'house': 397, 'tree': 390, 'watch': 377, 'cat': 372, 'bike': 369, 'cup': 369, 'car': 368, 'hat': 367, 'bird': 360, 'chair': 358, 'rabbit': 354, 'airplane': 352, 'lines': 279, 'circle': 140, 'line': 140, 'shapes': 139, 'square': 139, 'tutorial': 139, 'shape': 136, 'spoon': 114, 'face': 114, 'toothbrush': 113, 'man': 113, 'phone': 112, 'bus': 112, 'dog': 111, 'fish': 109, 'key': 109, 'woman': 108, 'eyeglasses': 108, 'train': 107}


Adding in categories to the embedding stores now

In [71]:
# Map each URL to the row as a dictionary of fields
url_to_fields = recognizability_df.set_index('url').to_dict(orient='index')

# Assign fields in loop
for store in [beijing_store, kisumu_store, sanjose_store, newdelhi_store]:
    for embedding in store.EmbeddingList:
        fields = url_to_fields.get(embedding.url, {})
        embedding.text = fields.get('drawing_category')

In [72]:
from docarray import DocList
from vislearnlabpy.embeddings.embedding_store import CLIPImageEmbedding
full_embedding_store = EmbeddingStore()
full_embedding_store.EmbeddingList = DocList[CLIPImageEmbedding](
    sanjose_store.EmbeddingList + beijing_store.EmbeddingList + kisumu_store.EmbeddingList + newdelhi_store.EmbeddingList
)

In [None]:
# need to figure out url storing..
# recognizability_df['id'] = recognizability_df['url'].apply(lambda p: Path(p).name)
#for store in [beijing_store, kisumu_store, sanjose_store, newdelhi_store,full_embedding_store]:
#    store.EmbeddingList.url = list(map(lambda p: Path(p).name, store.EmbeddingList.url))
#os.makedirs(output_embedding_dir, exist_ok=True)

## Saving embeddings

In [75]:
# saving embedding files
output_embedding_dir = "../data/embeddings"
beijing_store.to_doc(f"{output_embedding_dir}/beijing_store.doc")
kisumu_store.to_doc(f"{output_embedding_dir}/kisumu_store.doc")
sanjose_store.to_doc(f"{output_embedding_dir}/sanjose_store.doc")
newdelhi_store.to_doc(f"{output_embedding_dir}/newdelhi_store.doc")
full_embedding_store.to_doc(f"{output_embedding_dir}/full_embedding_store.doc")
text_embeddings.to_doc(f"{output_embedding_dir}/text_embeddings.doc")

'file:///Users/visuallearninglab/Documents/kenya_draw/analysis/../data/embeddings/text_embeddings.doc'

todo: port everything below this to a different file

Calculating centroid distances

In [78]:
import pandas as pd
import numpy as np
from collections import defaultdict
from vislearnlabpy.embeddings.similarity_utils import cosine_sim
from vislearnlabpy.embeddings.utils import zscore_embeddings
from docarray.utils.filter import filter_docs

def calculate_centroid_distances(df, kisumu_store, beijing_store, sanjose_store, newdelhi_store):
    """
    Calculate centroid embeddings for each age-location-category combination and compute distances.
    Returns:
    df: Original dataframe with added 'distance' column
    """
    
    # Helper function to get embedding list based on location
    def get_embedding_list(location):
        location_lower = location.lower()
        if 'kisumu' in location_lower:
            return kisumu_store.EmbeddingList
        elif 'beijing' in location_lower:
            return beijing_store.EmbeddingList
        elif 'sanjose' in location_lower:
            return sanjose_store.EmbeddingList
        elif 'newdelhi' in location_lower:
            return newdelhi_store.EmbeddingList
        else:
            return {}
    
    # Helper function to extract embedding from store
    def get_embedding(url, embedding_list):
        matches = filter_docs(embedding_list, {'url': {'$eq': url}})
        embedding = matches[0].embedding if matches else None
        return embedding
    
    # Group by age and location to calculate centroids
    centroids = {}
    centroid_urls = {}
    for (age, location, category), group in df.groupby(['age', 'location', 'drawing_category']):
        embedding_list = get_embedding_list(location)
        embeddings = []
        urls = []
        # Collect all embeddings for this age-location combination
        for url in group['url']:
            embedding = get_embedding(url, embedding_list)
            if embedding is not None:
                embeddings.append(np.array(embedding))
                urls.append(url)
        
        # Calculate centroid if we have embeddings
        if embeddings:
            centroid = np.mean(embeddings, axis=0)
            centroids[(age, location, category)] = centroid
            
            # Find the URL with embedding closest to the centroid
            min_distance = float('inf')
            closest_url = None
            closest_embedding = None
            
            for url, embedding in zip(urls, embeddings):
                distance = np.linalg.norm(embedding - centroid)
                if distance < min_distance:
                    min_distance = distance
                    closest_url = url
                    closest_embedding = embedding
            
            centroid_urls[(age, location, category)] = {
                'url': closest_url,
                'embedding': closest_embedding
            }
        else:
            print(f"Warning: No embeddings found for age={age}, category={category}, location={location}")
            centroids[(age, location, category)] = None
    all_locations = df['location'].unique()
    print(all_locations)
    # Calculate distances for each row in the dataframe
    distances = []
    distances_euclidean = []
    # Initialize dictionaries to store distances for each location
    location_distances = {}
    location_distances_euclidean = {}

    # Initialize distance lists for each location in all_locations
    for location in all_locations:
        location_distances[location] = []
        location_distances_euclidean[location] = []
        '''
        some zscoring logic which doesn't actually make sense to do here
        keys, embeddings = zip(*[
            ((age, loc, cat), emb)
            for (age, loc, cat), emb in centroids.items()
            if loc == location
        ])

        stacked = np.stack(embeddings)
        zscored = zscore_embeddings(stacked)

        for key, z_emb in zip(keys, zscored):
            centroids[key] = z_emb
        '''

    for _, row in df.iterrows():
        age = row['age']
        location = row['location']
        category = row['drawing_category']
        url = row['url']
        # Get the centroid for this age-location combination
        centroid = centroids.get((age, location, category))
        if centroid is None:
            distances.append(np.nan)
            continue
        # Get the embedding for this specific URL
        embedding_list = get_embedding_list(location)
        embedding = get_embedding(url, embedding_list)
        
        if embedding is None:
            distances.append(np.nan)
            continue
        
        embedding = np.stack(embedding)
        # Cosine distance calculation
        distance = 1 - cosine_sim(embedding, np.stack(centroid))
        distance_euclidean = np.linalg.norm(embedding - centroid)
        distances_euclidean.append(distance_euclidean)
        distances.append(distance)

        for target_location in all_locations:
            # Get centroid for the target location
            target_centroid = centroids.get((age, target_location, category))
        
            if target_centroid is None:
                location_distances[target_location].append(np.nan)
                location_distances_euclidean[target_location].append(np.nan)
            else:
                # Calculate cosine distance to target location centroid
                target_distance = 1 - cosine_sim(embedding, np.stack(target_centroid))
                target_distance_euclidean = np.linalg.norm(embedding - np.stack(target_centroid))

                location_distances[target_location].append(target_distance)
                location_distances_euclidean[target_location].append(target_distance_euclidean)

    
    df_result = df.copy()
    df_result['distance'] = distances
    df_result['distance_euclidean'] = distances_euclidean
    for location in all_locations:
        df_result[f'distance_{location}'] = location_distances[location]
        df_result[f'distance_euclidean_{location}'] = location_distances_euclidean[location]

    return df_result, centroids, centroid_urls


df_with_recognizability_distances, output_centroids, output_centroid_urls = calculate_centroid_distances(recognizability_df, kisumu_store, beijing_store, sanjose_store, newdelhi_store)

['Beijing' 'Kisumu' 'sanjose' 'newdelhi']


## Saving df

In [79]:
df_with_recognizability_distances.drop(columns=['url']).to_csv("../data/clip_recognizability_final.csv")

In [80]:
selected_categories = ["airplane", "bike", "bird", "hat", "rabbit", "watch",
                       "cat", "house", "cup", "chair", "tree", "car"]

In [82]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.linear_model import LinearRegression
import pandas as pd
from sklearn.linear_model import RANSACRegressor, LinearRegression
from PIL import Image
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from pathlib import Path
import matplotlib.colors as mcolors

# Toggle between direct connection and RANSAC regression
USE_RANSAC = False  # Set to True for RANSAC regression, False for direct connection

paths = {
    'Beijing': beijing_dir,
    'sanjose': beijing_dir,
    'newdelhi': newdelhi_dir,
    'Kisumu': kisumu_cropped_drawings_path
}

centroids_dict = output_centroids
centroid_urls_dict = output_centroid_urls

# Convert your dict_keys and dict_values to lists if needed
centroids_keys = list(centroids_dict.keys())
centroids_values = list(centroids_dict.values())

# Create initial DataFrame with all data
df_all = pd.DataFrame({
    'age': [k[0] for k in centroids_keys],
    'location': [k[1] for k in centroids_keys],
    'category': [k[2] for k in centroids_keys],
    'embedding': centroids_values,
    'url': [centroid_urls_dict[k]['url'] if centroid_urls_dict[k] is not None else None for k in centroids_keys]
})

# Filter for specific categories and ages
selected_categories = ["airplane", "bike", "bird", "hat", "rabbit", "watch",
                       "cat", "house", "cup", "chair", "tree", "car"]
df_all = df_all[df_all['category'].isin(selected_categories)]
df_all = df_all[(df_all['age'] >= 4) & (df_all['age'] <= 9)]

# Get unique locations
locations = sorted(df_all['location'].unique())

# Define base colors for each location
base_colors = {
    'Kisumu': '#1f77b4', 'Beijing': '#ff7f0e', 'sanjose': '#2ca02c', 'newdelhi': '#d62728'
}

location_base_colors = {loc: base_colors.get(loc, f'C{i}') for i, loc in enumerate(locations)}

# Function to adjust color brightness based on age
def get_age_color(base_color, age, min_age=4, max_age=9):
    """
    Adjust color brightness based on age.
    Younger ages (4) -> brighter
    Older ages (9) -> darker
    """
    # Normalize age to 0-1 range
    age_norm = (age - min_age) / (max_age - min_age)
    
    # Convert hex to RGB
    rgb = mcolors.to_rgb(base_color)
    
    # Adjust brightness: younger = brighter (closer to white), older = darker (closer to base color)
    # brightness_factor ranges from 1.8 (brightest, age 4) to 0.6 (darkest, age 9)
    brightness_factor = 1.8 - (age_norm * 1.2)
    
    # Apply brightness
    adjusted_rgb = tuple(min(1.0, c * brightness_factor) for c in rgb)
    
    return adjusted_rgb

# Function to load and prepare image with consistent size
def load_image(url, location, target_size=80):
    """Load image from url with proper path prepending and resize to consistent size"""
    try:
        if location == "newdelhi":
            location_path = Path(paths[location]) / str(url).split("_sketch")[0]
        else:
            location_path = Path(paths[location])
        full_path = location_path / str(url)

        img = Image.open(full_path).convert("RGBA")
        img.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)

        # Create white background and paste RGBA image onto it
        background = Image.new("RGB", img.size, (255, 255, 255))
        background.paste(img, mask=img.split()[3])  # 3 = alpha channel

        return OffsetImage(background, zoom=1.0)

    except Exception as e:
        print(f"Error loading image {url}: {e}")
        return None

# Create separate plot for each category with its own t-SNE
for category in selected_categories:
    print(f"Processing {category}...")
    
    # Filter data for this category only
    category_df = df_all[df_all['category'] == category].copy()
    
    if len(category_df) == 0:
        print(f"No data for {category}, skipping...")
        continue
    
    # Stack embeddings for this category only
    embeddings_matrix = np.vstack(category_df['embedding'].values)
    
    # Run t-SNE on this category's embeddings only
    print(f"  Running t-SNE for {category}...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(3, len(embeddings_matrix)-1))
    embeddings_2d = tsne.fit_transform(embeddings_matrix)
    
    # Add t-SNE coordinates to dataframe
    category_df['x'] = embeddings_2d[:, 0]
    category_df['y'] = embeddings_2d[:, 1]
    
    # Create plot
    fig, ax = plt.subplots(figsize=(20, 16))
    
    # First pass: Plot points and images (lower z-order)
    for location in locations:
        loc_data = category_df[category_df['location'] == location].sort_values('age')
        
        if len(loc_data) >= 1:
            # Plot points for each age with brightness-based coloring
            for _, row in loc_data.iterrows():
                age_color = location_base_colors[location] #get_age_color(location_base_colors[location], row['age'])
                
                ax.scatter(row['x'], row['y'], 
                          color=age_color, 
                          s=400,
                          alpha=0.8,
                          edgecolors='white',
                          linewidths=2,
                          zorder=3)
                
                ax.annotate(f"{int(row['age'])}", 
                           (row['x'], row['y']), 
                           fontsize=14, 
                           ha='center', 
                           va='center',
                           color='white',
                           weight='bold',
                           zorder=4)
                
                if row['url'] is not None:
                    img = load_image(row['url'], row['location'], target_size=60)
                    if img is not None:
                        ab = AnnotationBbox(img, (row['x'], row['y']),
                                          xybox=(0, -55),
                                          xycoords='data',
                                          boxcoords='offset points',
                                          frameon=True,
                                          pad=0.3,
                                          bboxprops=dict(edgecolor=age_color,
                                                        facecolor='white',
                                                        linewidth=2))
                        ax.add_artist(ab)
    
    # Second pass: Draw arrows on top (higher z-order)
    for location in locations:
        loc_data = category_df[category_df['location'] == location].sort_values('age')
        
        if len(loc_data) > 1:
            if USE_RANSAC:
                # RANSAC regression approach
                X = loc_data['age'].values.reshape(-1, 1)
                y_x = loc_data['x'].values
                y_y = loc_data['y'].values

                model_x = RANSACRegressor(LinearRegression()).fit(X, y_x)
                model_y = RANSACRegressor(LinearRegression()).fit(X, y_y)

                age_start, age_end = X.min(), X.max()
                x_start, x_end = model_x.predict([[age_start]])[0], model_x.predict([[age_end]])[0]
                y_start, y_end = model_y.predict([[age_start]])[0], model_y.predict([[age_end]])[0]
            else:
                # Direct connection from lowest to highest age
                youngest = loc_data.iloc[0]  # Already sorted by age
                oldest = loc_data.iloc[-1]
                
                x_start, y_start = youngest['x'], youngest['y']
                x_end, y_end = oldest['x'], oldest['y']

            arrow = ax.annotate(
                '', xy=(x_end, y_end), xytext=(x_start+1, y_start+1),
                arrowprops=dict(arrowstyle='->',
                                color=location_base_colors[location],
                                alpha=0.7,
                                linewidth=5,
                                mutation_scale=35),
                zorder=10  # High z-order to place arrows on top
            )
    
    ax.set_title(f'{category}', fontsize=22, fontweight='bold', pad=20)
    ax.set_xlabel('t-SNE 1', fontsize=14)
    ax.set_ylabel('t-SNE 2', fontsize=14)
    
    # Create custom legend for locations
    location_handles = []
    for location in locations:
        location_handles.append(plt.Line2D([0], [0], marker='o', color='w', 
                                          markerfacecolor=location_base_colors[location], 
                                          markersize=10, label=location,
                                          markeredgecolor='white', markeredgewidth=1.5))
    
    # Add legend to plot
    legend = ax.legend(handles=location_handles, loc='upper left', 
                       fontsize=12, title='Location', title_fontsize=13)
    
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'tsne_plots/tsne_centroid_{category}_age.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved visualization for {category}")

print("Done!")

Processing airplane...
  Running t-SNE for airplane...
  Saved visualization for airplane
Processing bike...
  Running t-SNE for bike...
  Saved visualization for bike
Processing bird...
  Running t-SNE for bird...
  Saved visualization for bird
Processing hat...
  Running t-SNE for hat...
  Saved visualization for hat
Processing rabbit...
  Running t-SNE for rabbit...
  Saved visualization for rabbit
Processing watch...
  Running t-SNE for watch...
  Saved visualization for watch
Processing cat...
  Running t-SNE for cat...
  Saved visualization for cat
Processing house...
  Running t-SNE for house...
  Saved visualization for house
Processing cup...
  Running t-SNE for cup...
  Saved visualization for cup
Processing chair...
  Running t-SNE for chair...
  Saved visualization for chair
Processing tree...
  Running t-SNE for tree...
  Saved visualization for tree
Processing car...
  Running t-SNE for car...
  Saved visualization for car
Done!


In [84]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from collections import defaultdict
from vislearnlabpy.embeddings.similarity_utils import cosine_sim
from vislearnlabpy.embeddings.utils import zscore_embeddings
from docarray.utils.filter import filter_docs
import seaborn as sns

def collect_all_embeddings(df, kisumu_store, beijing_store, sanjose_store, newdelhi_store):
    """
    Collect all embeddings from the dataframe with their metadata.
    Returns a dataframe with embeddings and metadata.
    """
    def get_embedding_list(location):
        location_lower = location.lower()
        if 'kisumu' in location_lower:
            return kisumu_store.EmbeddingList
        elif 'beijing' in location_lower:
            return beijing_store.EmbeddingList
        elif 'sanjose' in location_lower:
            return sanjose_store.EmbeddingList
        elif 'newdelhi' in location_lower:
            return newdelhi_store.EmbeddingList
        else:
            return {}
    
    def get_embedding(url, embedding_list):
        matches = filter_docs(embedding_list, {'url': {'$eq': url}})
        embedding = matches[0].embedding if matches else None
        return embedding
    
    embeddings_data = []
    for _, row in df.iterrows():
        embedding_list = get_embedding_list(row['location'])
        embedding = get_embedding(row['url'], embedding_list)
        
        if embedding is not None:
            embeddings_data.append({
                'embedding': np.array(embedding),
                'age': row['age'],
                'location': row['location'],
                'category': row['drawing_category'],
                'url': row['url']
            })
    
    return pd.DataFrame(embeddings_data)


def create_tsne_per_category(df, stores_dict, centroid_urls, emb_df, output_dir='tsne_plots', perplexity=30, random_state=42):
    """
    Type 1: Create t-SNE plots per category with all embeddings.
    Color by location, brightness by age (no age labels).
    Includes images of closest embeddings to centroids.
    
    Args:
        df: DataFrame with columns ['age', 'location', 'drawing_category', 'url']
        stores_dict: Dict with keys ['kisumu', 'beijing', 'sanjose', 'newdelhi'] mapping to stores
        centroid_urls: Dict from calculate_centroid_distances containing closest URLs and embeddings
        output_dir: Directory to save plots
    """
    import os
    from matplotlib.offsetbox import OffsetImage, AnnotationBbox
    from PIL import Image
    import requests
    from io import BytesIO
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Define animal categories
    animal_categories = ['bird', 'rabbit', 'cat']
    
    # Get unique categories
    categories = selected_categories
    
    # Color palette for locations
    location_colors = {'Kisumu': '#1f77b4', 'Beijing': '#ff7f0e', 'sanjose': '#2ca02c', 'newdelhi': '#d62728'}
    
    for category in categories:
        print(f"Processing {category}...")
        cat_data = emb_df[emb_df['category'] == category].copy()
        
        if len(cat_data) < 2:
            print(f"Skipping {category} - not enough data")
            continue
        
        # Extract embeddings
        embeddings = np.stack(cat_data['embedding'].values)
        
        # Run t-SNE
        print(f"Running t-SNE for {category}...")
        tsne = TSNE(n_components=2, perplexity=min(perplexity, len(embeddings)-1), 
                    random_state=random_state)
        tsne_results = tsne.fit_transform(embeddings)
        
        # Create mapping from URL to t-SNE coordinates
        cat_data_reset = cat_data.reset_index(drop=True)
        url_to_tsne = {}
        for idx, row in cat_data_reset.iterrows():
            url_to_tsne[row['url']] = tsne_results[idx]
        
        # Normalize age for brightness (0 to 1) - discrete values 4-9
        # Map age 4 -> 0.0 (dimmest), age 9 -> 1.0 (brightest)
        age_normalized = (cat_data['age'] - 4) / (9 - 4)
        
        # Create plot
        fig, ax = plt.subplots(figsize=(12, 10))
        
        for location in cat_data['location'].unique():
            loc_mask = cat_data['location'] == location
            loc_data = cat_data[loc_mask]
            loc_tsne = tsne_results[loc_mask]
            loc_age_norm = age_normalized[loc_mask]
            
            # Use age for alpha (brightness)
            alphas = 0.1 + 0.9 * loc_age_norm  # Range from 0.1 to 1.0

            for i, (x, y, alpha) in enumerate(zip(loc_tsne[:, 0], loc_tsne[:, 1], alphas)):
                ax.scatter(x, y, c=[location_colors.get(location, '#000000')], 
                          alpha=alpha, s=100, edgecolors='white', linewidth=0.5)
        
        # Add centroid images
        for (age, location, cat), centroid_info in centroid_urls.items():
            if cat == category:
                centroid_url = centroid_info['url']
                
                # Find the t-SNE coordinates for this URL
                if centroid_url in url_to_tsne:
                    x, y = url_to_tsne[centroid_url]
                    
                    # Calculate alpha for this centroid based on age
                    centroid_alpha = 0.1 + 0.9 * ((age - 4) / (9 - 4))
                    
                    try:
                        # Load image from URL
                        img = load_image(centroid_url, location, target_size=40)
                        
                        ab = AnnotationBbox(img, (x, y),
                                          frameon=True,
                                          pad=0.3,
                                          boxcoords="data",
                                          box_alignment=(0.5, 0.5),
                                          bboxprops=dict(edgecolor=location_colors.get(location, '#000000'),
                                                        alpha=centroid_alpha,
                                                        facecolor='white',
                                                        linewidth=2))
                        ax.add_artist(ab)
                    except Exception as e:
                        print(f"Could not load image for {centroid_url}: {e}")
        
        ax.set_title(f't-SNE: {category.capitalize()} (colored by location, brightness by age)', fontsize=16)
        ax.set_xlabel('t-SNE Component 1', fontsize=12)
        ax.set_ylabel('t-SNE Component 2', fontsize=12)
        
        # Create legend for locations
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=color, label=loc) 
                          for loc, color in location_colors.items() if loc in cat_data['location'].values]
        ax.legend(handles=legend_elements, title='Location', loc='best')
        
        plt.tight_layout()
        plt.savefig(f'{output_dir}/tsne_category_{category}.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved visualization for {category}")

def create_tsne_per_age(df, stores_dict, emb_df, output_dir='tsne_plots', perplexity=30, random_state=42):
    """
    Type 2: Create t-SNE plots per age bin, separating animals and non-animals.
    Color by category.
    
    Args:
        df: DataFrame with columns ['age', 'location', 'drawing_category', 'url']
        stores_dict: Dict with keys ['kisumu', 'beijing', 'sanjose', 'newdelhi'] mapping to stores
        output_dir: Directory to save plots
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # Collect all embeddings
    
    # Get unique ages
    ages = sorted(emb_df['age'].unique())
    
    for age in ages:
        age_data = emb_df[emb_df['age'] == age].copy()
        
        # Process animals and non-animals separately
        #for is_animal, label in [(True, 'animals'), (False, 'non_animals')]:
        #subset = age_data[age_data['is_animal'] == is_animal].copy()
        
        if len(age_data) < 2:
            print(f"Skipping age {age} - not enough data")
            continue
        
        print(f"Processing age {age}...")
        
        # Extract embeddings
        embeddings = np.stack(age_data['embedding'].values)
        
        # Run t-SNE
        print(f"Running t-SNE for age {age}...")
        tsne = TSNE(n_components=2, perplexity=min(perplexity, len(embeddings)-1), 
                    random_state=random_state)
        tsne_results = tsne.fit_transform(embeddings)
        
        # Create plot
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Get unique categories and assign colors
        categories = age_data['category'].unique()
        colors = sns.color_palette('husl', n_colors=len(categories))
        color_map = dict(zip(categories, colors))
        
        for category in categories:
            cat_mask = age_data['category'] == category
            cat_tsne = tsne_results[cat_mask]
            
            ax.scatter(cat_tsne[:, 0], cat_tsne[:, 1], 
                        c=[color_map[category]], label=category.capitalize(),
                        alpha=0.7, s=100, edgecolors='white', linewidth=0.5)
        
        ax.set_title(f't-SNE: Age {age}', fontsize=16)
        ax.set_xlabel('t-SNE Component 1', fontsize=12)
        ax.set_ylabel('t-SNE Component 2', fontsize=12)
        ax.legend(title='Category', loc='best', ncol=2)
        
        plt.tight_layout()
        plt.savefig(f'{output_dir}/tsne_age_{age}.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved visualization for age {age}")


def create_tsne_all_embeddings(df, stores_dict, emb_df, output_dir='tsne_plots', perplexity=50, random_state=42):
    """
    Type 3: Create t-SNE plot with all embeddings.
    Color by whether animal/non-animal category, brightness by age.
    
    Args:
        df: DataFrame with columns ['age', 'location', 'drawing_category', 'url']
        stores_dict: Dict with keys ['kisumu', 'beijing', 'sanjose', 'newdelhi'] mapping to stores
        output_dir: Directory to save plots
    """
    import os
    os.makedirs(output_dir, exist_ok=True)

    # Define animal categories
    animal_categories = ['bird', 'rabbit', 'cat']
    emb_df['is_animal'] = emb_df['category'].isin(animal_categories)
    
    # Extract embeddings
    embeddings = np.stack(emb_df['embedding'].values)
    
    # Run t-SNE
    print("Running t-SNE for all embeddings...")
    tsne = TSNE(n_components=2, perplexity=min(perplexity, len(embeddings)-1), 
                random_state=random_state)
    tsne_results = tsne.fit_transform(embeddings)
    
    # Normalize age for brightness
    age_normalized = (emb_df['age'] - emb_df['age'].min()) / (emb_df['age'].max() - emb_df['age'].min())
    
    # Create plot
    fig, ax = plt.subplots(figsize=(14, 12))
    
    # Get unique categories and assign colors
    categories = emb_df['category'].unique()
    colors = sns.color_palette('husl', n_colors=len(categories))
    color_map = dict(zip(categories, colors))
    
    for category in categories:
        cat_mask = emb_df['category'] == category
        cat_data = emb_df[cat_mask]
        cat_tsne = tsne_results[cat_mask]
        cat_age_norm = age_normalized[cat_mask]
        
        # Use age for alpha
        alphas = 0.3 + 0.7 * cat_age_norm
        
        for i, (x, y, alpha) in enumerate(zip(cat_tsne[:, 0], cat_tsne[:, 1], alphas)):
            ax.scatter(x, y, c=[color_map[category]], 
                      alpha=alpha, s=80, edgecolors='white', linewidth=0.3,
                      label=category.capitalize() if i == 0 else '')
    
    # Remove duplicate labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), title='Category', 
             loc='best', ncol=3, fontsize=10)
    
    ax.set_title('t-SNE: All Embeddings (colored by category, brightness by age)', fontsize=16)
    ax.set_xlabel('t-SNE Component 1', fontsize=12)
    ax.set_ylabel('t-SNE Component 2', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(f'{output_dir}/tsne_all_embeddings.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("Saved visualization for all embeddings")

In [85]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cosine
import pandas as pd

def create_rdm_from_centroids(output_centroids, metric='cosine'):
    """
    Create RDMs (Representational Dissimilarity Matrices) from centroid embeddings.
    
    Creates two RDMs:
    1. Average similarity across all categories for each age
    2. Average similarity across ages for each category
    
    Args:
        output_centroids: Dictionary from calculate_centroid_distances with keys (age, location, category)
        metric: Distance metric to use ('cosine' or 'euclidean')
    
    Returns:
        rdm_by_age: RDM averaged across categories for each age
        rdm_by_category: RDM averaged across ages for each category
    """
    
    # Define category order: animals first, then non-animals
    category_order = ["bird", "rabbit", "cat", "airplane", "car", "bike", "hat", 
                      "watch", "house", "cup", "chair", "tree"]
    
    # Extract unique ages and locations
    ages = sorted(list(set([k[0] for k in output_centroids.keys() if output_centroids[k] is not None])))
    locations = sorted(list(set([k[1] for k in output_centroids.keys() if output_centroids[k] is not None])))
    
    # Filter category_order to only include categories present in the data
    available_categories = set([k[2] for k in output_centroids.keys() if output_centroids[k] is not None])
    categories = [c for c in category_order if c in available_categories]
    
    print(f"Ages: {ages}")
    print(f"Locations: {locations}")
    print(f"Categories: {categories}")
    
    # ===== RDM 1: Average across categories for each age =====
    # Create matrix: rows and columns are (age, location) pairs
    age_location_pairs = [(age, loc) for age in ages for loc in locations]
    n_pairs = len(age_location_pairs)
    
    rdm_by_age = np.zeros((n_pairs, n_pairs))
    count_matrix_age = np.zeros((n_pairs, n_pairs))  # Track how many categories contributed
    
    for i, (age1, loc1) in enumerate(age_location_pairs):
        for j, (age2, loc2) in enumerate(age_location_pairs):
            distances = []
            
            # Average across all categories
            for category in categories:
                centroid1 = output_centroids.get((age1, loc1, category))
                centroid2 = output_centroids.get((age2, loc2, category))
                
                if centroid1 is not None and centroid2 is not None:
                    if metric == 'cosine':
                        dist = cosine(centroid1, centroid2)
                    else:  # euclidean
                        dist = np.linalg.norm(centroid1 - centroid2)
                    distances.append(dist)
            
            if distances:
                rdm_by_age[i, j] = np.mean(distances)
                count_matrix_age[i, j] = len(distances)
    
    # ===== RDM 2: Average across ages for each category =====
    # Create matrix: rows and columns are (category, location) pairs
    category_location_pairs = [(cat, loc) for cat in categories for loc in locations]
    n_cat_pairs = len(category_location_pairs)
    
    rdm_by_category = np.zeros((n_cat_pairs, n_cat_pairs))
    count_matrix_category = np.zeros((n_cat_pairs, n_cat_pairs))
    
    for i, (cat1, loc1) in enumerate(category_location_pairs):
        for j, (cat2, loc2) in enumerate(category_location_pairs):
            distances = []
            
            # Average across all ages
            for age in ages:
                centroid1 = output_centroids.get((age, loc1, cat1))
                centroid2 = output_centroids.get((age, loc2, cat2))
                
                if centroid1 is not None and centroid2 is not None:
                    if metric == 'cosine':
                        dist = cosine(centroid1, centroid2)
                    else:  # euclidean
                        dist = np.linalg.norm(centroid1 - centroid2)
                    distances.append(dist)
            
            if distances:
                rdm_by_category[i, j] = np.mean(distances)
                count_matrix_category[i, j] = len(distances)
    
    return rdm_by_age, rdm_by_category, age_location_pairs, category_location_pairs, categories


def plot_rdms(rdm_by_age, rdm_by_category, age_location_pairs, category_location_pairs, 
              categories, output_centroids, output_dir='rdm_plots', metric='cosine'):
    """
    Plot the RDMs with appropriate labels.
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # Extract unique ages and locations
    ages = sorted(list(set([k[0] for k in output_centroids.keys() if output_centroids[k] is not None])))
    locations = sorted(list(set([k[1] for k in output_centroids.keys() if output_centroids[k] is not None])))
    ages = [age for age in ages if age >= 4 and age <= 9]  # Filter to include only ages 4-9
    animal_categories = ["bird", "rabbit", "cat"]
    n_animals = len([c for c in categories if c in animal_categories])
    
    # ===== Plot RDM by Age =====
    fig, ax = plt.subplots(figsize=(14, 12))
    
    # Create labels for age-location pairs
    age_loc_labels = [f"{age}\n{loc}" for age, loc in age_location_pairs]
    
    sns.heatmap(rdm_by_age, annot=False, cmap='viridis', square=True,
                xticklabels=age_loc_labels, yticklabels=age_loc_labels,
                cbar_kws={'label': f'{metric.capitalize()} Distance'},
                ax=ax)
    
    ax.set_title(f'RDM: Average Across All Categories by Age and Location\n({metric} distance)', 
                 fontsize=16, pad=20)
    ax.set_xlabel('Age - Location', fontsize=12)
    ax.set_ylabel('Age - Location', fontsize=12)
    
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/rdm_by_age_{metric}.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved RDM by age")
    
    # ===== Plot RDM by Category =====
    fig, ax = plt.subplots(figsize=(16, 14))
    
    # Create labels for category-location pairs
    cat_loc_labels = [f"{cat}\n{loc}" for cat, loc in category_location_pairs]
    
    sns.heatmap(rdm_by_category, annot=False, cmap='viridis', square=True,
                xticklabels=cat_loc_labels, yticklabels=cat_loc_labels,
                cbar_kws={'label': f'{metric.capitalize()} Distance'},
                ax=ax)
    
    # Draw lines to separate animals from non-animals
    n_locations_count = len(set([loc for _, loc in category_location_pairs]))
    animal_boundary = n_animals * n_locations_count
    
    ax.axhline(y=animal_boundary, color='red', linewidth=2, linestyle='--', alpha=0.7)
    ax.axvline(x=animal_boundary, color='red', linewidth=2, linestyle='--', alpha=0.7)
    
    # Add text labels for animal/non-animal regions
    ax.text(animal_boundary / 2, -1, 'Animals', ha='center', va='bottom', 
            fontsize=12, fontweight='bold', color='red')
    ax.text(animal_boundary + (len(cat_loc_labels) - animal_boundary) / 2, -1, 
            'Non-Animals', ha='center', va='bottom', fontsize=12, fontweight='bold', color='red')
    
    ax.set_title(f'RDM: Average Across All Ages by Category and Location\n({metric} distance)', 
                 fontsize=16, pad=20)
    ax.set_xlabel('Category - Location', fontsize=12)
    ax.set_ylabel('Category - Location', fontsize=12)
    
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/rdm_by_category_{metric}.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved RDM by category")
    
    # ===== Plot simplified RDM: just categories (averaged across locations AND ages) =====
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Compute category-only RDM by averaging across all age-location combinations
    n_cats = len(categories)
    rdm_category_only = np.zeros((n_cats, n_cats))
    
    n_locations_actual = len(set([loc for _, loc in category_location_pairs]))
    
    for i, cat1 in enumerate(categories):
        for j, cat2 in enumerate(categories):
            # Collect all distances between these two categories across all locations
            distances = []
            for loc_idx in range(n_locations_actual):
                i_full = i * n_locations_actual + loc_idx
                j_full = j * n_locations_actual + loc_idx
                if i_full < rdm_by_category.shape[0] and j_full < rdm_by_category.shape[1]:
                    distances.append(rdm_by_category[i_full, j_full])
            
            if distances:
                rdm_category_only[i, j] = np.mean(distances)
    
    sns.heatmap(rdm_category_only, annot=True, fmt='.3f', cmap='viridis', square=True,
                xticklabels=[c.capitalize() for c in categories], 
                yticklabels=[c.capitalize() for c in categories],
                cbar_kws={'label': f'{metric.capitalize()} Distance'},
                ax=ax)
    
    # Draw line to separate animals from non-animals
    ax.axhline(y=n_animals, color='red', linewidth=2, linestyle='--', alpha=0.7)
    ax.axvline(x=n_animals, color='red', linewidth=2, linestyle='--', alpha=0.7)
    
    # Add region labels
    #ax.text(n_animals / 2, -0.5, 'Animals', ha='center', va='bottom', 
    #        fontsize=12, fontweight='bold', color='red')
    #ax.text(n_animals + (n_cats - n_animals) / 2, -0.5, 'Non-Animals', 
    #        ha='center', va='bottom', fontsize=12, fontweight='bold', color='red')
    
    ax.set_title(f'RDM: Categories (averaged across ages and locations)\n({metric} distance)', 
                 fontsize=16, pad=20)
    ax.set_xlabel('Category', fontsize=10)
    ax.set_ylabel('Category', fontsize=10)
    
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/rdm_category_only_{metric}.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved simplified category RDM")
    
    # ===== Plot RDM by Category for each Age separately =====
    for age in ages:
        fig, ax = plt.subplots(figsize=(14, 12))
        
        # Build RDM for this age across all categories and locations
        cat_loc_pairs_age = [(cat, loc) for cat in categories for loc in locations]
        n_pairs = len(cat_loc_pairs_age)
        rdm_age_specific = np.zeros((n_pairs, n_pairs))
        
        for i, (cat1, loc1) in enumerate(cat_loc_pairs_age):
            for j, (cat2, loc2) in enumerate(cat_loc_pairs_age):
                centroid1 = output_centroids.get((age, loc1, cat1))
                centroid2 = output_centroids.get((age, loc2, cat2))
                
                if centroid1 is not None and centroid2 is not None:
                    if metric == 'cosine':
                        dist = cosine(centroid1, centroid2)
                    else:  # euclidean
                        dist = np.linalg.norm(centroid1 - centroid2)
                    rdm_age_specific[i, j] = dist
                else:
                    rdm_age_specific[i, j] = np.nan
        
        # Create labels
        cat_loc_labels_age = [f"{cat}-{loc}" for cat, loc in cat_loc_pairs_age]
        
        sns.heatmap(rdm_age_specific, annot=False, cmap='viridis', square=True,
                    xticklabels=cat_loc_labels_age, yticklabels=cat_loc_labels_age,
                    cbar_kws={'label': f'{metric.capitalize()} Distance'},
                    ax=ax, mask=np.isnan(rdm_age_specific))
        
        # Draw lines to separate animals from non-animals
        animal_boundary_age = n_animals * len(locations)
        ax.axhline(y=animal_boundary_age, color='red', linewidth=2, linestyle='--', alpha=0.7)
        ax.axvline(x=animal_boundary_age, color='red', linewidth=2, linestyle='--', alpha=0.7)
        
        # Add text labels
        #ax.text(animal_boundary_age / 2, -1, 'Animals', ha='center', va='bottom', 
        #        fontsize=12, fontweight='bold', color='red')
        #ax.text(animal_boundary_age + (n_pairs - animal_boundary_age) / 2, -1, 
        #        'Non-Animals', ha='center', va='bottom', fontsize=12, fontweight='bold', color='red')
        
        ax.set_title(f'RDM: All Categories by Location (Age {age})\n({metric} distance)', 
                     fontsize=16, pad=20)
        ax.set_xlabel('Category - Location', fontsize=12)
        ax.set_ylabel('Category - Location', fontsize=12)
        
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig(f'{output_dir}/rdm_age_{age}_all_categories_{metric}.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved RDM for age {age}")
    
    # ===== Plot simplified RDM by Category for each Age (averaged across locations) =====
    for age in ages:
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Build category-only RDM for this age
        n_cats = len(categories)
        rdm_cat_age = np.zeros((n_cats, n_cats))
        count_matrix = np.zeros((n_cats, n_cats))
        
        for i, cat1 in enumerate(categories):
            for j, cat2 in enumerate(categories):
                distances = []
                for loc in locations:
                    centroid1 = output_centroids.get((age, loc, cat1))
                    centroid2 = output_centroids.get((age, loc, cat2))
                    
                    if centroid1 is not None and centroid2 is not None:
                        if metric == 'cosine':
                            dist = cosine(centroid1, centroid2)
                        else:
                            dist = np.linalg.norm(centroid1 - centroid2)
                        distances.append(dist)
                
                if distances:
                    rdm_cat_age[i, j] = np.mean(distances)
                    count_matrix[i, j] = len(distances)
                else:
                    rdm_cat_age[i, j] = np.nan
        
        sns.heatmap(rdm_cat_age, annot=True, fmt='.3f', cmap='viridis', square=True,
                    xticklabels=[c.capitalize() for c in categories], 
                    yticklabels=[c.capitalize() for c in categories],
                    cbar_kws={'label': f'{metric.capitalize()} Distance'},
                    ax=ax, mask=np.isnan(rdm_cat_age))
        
        # Draw line to separate animals from non-animals
        ax.axhline(y=n_animals, color='red', linewidth=2, linestyle='--', alpha=0.7)
        ax.axvline(x=n_animals, color='red', linewidth=2, linestyle='--', alpha=0.7)
        
        # Add region labels
        ax.text(n_animals / 2, -0.5, 'Animals', ha='center', va='bottom', 
                fontsize=12, fontweight='bold', color='red')
        ax.text(n_animals + (n_cats - n_animals) / 2, -0.5, 'Non-Animals', 
                ha='center', va='bottom', fontsize=12, fontweight='bold', color='red')
        
        ax.set_title(f'RDM: Categories (Age {age}, averaged across locations)\n({metric} distance)', 
                     fontsize=16, pad=20)
        ax.set_xlabel('Category', fontsize=12)
        ax.set_ylabel('Category', fontsize=12)
        
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.savefig(f'{output_dir}/rdm_age_{age}_categories_only_{metric}.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved simplified category RDM for age {age}")


# Example sanjosege:
rdm_by_age, rdm_by_category, age_loc_pairs, cat_loc_pairs, categories = create_rdm_from_centroids(
     output_centroids, metric='cosine')

plot_rdms(rdm_by_age, rdm_by_category, age_loc_pairs, cat_loc_pairs, 
          categories, output_centroids, metric='cosine')

Ages: [np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11)]
Locations: ['Beijing', 'Kisumu', 'newdelhi', 'sanjose']
Categories: ['bird', 'rabbit', 'cat', 'airplane', 'car', 'bike', 'hat', 'watch', 'house', 'cup', 'chair', 'tree']
Saved RDM by age
Saved RDM by category
Saved simplified category RDM
Saved RDM for age 4
Saved RDM for age 5
Saved RDM for age 6
Saved RDM for age 7
Saved RDM for age 8
Saved RDM for age 9
Saved simplified category RDM for age 4
Saved simplified category RDM for age 5
Saved simplified category RDM for age 6
Saved simplified category RDM for age 7
Saved simplified category RDM for age 8
Saved simplified category RDM for age 9


In [86]:
stores = {
    'kisumu': kisumu_store,
    'beijing': beijing_store,
    'sanjose': sanjose_store,
    'newdelhi': newdelhi_store
}
emb_df = collect_all_embeddings(recognizability_df, stores['kisumu'], stores['beijing'], 
                                 stores['sanjose'], stores['newdelhi'])
emb_df = emb_df[(emb_df['age'] >= 4) & (emb_df['age'] <= 9)]
emb_df = emb_df[emb_df['category'].isin(selected_categories)]

create_tsne_per_category(recognizability_df, stores, centroid_urls=output_centroid_urls, emb_df=emb_df)

Processing airplane...
Running t-SNE for airplane...
Saved visualization for airplane
Processing bike...
Running t-SNE for bike...
Saved visualization for bike
Processing bird...
Running t-SNE for bird...
Saved visualization for bird
Processing hat...
Running t-SNE for hat...
Saved visualization for hat
Processing rabbit...
Running t-SNE for rabbit...
Saved visualization for rabbit
Processing watch...
Running t-SNE for watch...
Saved visualization for watch
Processing cat...
Running t-SNE for cat...
Saved visualization for cat
Processing house...
Running t-SNE for house...
Saved visualization for house
Processing cup...
Running t-SNE for cup...
Saved visualization for cup
Processing chair...
Running t-SNE for chair...
Saved visualization for chair
Processing tree...
Running t-SNE for tree...
Saved visualization for tree
Processing car...
Running t-SNE for car...
Saved visualization for car


In [87]:
create_tsne_per_age(recognizability_df, stores, emb_df=emb_df)

Processing age 4...
Running t-SNE for age 4...
Saved visualization for age 4
Processing age 5...
Running t-SNE for age 5...
Saved visualization for age 5
Processing age 6...
Running t-SNE for age 6...
Saved visualization for age 6
Processing age 7...
Running t-SNE for age 7...
Saved visualization for age 7
Processing age 8...
Running t-SNE for age 8...
Saved visualization for age 8
Processing age 9...
Running t-SNE for age 9...
Saved visualization for age 9


In [88]:
create_tsne_all_embeddings(recognizability_df, stores, emb_df=emb_df)

Running t-SNE for all embeddings...
Saved visualization for all embeddings


In [None]:
for _, row in df_all[df_all['category'] == 'bike'].iterrows():
    os.makedirs("../data/examples/bike", exist_ok=True)
    ImageExtractor.save_transformed(str(row['url']), f"../data/examples/bike/centroid_{row['location']}_{str(row['age'])}.png", general_extraction_settings if row['location'] != "Kisumu" else kisumu_extraction_settings)

Just storing the file names as the drawing ID for ease of use and understanding and because they are unique. Maybe we want to include the parent directory too so that it's easy to remap back to the urls and so that it's more clear how they're different since they have different file naming conventions.