# batch makes all queries to get top 40 products relevant for each query

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed

In [4]:
model = 'gemini'
locale = 'es'

In [5]:
q = pd.read_parquet(f'{model}/queries-{locale}.parquet')

In [6]:
p = pd.read_parquet(f'{model}/products-{locale}.parquet')

In [7]:
q.shape, p.shape

((15180, 2), (259973, 2))

In [8]:
qe = np.stack(q.embeddings.values)
pe = np.stack(p.embeddings.values)
qe.shape, pe.shape

((15180, 768), (259973, 768))

In [7]:
qi = qe[0]

In [9]:
def get_dotp_closests_idxs(q_embedding,p_embeddings_matrix, topk=40):
    return np.argsort( p_embeddings_matrix.dot(q_embedding) )[-topk:][::-1]


In [None]:
qnn = Parallel(n_jobs=-1, verbose=5, prefer='threads')(delayed(get_dotp_closests_idxs)(qi,pe) for qi in qe)

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done  56 tasks      | elapsed:    2.2s
[Parallel(n_jobs=-1)]: Done 146 tasks      | elapsed:    5.0s
[Parallel(n_jobs=-1)]: Done 272 tasks      | elapsed:    9.1s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:   14.2s
[Parallel(n_jobs=-1)]: Done 632 tasks      | elapsed:   20.4s
[Parallel(n_jobs=-1)]: Done 866 tasks      | elapsed:   27.6s
[Parallel(n_jobs=-1)]: Done 1136 tasks      | elapsed:   36.3s
[Parallel(n_jobs=-1)]: Done 1442 tasks      | elapsed:   46.5s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:   57.9s
[Parallel(n_jobs=-1)]: Done 2162 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done 2576 tasks      | elapsed:  1.4min
[Parallel(n_jobs=-1)]: Done 3026 tasks      | elapsed:  1.7min
[Parallel(n_jobs=-1)]: Done 3512 tasks      | elapsed:  2.0min
[Parallel(n_jobs=-1)]: Done 4034 tasks      

In [21]:
qnn[0]

array([ 15656,  52031,  45378, 101655,  52855,  95519,  81247,  78573,
        76855,  45839, 106639,  45927,  36356, 119329,  69711,  31470,
       114362, 112287, 111813, 141540,  72882, 114055, 114053,  84736,
        91571,  91570, 152210, 165589, 101158,  73089,  81492, 108548,
        44015, 133191,  48123,  84020,  54317,  41387,  39492,  82285])

In [11]:
qnn[0]

array([ 82285,  39492,  41387,  54317,  84020,  48123, 133191,  44015,
       108548,  81492,  73089, 101158, 165589, 152210,  91570,  91571,
        84736, 114053, 114055,  72882, 141540, 111813, 112287, 114362,
        31470,  69711, 119329,  36356,  45927, 106639,  45839,  76855,
        78573,  81247,  95519,  52855, 101655,  45378,  52031,  15656])

In [22]:
q['top40_products'] = [p.index[qnni] for qnni in qnn]

In [23]:
qr = []
for query_id, qi in tqdm(q.iterrows()):
    for product_id in qi.top40_products.values:
        qr.append([query_id, product_id])
    

8049it [00:11, 725.28it/s] 


In [24]:
pd.DataFrame(qr, columns = ['query_id', 'product_id']).to_csv(f'{model}/embeddings_dotp_ranking_{locale}.csv', index=False)