# batch makes all queries to get top 40 products relevant for each query

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed
import utils

In [2]:
import sys
sys.executable

'/opt/conda/envs/p312/bin/python'

In [16]:
model = 'gemini'
locale = 'us'

In [17]:
ESCI_DATASET_ROOT = '/usr/local/google/home/raulramos/projects/esci-data'

dgt = utils.load_examples(ESCI_DATASET_ROOT=ESCI_DATASET_ROOT, locale=locale)

In [18]:
q = pd.read_parquet(f'{model}/queries-{locale}.parquet')

In [19]:
p = pd.read_parquet(f'{model}/products-{locale}.parquet')

In [20]:
q.shape, p.shape

((97345, 2), (1215851, 2))

In [21]:
qe = np.stack(q.embeddings.values.copy())
pe = np.stack(p.embeddings.values.copy())
qe.shape, pe.shape

((97345, 768), (1215851, 768))

In [22]:
def get_dotp_closests_idxs_onlyannotated(query_id):

    q_embeddings = q.loc[query_id].embeddings

    prod_ids = dgt[dgt.query_id == query_id].product_id.values
    pq = p.loc[prod_ids]
    p_embeddings = np.stack(pq.embeddings.values)
    return list(pq.index[np.argsort(p_embeddings.dot(q_embeddings))[::-1]])


In [23]:
qnn = Parallel(n_jobs=-1, verbose=5, prefer='threads')(delayed(get_dotp_closests_idxs_onlyannotated)(qi) for qi in q.index)

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 64 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 160 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 322 tasks      | elapsed:    0.9s
[Parallel(n_jobs=-1)]: Done 520 tasks      | elapsed:    1.2s
[Parallel(n_jobs=-1)]: Done 754 tasks      | elapsed:    1.5s
[Parallel(n_jobs=-1)]: Done 1024 tasks      | elapsed:    2.0s
[Parallel(n_jobs=-1)]: Done 1330 tasks      | elapsed:    2.5s
[Parallel(n_jobs=-1)]: Done 1672 tasks      | elapsed:    3.0s
[Parallel(n_jobs=-1)]: Done 2050 tasks      | elapsed:    3.6s
[Parallel(n_jobs=-1)]: Done 2464 tasks      | elapsed:    4.2s
[Parallel(n_jobs=-1)]: Done 2914 tasks      | elapsed:    4.9s
[Parallel(n_jobs=-1)]: Done 3400 tasks      | elapsed:    5.6s
[Parallel(n_jobs=-1)]: Done 3922 tasks      | elapsed:    6.4s
[Parallel(n_jobs=-1)]: Done 4480 tasks      | elapsed:    7.3s
[Parallel(n_jobs=-1)]: Done 5074 tasks   

In [24]:
q['top40_products'] = qnn

In [25]:
q

Unnamed: 0_level_0,query,embeddings,top40_products
query_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,revent 80 cfm,"[-0.01598229, -0.046302106, -0.013507076, -0.0...","[B00MARNO5Y, B07QJ7WYFQ, B07JY1PQNT, B07RH6Z8K..."
1,!awnmower tires without rims,"[-0.030829305, -0.08287808, 0.04729122, 0.0214...","[B077QMNXTS, B07P4CF3DP, B06XX6BM2R, B01DBGLLU..."
2,!qscreen fence without holes,"[-0.028439859, -0.023042167, -0.051055882, 0.0...","[B07DS3J3MB, B07R6P8TK8, B07R3TNQDM, B07DS1YCR..."
5,# 10 self-seal envelopes without window,"[0.04445911, -0.056469392, -0.0016481074, -0.0...","[B07CXXVXLC, B071R9SBXJ, B01N175R8R, B007YX2KB..."
6,# 2 pencils not sharpened,"[0.0021929916, 0.0144573795, 0.00267528, -0.04...","[B07GJQJFG6, B07G2RYY6H, B07JZJLHCF, B07G7F6JZ..."
...,...,...,...
129275,茶叶,"[-0.010790561, 0.037765387, 0.00715581, 0.0184...","[B07NTHKL15, B088HGYXRN, B088TVHZTT, B07KD4PN5..."
130378,香奈儿,"[-0.010790561, 0.037765387, 0.00715581, 0.0184...","[B07KFLL4X7, B01M7W1MIU, B01HFH4DAI, B006IB5T4..."
130537,가마솥,"[-0.010790561, 0.037765387, 0.00715581, 0.0184...","[B07QDYP2NZ, B07Y5L6XYM, B01FGJAJ1Y, B0793N5ST..."
130538,골프공,"[-0.010790561, 0.037765387, 0.00715581, 0.0184...","[B07BDPLY63, B082273NWQ, B08TF1TTB8, B082274MR..."


In [26]:
qr = []
for query_id, qi in tqdm(q.iterrows()):
    for product_id in qi.top40_products:
        qr.append([query_id, product_id])
    

97345it [00:05, 16587.64it/s]


In [27]:
pd.DataFrame(qr, columns = ['query_id', 'product_id']).to_csv(f'{model}/embeddings_dotp_ranking_onlyannotated_{locale}.csv', index=False)