In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [19]:
# Imports
import torch
import pandas as pd
import numpy as np
import ast
from scipy.stats import zscore
from sklearn.metrics.pairwise import cosine_similarity


In [29]:
groundTruthDF = pd.read_csv('/content/drive/MyDrive/Projects/zero/data_public/groundTruthObjectEmbeddingsBERT.csv')
parTranscriptDF = pd.read_csv('/content/drive/MyDrive/Projects/zero/data_public/participantDescriptionsByBlockEmbeddingsBERT.csv')

In [27]:
groundTruthDF['object_name'] = groundTruthDF['object_name'].str.replace('_', '')
print(groundTruthDF['object_name'])

0              scissors
1           stethoscope
2           frenchpress
3              shoehorn
4           fishingreel
5       crankflashlight
6               rolodex
7            floppydisk
8           bulbplanter
9        threeholepunch
10          pocketradio
11            handmixer
12    bloodpressurecuff
Name: object_name, dtype: object


In [30]:
parTranscriptDF

Unnamed: 0,subject_name,block_number,object_name,description,total_tokens,num_chunks,embedding
0,tulip003,1,handmixer,A hand mixer is usually made out of metal and ...,68,1,"[0.0014718323945999146, -0.09571817517280579, ..."
1,tulip003,2,handmixer,There are like other forms of hand mixer so th...,72,1,"[-0.12219145894050598, -0.03579965978860855, 0..."
2,tulip003,1,fishingreel,This thing is usually made out of plastic and ...,53,1,"[0.21410039067268372, -0.10099238157272339, 0...."
3,tulip003,2,fishingreel,"This is a tool where a string is attached, so...",82,1,"[-0.05731167271733284, -0.09037087112665176, 0..."
4,tulip003,1,shoehorn,This object is usually made out of metal and i...,37,1,"[-0.13571789860725403, -0.05446772277355194, -..."
...,...,...,...,...,...,...,...
448,tulip024,2,stethoscope,it is this metal little circular thing and the...,85,1,"[0.058361250907182693, 0.12742501497268677, 0...."
449,tulip024,1,frenchpress,It's like a glass cylindrical pitcher and wit...,147,1,"[-0.18195031583309174, 0.12108269333839417, 0...."
450,tulip024,2,frenchpress,It's a glass cylinder that works as a pitcher...,109,1,"[-0.11101942509412766, 0.1576933115720749, 0.2..."
451,tulip024,1,threeholepunch,it is like it's its length is about the size o...,147,1,"[-0.13913710415363312, -0.09535489231348038, 0..."


In [31]:
# zscore the ground truth embeddings
groundTruthDF['embedding'] = groundTruthDF['embedding'].apply(ast.literal_eval)
embeddings_array = np.vstack(groundTruthDF['embedding'].values)
zscored_embeddings = zscore(embeddings_array, axis=0)
groundTruthDF['embedding_zscored'] = list(zscored_embeddings)



In [24]:
groundTruthDF

Unnamed: 0,object_name,description,embedding,embedding_zscored
0,scissors,Scissors are handheld cutting tools consisting...,"[[-0.006037797778844833, 0.155913844704628, 0....","[1.0447710313481469, 0.655182236138258, -1.982..."
1,stethoscope,A medical instrument used by healthcare profes...,"[[0.08775169402360916, 0.20279371738433838, 0....","[1.7031637188741124, 1.096332523626729, 0.1825..."
2,frenchpress,A manual coffee brewing device invented in the...,"[[-0.3064567446708679, 0.22607463598251343, 0....","[-1.0641394022915431, 1.315411274176016, -1.29..."
3,shoehorn,A tool designed to aid in putting on shoes wit...,"[[-0.17555826902389526, -0.015135371126234531,...","[-0.14524542474674182, -0.9544298729646093, 0...."
4,fishingreel,A mechanical device attached to a fishing rod ...,"[[-0.06742638349533081, 0.17384985089302063, 0...","[0.6138294071155932, 0.8239641426179001, -0.96..."
5,crankflashlight,A self-powered illumination device that conver...,"[[0.039940908551216125, 0.01705889403820038, 0...","[1.3675368705466049, -0.6514744995615289, 1.15..."
6,rolodex,A rotating file device used to store and organ...,"[[-0.21346072852611542, 0.16289684176445007, 0...","[-0.41131683313087236, 0.7208938301419022, 1.6..."
7,floppydisk,A portable data storage device popular from th...,"[[-0.41984307765960693, 0.06521817296743393, 0...","[-1.8600999205808282, -0.1982847076433463, -0...."
8,bulbplanter,A specialized gardening tool designed for effi...,"[[-0.31858029961586, -0.16505786776542664, 0.2...","[-1.1492455241949906, -2.365234738951236, -0.3..."
9,threeholepunch,An office tool used to create uniform holes in...,"[[-0.08882688730955124, 0.07960617542266846, 0...","[0.463600048262349, -0.06289032500760511, 0.79..."


