# 0. Change below hyperparameter for new dataset of embedding points

In [None]:
NUM_SAMPLES = 5000              # Change this to have more or less samples
UMAP_N_NEIGHBORS = 30           # Change this to adjust UMAP's nearest neighbor hyperparameter
FILENAME = "pairs_5K_UMAPn30"   # Change this to set filename, in format "pairs_{sample size}_UMAPn{UMAP neighbor hyperparameter}"

# 1. Experiment with one sample from MS-COCO

In [None]:
''' pip install necessary packages '''

!pip install torch torchvision transformers datasets clip umap-learn
!pip uninstall clip
!pip install git+https://github.com/openai/CLIP.git

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting clip
  Downloading clip-0.2.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting umap-learn
  Downloading umap_learn-0.5.7-py3-none-any.whl.metadata (21 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.13-py3-none-any.whl.metadata (6.8 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m639.2 kB/s

In [None]:
''' Create dataclass for CLIPPair '''

from PIL import Image
from dataclasses import dataclass
import numpy as np

@dataclass
class CLIPPair:
  # Original data`
  caption : str
  image : Image.Image

  # CLIP embeddings
  clip_text_embedding: np.array
  clip_image_embedding: np.array

  # UMAP projections
  umap_text_embedding: np.array
  umap_image_embedding: np.array

  # Averaged projection
  umap_average_embedding: np.array

  # Metadata
  similarity_score : float  # cosine similarity of text-image embeddings.

In [None]:
''' Load MS-COCO. Not necessary if you already have the .json with the projected CLIP pairs from a previous random sampling run. '''

import datasets as ds
import aiohttp

# Download MS-COCO from HuggingFace official datasets
dataset = ds.load_dataset(
    "HuggingFaceM4/COCO",
    storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}} # Adjust timeout to 1 hr, or 3600 seconds.
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/3.66k [00:00<?, ?B/s]

COCO.py:   0%|          | 0.00/9.47k [00:00<?, ?B/s]

The repository for HuggingFaceM4/COCO contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceM4/COCO.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/36.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.5G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.65G [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [None]:
# Display an image
# dataset["train"][29]['image']

# Display the raw caption
# dataset["train"][0]['sentences']['raw']

In [None]:
''' Import clip model and experiment with one sample '''

import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# # Uncomment to experiment with one sample
# # Get sample
# idx = 0
# image = preprocess(dataset["train"][idx]["image"]).unsqueeze(0).to(device)
# text = clip.tokenize(dataset["train"][idx]["sentences"]["raw"]).to(device)

# # Get the feature of the samples, drop the gradient since we don't need it right now
# with torch.no_grad():
#   image_features = model.encode_image(image)
#   text_features = model.encode_text(text)

# # Normalize the features
# image_features /= image_features.norm(dim=-1, keepdim=True)
# text_features /= text_features.norm(dim=-1, keepdim=True)

# # Get the similarity score
# similarity_score = (image_features @ text_features.T)

# # Print feature embeddings and similarity score
# print("Image embedding:", image_features)
# print("Text embedding:", text_features)
# print("Cosine similarity:", similarity_score)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 169MiB/s]


# 2. Feed in a larger subset of MS-COCO, get image-text embeddings, then use UMAP to project to 3D for visualization

In [None]:
''' Randomly sample MS-COCO while making sure every image in the sampled subset is unique. Not necessary to run if you already have the .json '''

import random
import hashlib


def image_to_hash(image: Image.Image) -> str:
    """Generate a hash for a PIL Image."""
    image_bytes = np.array(image).tobytes()
    return hashlib.md5(image_bytes).hexdigest()

def get_features(image: Image.Image, text: str):
    """
    Get the features of an image and a text.

    TODO: Could work on adding a feature to batch image-text pairs from
    """
    image = preprocess(image).unsqueeze(0).to(device)
    text = clip.tokenize(text).to(device)

    # Get the features
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    # Normalize the features
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    return image_features, text_features


# Shuffle the dataset
shuffled_dataset = dataset["train"].shuffle(seed=42)

# Initialize variables
num_samples = NUM_SAMPLES
pairs = []
seen_hashes = set()
idx = 0

print("Starting sampling unique images...")
# Sample unique image-caption pairs
while len(pairs) < num_samples and idx < len(shuffled_dataset):
    # Get the current image and caption
    image = shuffled_dataset[idx]["image"]
    caption = shuffled_dataset[idx]["sentences"]["raw"]

    # Compute the hash of the image
    img_hash = image_to_hash(image)

    # Check if the image is unique
    if img_hash in seen_hashes:
        idx += 1
        continue

    # Add the pair and mark the image as seen
    pairs.append({"image": image, "caption": caption})
    seen_hashes.add(img_hash)

    # Print progress (optional)
    if len(pairs) % 1000 == 0:
        print(f"Collected {len(pairs)} unique pairs...")

    idx += 1

print("===================================================================")
print("Starting encoding features...")
# Now extract features for the collected pairs
final_pairs = []
for idx, pair in enumerate(pairs):
    image = pair["image"]
    caption = pair["caption"]

    image_features, text_features = get_features(image, caption)
    similarity_score = (image_features @ text_features.T)

    final_pairs.append(CLIPPair(
        caption=caption,
        image=image,
        clip_text_embedding=text_features,
        clip_image_embedding=image_features,
        umap_text_embedding=None,
        umap_image_embedding=None,
        umap_average_embedding=None,
        similarity_score=similarity_score
    ))

    # Print progress
    if idx % 1000 == 0:
        print(f"Processed features for {idx} samples...")

Starting sampling unique images...
Collected 1000 unique pairs...
Collected 2000 unique pairs...
Collected 3000 unique pairs...
Collected 4000 unique pairs...
Collected 5000 unique pairs...
Starting encoding features...
Processed features for 0 samples...
Processed features for 1000 samples...
Processed features for 2000 samples...
Processed features for 3000 samples...
Processed features for 4000 samples...


In [None]:
''' Save and load json file for CLIPPair '''

import json
from typing import List
import io
import base64

def save_clip_pairs(clip_pairs: List[CLIPPair], filename):

  # Convert CLIPPair to a dictionary, store it in a new list
  clip_pairs_dict = []

  for pair in clip_pairs:
    # Convert PIL Image to base64 string
    buffered = io.BytesIO()
    pair.image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    clip_pairs_dict.append({
        'caption': pair.caption,
        'image': img_str,
        'clip_text_embedding': pair.clip_text_embedding.tolist(),
        'clip_image_embedding': pair.clip_image_embedding.tolist(),
        'umap_text_embedding': pair.umap_text_embedding.tolist() if pair.umap_text_embedding is not None else None,
        'umap_image_embedding': pair.umap_image_embedding.tolist() if pair.umap_image_embedding is not None else None,
        'umap_average_embedding': pair.umap_average_embedding.tolist() if pair.umap_image_embedding is not None else None,
        'similarity_score': pair.similarity_score.tolist()
    })

  # Save to json
  with open(filename, 'w') as f:
    json.dump(clip_pairs_dict, f)

def load_clip_pairs(filename):

  # Load file
  with open(filename, 'r') as f:
    clip_pairs_dict = json.load(f)

  # Convert json format back to CLIPPair
  clip_pairs = []
  for pair in clip_pairs_dict:
    # Convert base64 string back to PIL Image
    img_bytes = base64.b64decode(pair['image'])
    image = Image.open(io.BytesIO(img_bytes))

    clip_pairs.append(CLIPPair(
        caption = pair['caption'],
        image = image,  # Store as PIL Image
        clip_text_embedding=torch.tensor(pair['clip_text_embedding'], dtype=torch.float16),
        clip_image_embedding=torch.tensor(pair['clip_image_embedding'], dtype=torch.float16),
        umap_text_embedding = np.array(pair['umap_text_embedding'], dtype=np.float32) if pair['umap_text_embedding'] is not None else None,
        umap_image_embedding = np.array(pair['umap_image_embedding'], dtype=np.float32) if pair['umap_image_embedding'] is not None else None,
        umap_average_embedding = np.array(pair['umap_average_embedding'], dtype=np.float32) if pair['umap_average_embedding'] is not None else None,
        similarity_score=torch.tensor(pair['similarity_score'], dtype=torch.float16)
    ))

  return clip_pairs



In [None]:
import numpy as np
import umap
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd


def project_clip_pairs(clip_pairs: List[CLIPPair], random_state: int = 42, n_neighbors: int = 15) -> List[CLIPPair]:
    """
    Project CLIP embeddings to 3D using UMAP and update the CLIPPair objects.

    Args:
        clip_pairs: List of CLIPPair objects
        random_state: Random seed for UMAP.

    Returns:
        Updated list of CLIPPair objects with UMAP projections.
    """
    # Extract embeddings
    text_embeddings = np.vstack([pair.clip_text_embedding.cpu().numpy() for pair in clip_pairs])
    image_embeddings = np.vstack([pair.clip_image_embedding.cpu().numpy() for pair in clip_pairs])

    # Concatenate text and image embeddings for joint UMAP projection
    combined_embeddings = np.vstack([text_embeddings, image_embeddings])

    # Initialize and fit UMAP once
    reducer = umap.UMAP(n_components=3, random_state=random_state, n_neighbors=n_neighbors)
    combined_3d = reducer.fit_transform(combined_embeddings)

    # Split back into text and image 3D embeddings
    num_samples = len(clip_pairs)
    text_3d = combined_3d[:num_samples]
    image_3d = combined_3d[num_samples:]

    # Calculate average projections
    average_3d = (text_3d + image_3d) / 2

    # Update CLIPPair objects
    for i, pair in enumerate(clip_pairs):
        pair.umap_text_embedding = text_3d[i]
        pair.umap_image_embedding = image_3d[i]
        pair.umap_average_embedding = average_3d[i]

    return clip_pairs


def visualize_clip_pairs(clip_pairs: List[CLIPPair], view: List[str]=["text", "image"]) -> None:
    """
    Create an interactive 3D visualization of CLIP pairs showing text, image, and average embeddings

    Args:
        clip_pairs: List of CLIPPair objects with UMAP projections
        view: View to visualize ('text', 'image', or 'average')
    """
    # Create dataframes for each embedding type
    data = []

    for i, pair in enumerate(clip_pairs):
        # Text embeddings
        if "text" in view:
          data.append({
              'x': pair.umap_text_embedding[0],
              'y': pair.umap_text_embedding[1],
              'z': pair.umap_text_embedding[2],
              'type': 'Text',
              'index': i,
              'caption': pair.caption,
              'similarity': float(pair.similarity_score)
          })

        # Image embeddings
        if "image" in view:
          data.append({
              'x': pair.umap_image_embedding[0],
              'y': pair.umap_image_embedding[1],
              'z': pair.umap_image_embedding[2],
              'type': 'Image',
              'index': i,
              'caption': pair.caption,
              'similarity': float(pair.similarity_score)
          })

        # Average embeddings
        if "average" in view:
          data.append({
              'x': pair.umap_average_embedding[0],
              'y': pair.umap_average_embedding[1],
              'z': pair.umap_average_embedding[2],
              'type': 'Average',
              'index': i,
              'caption': pair.caption,
              'similarity': float(pair.similarity_score)
          })

    df = pd.DataFrame(data)

    # Create scatter plot
    fig = px.scatter_3d(df, x='x', y='y', z='z',
                        color='type',
                        title='3D UMAP Projection of CLIP Pairs',
                        labels={'x': 'UMAP-1',
                               'y': 'UMAP-2',
                               'z': 'UMAP-3'},
                        hover_data=['caption', 'similarity', 'index'])

    # Update layout for better visualization
    fig.update_traces(marker=dict(size=5))
    fig.update_layout(
        scene=dict(
            xaxis_title='UMAP-1',
            yaxis_title='UMAP-2',
            zaxis_title='UMAP-3'
        ),
        width=1000,
        height=800,
        legend_title="Embedding Type"
    )

    # Show the plot
    fig.show()

# Example usage:
# First project the embeddings
# Load projected pairs
# projected_pairs = load_clip_pairs('projected_pairs_combined_50K.json')
# projected_pairs = project_clip_pairs(final_pairs)

# Then visualize them
# visualize_clip_pairs(projected_pairs)

# Save projected_pairs
# save_clip_pairs(projected_pairs, 'projected_pairs_50K.json')
# save_clip_pairs(projected_pairs_combined, 'projected_pairs_combined_50K.json')

In [None]:
loaded_pairs = final_pairs

## UMAP for global view

In [None]:
# ''' Load json from drive '''

# from google.colab import drive
# drive.mount('/content/drive')

# loaded_pairs = load_clip_pairs('/content/drive/MyDrive/(Computer Vision) Visualizing CLIP\'s Latent Space/pairs_50K_UMAPn200.json')

Mounted at /content/drive


In [None]:
n_neighbors = UMAP_N_NEIGHBORS
testing_pairs = project_clip_pairs(loaded_pairs, random_state=42, n_neighbors=n_neighbors)
visualize_clip_pairs(testing_pairs)

In [None]:
filename = FILENAME
save_clip_pairs(testing_pairs, f'{filename}.json')

## PCA for local view given one sample

In [None]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from typing import List, Tuple, Union
from dataclasses import dataclass

def get_local_neighborhood(
    clip_pairs: List[CLIPPair],
    selected_idx: int,
    n_neighbors: int = 50,
    embedding_type: str = 'image'
) -> Tuple[np.ndarray, List[int], np.ndarray]:
    """
    Find nearest neighbors in original CLIP space and project to 2D using PCA.

    Args:
        clip_pairs: List of CLIPPair objects
        selected_idx: Index of the selected point
        n_neighbors: Number of neighbors to find (including the selected point)
        embedding_type: Which embedding to use ('image', 'text', or 'average')

    Returns:
        projected_points: ndarray of shape (n_neighbors, 2) containing PCA projections
        neighbor_indices: List of indices of nearest neighbors
        explained_variance_ratio: Variance explained by each principal component
    """
    # Extract embeddings based on type
    if embedding_type == 'image':
        embeddings = np.vstack([pair.clip_image_embedding for pair in clip_pairs])
    elif embedding_type == 'text':
        embeddings = np.vstack([pair.clip_text_embedding for pair in clip_pairs])
    elif embedding_type == 'average':
        embeddings = np.vstack([
            (pair.clip_image_embedding + pair.clip_text_embedding) / 2
            for pair in clip_pairs
        ])
    else:
        raise ValueError(f"Unknown embedding type: {embedding_type}")

    # Normalize embeddings (in case they aren't already)
    embeddings = embeddings / np.linalg.norm(embeddings, axis=1)[:, np.newaxis]

    # Find nearest neighbors
    n_neighbors = min(n_neighbors, len(clip_pairs))  # Ensure we don't ask for too many neighbors
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nbrs.fit(embeddings)

    # Get indices of nearest neighbors
    distances, indices = nbrs.kneighbors(embeddings[selected_idx].reshape(1, -1))
    neighbor_indices = indices[0]  # Flatten from 2D array

    # Get embeddings of neighborhood
    neighborhood_embeddings = embeddings[neighbor_indices]

    # Project to 2D using PCA
    pca = PCA(n_components=2)
    projected_points = pca.fit_transform(neighborhood_embeddings)

    return projected_points, neighbor_indices, pca.explained_variance_ratio_

def get_neighborhood_info(
    clip_pairs: List[CLIPPair],
    neighbor_indices: List[int]
) -> List[dict]:
    """
    Create a list of dictionaries containing information about each neighbor
    for visualization or analysis.

    Args:
        clip_pairs: List of CLIPPair objects
        neighbor_indices: List of indices of nearest neighbors

    Returns:
        List of dictionaries containing neighbor information
    """
    return [{
        'index': idx,
        'caption': clip_pairs[idx].caption,
        'image': clip_pairs[idx].image,
        'similarity_score': clip_pairs[idx].similarity_score
    } for idx in neighbor_indices]

# Example usage:
def analyze_local_neighborhood(
    clip_pairs: List[CLIPPair],
    selected_idx: int,
    n_neighbors: int = 50,
    embedding_type: str = 'image'
) -> dict:
    """
    Perform complete local neighborhood analysis for a selected point.

    Args:
        clip_pairs: List of CLIPPair objects
        selected_idx: Index of the selected point
        n_neighbors: Number of neighbors to find
        embedding_type: Which embedding to use ('image', 'text', or 'average')

    Returns:
        Dictionary containing analysis results
    """
    # Get local projections and neighbor indices
    projected_points, neighbor_indices, explained_variance = get_local_neighborhood(
        clip_pairs, selected_idx, n_neighbors, embedding_type
    )

    # Get neighbor information
    neighbors_info = get_neighborhood_info(clip_pairs, neighbor_indices)

    # Combine results
    results = {
        'projected_points': projected_points,  # 2D coordinates for visualization
        'neighbor_indices': neighbor_indices,  # Indices of neighbors
        'neighbors_info': neighbors_info,      # Detailed info about neighbors
        'explained_variance': explained_variance,  # PCA explained variance
        'selected_point': neighbors_info[0],   # Info about selected point
    }

    return results

# Example usage:
# Assuming you have a list of CLIPPair objects called clip_pairs:
selected_idx = 0  # Index of point user clicked
results = analyze_local_neighborhood(
    clip_pairs=loaded_pairs,
    selected_idx=selected_idx,
    n_neighbors=50,
    embedding_type='image'
)

# # Access results
# print(f"Selected point caption: {results['selected_point']['caption']}")
# print(f"Variance explained by PCA: {results['explained_variance']}")

# # Get 2D coordinates for plotting
# x_coords = results['projected_points'][:, 0]
# y_coords = results['projected_points'][:, 1]


Selected point caption: This an image of a bird on a car mirror
Variance explained by PCA: [0.21518708 0.07746034]


# 3. Validate data

In [None]:
# def images_are_equal(img1: Image.Image, img2: Image.Image) -> bool:
#   ''' Check if two PIL Images are the same. '''

#   return np.array_equal(np.array(img1), np.array(img2))

# # Check if any images repeat in projected_pairs
# for i, pair in enumerate(projected_pairs):
#   for other_pair in projected_pairs[i+1:]:
#     if images_are_equal(pair.image, other_pair.image):
#       print("Images are equal!")
#       break

# print("All images are unique!")

In [None]:
# import torch
# import numpy as np
# from PIL import Image
# import io
# import base64
# from dataclasses import dataclass
# from typing import List

# def validate_clip_pair_serialization(pairs: List[CLIPPair], filename: str = 'test.json') -> dict:
#     """
#     Validates the save/load functionality of CLIPPair objects by checking type and value preservation.

#     Returns a dictionary with validation results and any detected issues.
#     """
#     validation_results = {
#         'type_matches': True,
#         'value_matches': True,
#         'issues': []
#     }

#     # Save and then load the pairs
#     # save_clip_pairs(pairs, filename)
#     loaded_pairs = load_clip_pairs(filename)



#     if len(pairs) != len(loaded_pairs):
#         validation_results['issues'].append(f"Length mismatch: original={len(pairs)}, loaded={len(loaded_pairs)}")
#         return validation_results

#     for i, (original, loaded) in enumerate(zip(pairs, loaded_pairs)):
#         # Check types
#         type_checks = {
#             'caption': isinstance(loaded.caption, str),
#             'image': isinstance(loaded.image, Image.Image),
#             'clip_text_embedding': isinstance(loaded.clip_text_embedding, torch.Tensor),
#             'clip_image_embedding': isinstance(loaded.clip_image_embedding, torch.Tensor),
#             'umap_text_embedding': isinstance(loaded.umap_text_embedding, np.ndarray),
#             'umap_image_embedding': isinstance(loaded.umap_image_embedding, np.ndarray),
#             'umap_average_embedding': isinstance(loaded.umap_average_embedding, np.ndarray),
#             'similarity_score': isinstance(loaded.similarity_score, torch.Tensor)
#         }

#         for field, type_match in type_checks.items():
#             if not type_match:
#                 validation_results['type_matches'] = False
#                 validation_results['issues'].append(
#                     f"Type mismatch in pair {i}, field {field}: "
#                     f"expected={type(getattr(original, field))}, "
#                     f"got={type(getattr(loaded, field))}"
#                 )

#         # Check values
#         try:
#           value_checks = {
#               'caption': original.caption == loaded.caption,
#               'clip_text_embedding': torch.allclose(original.clip_text_embedding, loaded.clip_text_embedding),
#               'clip_image_embedding': torch.allclose(original.clip_image_embedding, loaded.clip_image_embedding),
#               'similarity_score': torch.allclose(original.similarity_score, loaded.similarity_score)
#           }

#           if original.umap_text_embedding is not None:
#               value_checks.update({
#                   'umap_text_embedding': np.allclose(original.umap_text_embedding, loaded.umap_text_embedding),
#                   'umap_image_embedding': np.allclose(original.umap_image_embedding, loaded.umap_image_embedding),
#                   'umap_average_embedding': np.allclose(original.umap_average_embedding, loaded.umap_average_embedding)
#               })

#           for field, value_match in value_checks.items():
#               if not value_match:
#                   validation_results['value_matches'] = False
#                   validation_results['issues'].append(
#                       f"Value mismatch in pair {i}, field {field}"
#                   )

#         except Exception as e:
#             validation_results['value_matches'] = False
#             validation_results['issues'].append(f"Error comparing values in pair {i}: {str(e)}")

#     return validation_results

# # Example usage:
# results = validate_clip_pair_serialization(projected_pairs_combined, filename="projected_pairs_combined_50K.json")
# if results['type_matches'] and results['value_matches']:
#     print("Validation passed: All types and values preserved correctly")
# else:
#     print("Validation failed:")
#     for issue in results['issues']:
#         print(f"- {issue}")
