In [1]:
import pandas as pd
import helpful
import torch

In [2]:
df = pd.read_parquet("./dummy.parquet")

In [3]:
class ClipQuery:
    def __init__(self, df: pd.DataFrame, clip_encoding_column: str):
        self.df = df
        self.model, self.preprocess, self.tokenizer = helpful.load_clip()
        self.clip_encoding_column = clip_encoding_column

    
    @torch.no_grad()
    @torch.cuda.amp.autocast()
    def __call__(self, text: str):
        # encode the text in CLIP space
        tokens = self.tokenizer([text])        
        text_encoding = self.model.encode_text(tokens)

        # images are already encoded in CLIP space
        images_encoding = torch.tensor(self.df[self.clip_encoding_column].tolist(), dtype=torch.float32)

        # compute scores for each image
        clip_scores = images_encoding @ text_encoding.T # shapes: [n, vector_size] @ [vector_size, 1] = [n, 1]
        self.df[f"clip('{text}', image)"] = clip_scores.squeeze().tolist()



In [4]:
clipql = ClipQuery(df, "clip(image)")

  from .autonotebook import tqdm as notebook_tqdm


Which images match this description?

In [5]:
clipql("a radio set")

  images_encoding = torch.tensor(df[self.clip_encoding_column].tolist(), dtype=torch.float32)


As expected the clips scores are largest with cassette players

In [10]:
sorted_df = clipql.df.sort_values(by="clip('a radio set', image)", ascending=False)
sorted_df

Unnamed: 0,id,label,split,clip(image),"clip('a radio set', image)"
303,val/n02979186/n02979186_15441.JPEG,cassette player,valid,"[-0.19440984725952148, -0.05704239010810852, -...",29.824308
269,val/n02979186/n02979186_11790.JPEG,cassette player,valid,"[-0.012168984860181808, 0.7639139294624329, -0...",29.571449
73,val/n02979186/n02979186_13231.JPEG,cassette player,valid,"[-0.09194931387901306, 0.24397143721580505, -0...",28.948217
279,val/n02979186/n02979186_1621.JPEG,cassette player,valid,"[0.230721578001976, -0.11661261320114136, -0.4...",28.810951
176,val/n02979186/n02979186_18940.JPEG,cassette player,valid,"[-0.14828135073184967, 0.2738904058933258, -0....",28.378252
...,...,...,...,...,...
2143,val/n03888257/n03888257_31301.JPEG,parachute,valid,"[-0.2291850596666336, 0.4319629669189453, -0.4...",2.114641
2068,val/n03888257/n03888257_66552.JPEG,parachute,valid,"[-0.1232389360666275, 0.13655883073806763, 0.0...",1.754479
2290,val/n03888257/n03888257_18002.JPEG,parachute,valid,"[-0.5015037059783936, -0.02663572132587433, 0....",1.692758
2037,val/n03888257/n03888257_58270.JPEG,parachute,valid,"[-0.8005543351173401, 0.007637642323970795, 0....",1.313354


In [8]:
# get top 100 scores where the label is not "cassette player"
sorted_df[sorted_df["label"] != "cassette player"][:100]

Unnamed: 0,id,label,split,clip(image),"clip('a radio set', image)"
3837,val/n03425413/n03425413_3021.JPEG,gas pump,valid,"[-0.0727212131023407, 0.4719665050506592, -0.6...",24.645161
2697,val/n03394916/n03394916_2151.JPEG,French horn,valid,"[0.07101297378540039, -0.21776656806468964, 0....",24.512970
2785,val/n03000684/n03000684_30141.JPEG,chainsaw,valid,"[-0.033758118748664856, 0.25843387842178345, -...",23.603638
3618,val/n03425413/n03425413_11061.JPEG,gas pump,valid,"[-0.16238315403461456, 0.2804635763168335, -0....",23.148726
3576,val/n03425413/n03425413_19221.JPEG,gas pump,valid,"[-0.08666130900382996, -0.2441093772649765, -0...",23.103786
...,...,...,...,...,...
3649,val/n03425413/n03425413_13811.JPEG,gas pump,valid,"[-0.21496453881263733, 0.2393254041671753, -0....",18.375860
3387,val/n03445777/n03445777_4520.JPEG,golf ball,valid,"[0.7031760215759277, 0.44363221526145935, 0.06...",18.337053
2379,val/n03394916/n03394916_32422.JPEG,French horn,valid,"[-0.5351296067237854, -0.4313502311706543, 0.3...",18.333483
2690,val/n03394916/n03394916_30432.JPEG,French horn,valid,"[0.08296029269695282, -0.11597874760627747, 0....",18.319208


A more natural API would be declarative like SQL

```sql
SELECT *, CLIP('a radio set', image) as score FROM table 
WHERE label != 'cassette player'
ORDER BY score DESC
LIMIT 25
```

would return the top 25 results

Or directly act on the clip scores

```sql
SELECT *, CLIP('a radio set', image) as score FROM table
WHERE label != 'cassette player' AND score > 25
```

If I wanted to make this a classification task, I could do

```sql
SELECT *, ARGMAX(CLIP( ('a dog', 'a radio'), image)) as prediction_index FROM table
```