In [32]:
# zscore the participant transcript embeddings
parTranscriptDF['embedding'] = parTranscriptDF['embedding'].apply(ast.literal_eval)
embeddings_array = np.vstack(parTranscriptDF['embedding'].values)
zscored_embeddings = zscore(embeddings_array, axis=0)
parTranscriptDF['embedding_zscored'] = list(zscored_embeddings)


In [9]:
groundTruth_lookup = groundTruthDF.set_index('object_name')['embedding_zscored'].to_dict()


In [10]:
def row_cosine_similarity(row):
    obj = row['object_name']
    emb = row['embedding_zscored']
    gt_emb = groundTruth_lookup.get(obj)

    # Check if either embedding is missing or malformed
    if gt_emb is None:
        print(f"Missing ground truth for object: {obj}")
        return np.nan

    # Convert to arrays
    emb = np.array(emb)
    gt_emb = np.array(gt_emb)

    # Check shape and type
    if emb.shape != gt_emb.shape:
        print(f"Shape mismatch for object: {obj} | par shape: {emb.shape} | gt shape: {gt_emb.shape}")
        return np.nan

    if emb.dtype.kind not in 'fi' or gt_emb.dtype.kind not in 'fi':
        print(f"Non-numeric types for object: {obj}")
        return np.nan

    # Check for NaNs (just to be safe)
    if np.isnan(emb).any() or np.isnan(gt_emb).any():
        print(f"NaNs found for object: {obj}")
        return np.nan

    # Compute cosine similarity
    return cosine_similarity(emb.reshape(1, -1), gt_emb.reshape(1, -1))[0, 0]



In [11]:
# True for rows with any NaN in the embedding
groundTruthDF['has_nan'] = groundTruthDF['embedding_zscored'].apply(lambda x: np.isnan(x).any())

# Count how many have NaNs
num_nan_par = groundTruthDF['has_nan'].sum()
print(f"Rows in groundTruthDF with NaNs: {num_nan_par} / {len(groundTruthDF)}")


Rows in groundTruthDF with NaNs: 0 / 13


In [33]:
parTranscriptDF['cosine_similarity_to_gt'] = parTranscriptDF.apply(row_cosine_similarity, axis=1)

In [34]:
parTranscriptDF

Unnamed: 0,subject_name,block_number,object_name,description,total_tokens,num_chunks,embedding,embedding_zscored,cosine_similarity_to_gt
0,tulip003,1,handmixer,A hand mixer is usually made out of metal and ...,68,1,"[0.0014718323945999146, -0.09571817517280579, ...","[0.09163141249400854, -0.8263974732144712, -0....",0.333101
1,tulip003,2,handmixer,There are like other forms of hand mixer so th...,72,1,"[-0.12219145894050598, -0.03579965978860855, 0...","[-0.6746503297403398, -0.3844315920588837, 0.0...",0.293445
2,tulip003,1,fishingreel,This thing is usually made out of plastic and ...,53,1,"[0.21410039067268372, -0.10099238157272339, 0....","[1.4091879834620797, -0.8653006280071015, 0.47...",0.423366
3,tulip003,2,fishingreel,"This is a tool where a string is attached, so...",82,1,"[-0.05731167271733284, -0.09037087112665176, 0...","[-0.2726216024912567, -0.7869551417512888, -0....",0.259747
4,tulip003,1,shoehorn,This object is usually made out of metal and i...,37,1,"[-0.13571789860725403, -0.05446772277355194, -...","[-0.7584671491227059, -0.5221293780221273, -2....",0.271147
...,...,...,...,...,...,...,...,...,...
448,tulip024,2,stethoscope,it is this metal little circular thing and the...,85,1,"[0.058361250907182693, 0.12742501497268677, 0....","[0.44414768682145334, 0.8195324371673148, -0.3...",0.346237
449,tulip024,1,frenchpress,It's like a glass cylindrical pitcher and wit...,147,1,"[-0.18195031583309174, 0.12108269333839417, 0....","[-1.044947129050776, 0.7727507411973984, -0.06...",0.520145
450,tulip024,2,frenchpress,It's a glass cylinder that works as a pitcher...,109,1,"[-0.11101942509412766, 0.1576933115720749, 0.2...","[-0.6054226270997455, 1.0427948843733317, 0.27...",0.390811
451,tulip024,1,threeholepunch,it is like it's its length is about the size o...,147,1,"[-0.13913710415363312, -0.09535489231348038, 0...","[-0.7796543159435918, -0.8237178569389181, 0.0...",0.300263


In [35]:
parTranscriptDF.to_csv('/content/drive/MyDrive/Projects/zero/data_public/participantDescriptionsByBlockEmbeddings_CosSimBERT.csv')