# batch makes all queries to get top 40 products relevant for each query

In [12]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed
import polars as pl
import utils

In [7]:
model = 'openai'
locale = 'us'

In [8]:
ESCI_DATASET_ROOT = '/usr/local/google/home/raulramos/projects/esci-data'

dgt = utils.load_examples(ESCI_DATASET_ROOT=ESCI_DATASET_ROOT, locale=locale)

In [9]:
q = pd.read_parquet(f'{model}/queries-{locale}.parquet')

In [13]:

p = pl.read_parquet(f'{model}/products-{locale}.parquet')
p = p.to_pandas()


In [20]:
p.index = p.__index_level_0__

In [14]:
q.shape, p.shape

((97345, 2), (1215851, 3))

In [15]:
qe = np.stack(q.embeddings.values.copy())
pe = np.stack(p.embeddings.values.copy())
qe.shape, pe.shape

((97345, 3072), (1215851, 3072))

In [16]:
def get_dotp_closests_idxs_onlyannotated(query_id):

    q_embeddings = q.loc[query_id].embeddings

    prod_ids = dgt[dgt.query_id == query_id].product_id.values
    pq = p.loc[prod_ids]
    p_embeddings = np.stack(pq.embeddings.values)
    return list(pq.index[np.argsort(p_embeddings.dot(q_embeddings))[::-1]])


In [22]:
qnn = Parallel(n_jobs=-1, verbose=5, prefer='threads')(delayed(get_dotp_closests_idxs_onlyannotated)(qi) for qi in q.index)

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 64 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done 160 tasks      | elapsed:    0.8s
[Parallel(n_jobs=-1)]: Done 322 tasks      | elapsed:    1.2s
[Parallel(n_jobs=-1)]: Done 520 tasks      | elapsed:    1.6s
[Parallel(n_jobs=-1)]: Done 754 tasks      | elapsed:    2.0s
[Parallel(n_jobs=-1)]: Done 1024 tasks      | elapsed:    2.4s
[Parallel(n_jobs=-1)]: Done 1330 tasks      | elapsed:    2.9s
[Parallel(n_jobs=-1)]: Done 1672 tasks      | elapsed:    3.5s
[Parallel(n_jobs=-1)]: Done 2050 tasks      | elapsed:    4.1s
[Parallel(n_jobs=-1)]: Done 2464 tasks      | elapsed:    4.8s
[Parallel(n_jobs=-1)]: Done 2914 tasks      | elapsed:    5.5s
[Parallel(n_jobs=-1)]: Done 3400 tasks      | elapsed:    6.2s
[Parallel(n_jobs=-1)]: Done 3922 tasks      | elapsed:    7.0s
[Parallel(n_jobs=-1)]: Done 4480 tasks      | elapsed:    7.9s
[Parallel(n_jobs=-1)]: Done 5074 tasks   

In [23]:
q['top40_products'] = qnn

In [24]:
q

Unnamed: 0_level_0,query,embeddings,top40_products
query_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,revent 80 cfm,"[-0.05282729119062424, 0.0162424948066473, -0....","[B00MARNO5Y, B07WDM7MQQ, B07X3Y6B1V, B06W2LB17..."
1,!awnmower tires without rims,"[-0.015263666398823261, -0.023168815299868584,...","[B08L3B9B9P, B07P4CF3DP, B07C1WZG12, B077QMNXT..."
2,!qscreen fence without holes,"[0.019541103392839432, 0.025528443977236748, -...","[B08NG85RHL, B07DS3J3MB, B07DS1YCRZ, B001OJXVK..."
5,# 10 self-seal envelopes without window,"[-0.010724861174821854, 0.0026416631881147623,...","[B07CXXVXLC, B071R9SBXJ, B078S5ZL5D, B01N175R8..."
6,# 2 pencils not sharpened,"[-0.013498208485543728, -0.011514219455420971,...","[B07JZJLHCF, B004X4KRW0, B00125Q75Y, B0188A3QR..."
...,...,...,...
129275,茶叶,"[0.03555547818541527, -0.01856878586113453, 0....","[B088HGYXRN, B07NTHKL15, B07PNZ2B39, B088TVHZT..."
130378,香奈儿,"[-0.0326329730451107, 0.0005094753578305244, 0...","[B081X6DRRT, B01E7KBXWC, B00KFKDOYO, B0010POWE..."
130537,가마솥,"[-0.005771995056420565, -0.016856269910931587,...","[B07Y5L6XYM, B07GKZZMSF, B000N4UX4Q, B0793N5ST..."
130538,골프공,"[0.03067653998732567, -0.008546626195311546, -...","[B081B117DH, B07Z44KPHS, B00QN6LC9I, B083T6LDT..."


In [25]:
qr = []
for query_id, qi in tqdm(q.iterrows()):
    for product_id in qi.top40_products:
        qr.append([query_id, product_id])
    

97345it [00:06, 15879.18it/s]


In [26]:
pd.DataFrame(qr, columns = ['query_id', 'product_id']).to_csv(f'{model}/embeddings_dotp_ranking_onlyannotated_{locale}.csv', index=False)