In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Imports
import torch
import pandas as pd
import numpy as np
from transformers import CLIPTokenizer, CLIPTextModel
import re

In [13]:
# chunking function
# Initialize tokenizer and model first so they can be used in functions
print("Loading CLIP model...")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)
model.eval()


Loading CLIP model...


In [19]:
def chunk_by_sentences(text, tokenizer, max_tokens=77, overlap_sentences=1):
    # Split while capturing the delimiters (periods, !, ?)
    parts = re.split(r'([.!?]+)', text)

    # Reconstruct sentences with their original punctuation
    sentences = []
    for i in range(0, len(parts)-1, 2):  # Step by 2 to get text + punctuation pairs
        sentence_text = parts[i].strip()
        if sentence_text:  # Only if there's actual text
            punctuation = parts[i+1] if i+1 < len(parts) else ''
            sentences.append(sentence_text + punctuation)

    # Handle case where text doesn't end with punctuation
    if len(parts) % 2 == 1 and parts[-1].strip():
        sentences.append(parts[-1].strip())

    chunks = []
    current_chunk = []

    for sentence in sentences:
        test_chunk = current_chunk + [sentence]
        test_text = ' '.join(test_chunk)  # Just join with spaces, no added periods

        if len(tokenizer.encode(test_text)) <= max_tokens:
            current_chunk.append(sentence)
        else:
            if current_chunk:
                chunks.append(' '.join(current_chunk))  # Join with spaces only
                overlap_start = max(0, len(current_chunk) - overlap_sentences)
                current_chunk = current_chunk[overlap_start:] + [sentence]

                # Validate the new chunk with overlap
                new_chunk_text = ' '.join(current_chunk)
                if len(tokenizer.encode(new_chunk_text)) > max_tokens:
                    current_chunk = [sentence]
                    if len(tokenizer.encode(sentence)) > max_tokens:
                        tokens = tokenizer.encode(sentence)[:max_tokens]
                        truncated = tokenizer.decode(tokens, skip_special_tokens=True)
                        current_chunk = [truncated]
            else:
                tokens = tokenizer.encode(sentence)[:max_tokens]
                truncated = tokenizer.decode(tokens, skip_special_tokens=True)
                current_chunk = [truncated]

    if current_chunk:
        chunks.append(' '.join(current_chunk))

    # Final safety check
    validated_chunks = []
    for chunk in chunks:
        if chunk.strip():
            tokens = tokenizer.encode(chunk)
            if len(tokens) <= max_tokens:
                validated_chunks.append(chunk)
            else:
                # Truncate at token level and ensure it stays truncated
                truncated_tokens = tokens[:max_tokens]
                truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
                # Double-check the truncated text doesn't exceed limits
                if len(tokenizer.encode(truncated_text)) <= max_tokens:
                    validated_chunks.append(truncated_text)
                else:
                    # If still too long, truncate more aggressively
                    safe_tokens = tokens[:max_tokens-1]  # Leave buffer
                    validated_chunks.append(tokenizer.decode(safe_tokens, skip_special_tokens=True))

    return validated_chunks  # <- This should be at the same level as the for loop

In [5]:
# embedding by chunk
def get_embeddings_for_chunks(chunks, tokenizer, model, device):
    """Get embeddings for each chunk."""
    embeddings = []
    for chunk in chunks:
        with torch.no_grad():
            inputs = tokenizer(chunk, padding=True, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs)
            embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]
            embeddings.append(embedding)
    return embeddings

In [6]:
# average embedding
def average_embeddings(embeddings):
    """Average multiple embeddings."""
    if not embeddings:
        return np.zeros(512)  # CLIP embedding size
    return np.mean(embeddings, axis=0)

In [22]:
# Load object descriptions
print("Loading CSV...")
df = pd.read_csv('/content/drive/MyDrive/Projects/zero/data_public/groundTruthObjectDescriptions.csv')
df.head()

Loading CSV...


