In [9]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
# Imports
import torch
import pandas as pd
import numpy as np
from transformers import CLIPTokenizer, CLIPTextModel
import re

In [14]:
# chunking function
# Initialize tokenizer and model first so they can be used in functions
print("Loading CLIP model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
model.eval()
def chunk_by_sentences(text, tokenizer, max_tokens=77, overlap_sentences=1):
    """Chunk text by sentences, keeping within token limit with overlap."""
    # Split into sentences
    sentences = re.split(r'[.!?]+', text)
    sentences = [s.strip() for s in sentences if s.strip()]

    chunks = []
    current_chunk = []
    current_tokens = 0

    for i, sentence in enumerate(sentences):
        # Count tokens for this sentence
        sentence_tokens = len(tokenizer.encode(sentence))

        # If adding this sentence would exceed limit
        if current_tokens + sentence_tokens > max_tokens:
            if current_chunk:
                chunks.append('. '.join(current_chunk) + '.')
                # Start new chunk with overlap
                overlap_start = max(0, len(current_chunk) - overlap_sentences)
                current_chunk = current_chunk[overlap_start:] + [sentence]
                current_tokens = sum(len(tokenizer.encode(s)) for s in current_chunk)
            else:
                # Single sentence too long, truncate it
                current_chunk = [sentence[:max_tokens//2]]  # Rough truncation
                current_tokens = len(tokenizer.encode(current_chunk[0]))
        else:
            current_chunk.append(sentence)
            current_tokens += sentence_tokens

    if current_chunk:
        chunks.append('. '.join(current_chunk) + '.')

    return chunks

Loading CLIP model...


pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [3]:
# embedding by chunk
def get_embeddings_for_chunks(chunks, tokenizer, model, device):
    """Get embeddings for each chunk."""
    embeddings = []
    for chunk in chunks:
        with torch.no_grad():
            inputs = tokenizer(chunk, padding=True, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
            embeddings.append(embedding)
    return embeddings

In [4]:
# average embedding
def average_embeddings(embeddings):
    """Average multiple embeddings."""
    if not embeddings:
        return np.zeros(512)  # CLIP embedding size
    return np.mean(embeddings, axis=0)

In [10]:
# Load object descriptions
print("Loading CSV...")
df = pd.read_csv('/content/drive/MyDrive/Projects/zero/data_public/groundTruthObjectDescriptions.csv')
df.head()

Loading CSV...


Unnamed: 0,object_name,description,Unnamed: 2
0,scissors,Scissors are handheld cutting tools consisting...,
1,stethoscope,A medical instrument used by healthcare profes...,
2,french_press,A manual coffee brewing device invented in the...,
3,shoe_horn,A tool designed to aid in putting on shoes wit...,
4,fishing_reel,A mechanical device attached to a fishing rod ...,


In [15]:
# Get embeddings
results = []
for idx, row in df.iterrows():
    object_name = row['object_name']
    description = row['description']

    if pd.isna(description) or not description:
        continue

    print(f"Processing {object_name}...")

    # Chunk the description
    chunks = chunk_by_sentences(description, tokenizer, max_tokens=77, overlap_sentences=1)

    # Get embeddings for chunks
    chunk_embeddings = get_embeddings_for_chunks(chunks, tokenizer, model, device)

    # Average embeddings
    avg_embedding = average_embeddings(chunk_embeddings)

    results.append({
        'object_name': object_name,
        'description': description,
        'embedding': avg_embedding.tolist()
    })


Processing scissors...
Processing stethoscope...
Processing french_press...
Processing shoe_horn...
Processing fishing_reel...
Processing crank_flashlight...
Processing rolodex...
Processing floppy_disk...
Processing bulb_planter...
Processing three_hole_punch...
Processing pocket_radio...
Processing hand_mixer...
Processing blood_pressure_cuff...


In [17]:
# print the first row of the results

results[0]

{'object_name': 'scissors',
 'description': 'Scissors are handheld cutting tools consisting of two metal blades pivoted together at a central point. They have two looped handles, usually made of plastic or metal, designed for gripping. When the handles are squeezed together, the sharp edges of the blades slide past each other, creating a cutting action. The blades can be straight or slightly curved, with some featuring serrated edges for specific tasks. Scissors are used for a wide range of cutting applications, from household and office tasks to specialized uses in sewing, crafting, and medical procedures. Kitchen scissors help cut food, while heavy-duty shears can handle tougher materials like fabric, cardboard, or plastic. There are also child-safe scissors with rounded tips designed for young users. The simple yet effective mechanism of scissors makes them essential in everyday life, offering precision and ease when cutting paper, thread, hair, and various other materials.',
 'embe

In [19]:
# Save results
output_df = pd.DataFrame(results)
output_df.to_csv('/content/drive/MyDrive/Projects/zero/data_public/ground_truth_embeddings_chunked.csv', index=False)
print(f"Saved {len(results)} embeddings to ground_truth_embeddings_chunked.csv")
output_df.head()

Saved 13 embeddings to ground_truth_embeddings_chunked.csv


Unnamed: 0,object_name,description,embedding
0,scissors,Scissors are handheld cutting tools consisting...,"[0.22075822949409485, -0.9771410226821899, -0...."
1,stethoscope,A medical instrument used by healthcare profes...,"[0.6519325971603394, 0.28762707114219666, -0.1..."
2,french_press,A manual coffee brewing device invented in the...,"[-0.18385392427444458, -0.028368502855300903, ..."
3,shoe_horn,A tool designed to aid in putting on shoes wit...,"[0.9078508615493774, -0.058848828077316284, 0...."
4,fishing_reel,A mechanical device attached to a fishing rod ...,"[0.6220713257789612, -0.29306548833847046, 0.4..."