Unnamed: 0,object_name,description,Unnamed: 2
0,scissors,Scissors are handheld cutting tools consisting...,
1,stethoscope,A medical instrument used by healthcare profes...,
2,french_press,A manual coffee brewing device invented in the...,
3,shoe_horn,A tool designed to aid in putting on shoes wit...,
4,fishing_reel,A mechanical device attached to a fishing rod ...,


In [24]:
# Get embeddings for the groundTruthDescriptions file
results = []
for idx, row in df.iterrows():
    object_name = row['object_name']
    description = row['description']

    if pd.isna(description) or not description:
        continue

    print(f"Processing {object_name}...")

    # Chunk the description
    chunks = chunk_by_sentences(description, tokenizer, max_tokens=77, overlap_sentences=1)

# VALIDATE CHUNK LENGTHS  - right after chunking
    for i, chunk in enumerate(chunks):
        token_count = len(tokenizer.encode(chunk))
        if token_count > 77:  # or use max_tokens variable
            print(f"Warning: Chunk {i} for {object_name} exceeds token limit ({token_count} tokens)")
            print(f"Problematic chunk: {chunk[:100]}...")  # Show first 100 chars

# Get embeddings for chunks
    chunk_embeddings = get_embeddings_for_chunks(chunks, tokenizer, model, device)

    # Average embeddings
    avg_embedding = average_embeddings(chunk_embeddings)

    results.append({
        'object_name': object_name,
        'description': description,
        'embedding': avg_embedding.tolist()
    })


Processing scissors...
Processing stethoscope...
Processing french_press...
Processing shoe_horn...
Processing fishing_reel...
Processing crank_flashlight...
Processing rolodex...
Processing floppy_disk...
Processing bulb_planter...
Processing three_hole_punch...
Processing pocket_radio...
Processing hand_mixer...
Processing blood_pressure_cuff...


In [20]:
# OR get embeddings for the participant descriptions
results = []
for idx, row in df.iterrows():
    subject_name = row['SubjectID']
    talk_block = row['TalkBlock']
    object_name = row['ObjectID']
    description = row['Transcription']

    if pd.isna(description) or not description:
        continue

    #print(f"Processing {object_name}...")

    # Chunk the description
    chunks = chunk_by_sentences(description, tokenizer, max_tokens=77, overlap_sentences=1)

    # VALIDATE CHUNK LENGTHS  - right after chunking
    for i, chunk in enumerate(chunks):
        token_count = len(tokenizer.encode(chunk))
        if token_count > 77:  # or use max_tokens variable
            print(f"Warning: Chunk {i} for {object_name} exceeds token limit ({token_count} tokens)")
            print(f"Problematic chunk: {chunk[:100]}...")  # Show first 100 chars

    # Get embeddings for chunks
    chunk_embeddings = get_embeddings_for_chunks(chunks, tokenizer, model, device)

    # Average embeddings
    avg_embedding = average_embeddings(chunk_embeddings)

    results.append({
        'subject_name': subject_name,
        'talk_block': talk_block,
        'object_name': object_name,
        'description': description,
        'embedding': avg_embedding.tolist()
    })

In [None]:
# print the first row of the results

results[0]

In [25]:
# Save results
output_df = pd.DataFrame(results)
output_df.to_csv('/content/drive/MyDrive/Projects/zero/data_public/groundTruthObjectEmbeddings.csv', index=False)
print(f"Saved {len(results)} embeddings to ground_truth_embeddings_chunked.csv")
output_df.head()

Saved 13 embeddings to ground_truth_embeddings_chunked.csv


Unnamed: 0,object_name,description,embedding
0,scissors,Scissors are handheld cutting tools consisting...,"[0.22075822949409485, -0.9771410226821899, -0...."
1,stethoscope,A medical instrument used by healthcare profes...,"[0.6519325971603394, 0.28762707114219666, -0.1..."
2,french_press,A manual coffee brewing device invented in the...,"[-0.18385392427444458, -0.028368502855300903, ..."
3,shoe_horn,A tool designed to aid in putting on shoes wit...,"[0.9078508615493774, -0.058848828077316284, 0...."
4,fishing_reel,A mechanical device attached to a fishing rod ...,"[0.6220713257789612, -0.29306548833847046, 0.4..."
