In [1]:
import torch
from transformers import DistilBertModel, DistilBertTokenizer
import os
import pandas as pd
import dask.dataframe as dd
import numpy as np

  torch.utils._pytree._register_pytree_node(
Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



# Load Data

In [2]:
examples_path = os.path.join('..', 'data', 'shopping_queries_dataset_examples.parquet')
products_path = os.path.join('..', 'data', 'shopping_queries_dataset_products.parquet')
sources_path = os.path.join('..', 'data', 'shopping_queries_dataset_sources.csv')

examples = dd.read_parquet(examples_path)
products = dd.read_parquet(products_path)
sources = dd.read_csv(sources_path)

In [3]:
examples_products = dd.merge(
    examples,
    products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

examples_products = examples_products[examples_products['product_locale'] == 'us']

task_2 = examples_products[examples_products['large_version'] == 1]
task_2_train = task_2[task_2['split'] == 'train']
task_2_test = task_2[task_2['split'] == 'test']

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)

for param in model.parameters():
    param.requires_grad = False

def generate_embeddings(texts):
    batch_size = 16  # Adjust this size
    embeddings = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = tokenizer(batch.tolist(), return_tensors='pt', padding=True, truncation=True, max_length=512)
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        embeddings.append(batch_embeddings)

    return np.vstack(embeddings)

def process_partition(partition):
    query_embeddings = generate_embeddings(partition['query'])
    product_title_embeddings = generate_embeddings(partition['product_title'])

    combined = torch.cat((torch.tensor(query_embeddings), torch.tensor(product_title_embeddings)), dim=1).numpy()
    
    print(f'Combined shape: {combined.shape}')  # Expecting (n, 1536)

    result = pd.DataFrame(combined, index=partition.index, columns=[f'embedding_{i}' for i in range(combined.shape[1])])

    return result



In [6]:
meta = pd.DataFrame(columns=[f'embedding_{i}' for i in range(2 * 768)], dtype='float64')

In [24]:
total_rows = task_2_train.shape[0].compute()

sample_fraction = 10000 / total_rows

task_2_train_sample = task_2_train.sample(frac=sample_fraction, random_state=42)

In [25]:
result = task_2_train_sample.map_partitions(process_partition, meta=meta)

In [26]:
result = result.compute()

In [20]:
result

Unnamed: 0,embedding_0,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_1526,embedding_1527,embedding_1528,embedding_1529,embedding_1530,embedding_1531,embedding_1532,embedding_1533,embedding_1534,embedding_1535
1322108,-0.243804,0.005680,-0.031020,-0.141934,0.066477,0.058729,0.221509,0.247692,-0.226389,-0.155458,...,0.066826,-0.105104,0.017408,-0.383734,0.075518,0.073233,-0.064225,-0.123030,0.319219,0.487798
686437,-0.069240,-0.094564,-0.027499,-0.062775,0.215005,-0.084165,0.315619,0.399373,-0.167063,-0.145985,...,0.074895,-0.415459,-0.061893,-0.310011,0.124662,0.092429,-0.048703,0.139839,0.265782,0.358987
2135583,-0.207307,-0.005689,-0.129575,-0.089402,0.195646,0.076074,0.084630,0.407377,-0.110144,-0.070404,...,0.060664,-0.558679,-0.059311,-0.290303,0.296790,0.005240,0.085023,-0.169601,0.160973,0.352225
1566068,-0.410879,-0.097279,-0.090982,-0.216974,-0.059183,-0.139799,0.370095,0.509796,-0.280695,-0.058542,...,0.237758,-0.205706,0.103665,-0.265193,0.184633,-0.103951,-0.013666,0.036094,0.131394,0.338542
2075274,-0.333989,-0.095462,-0.064409,-0.236993,-0.165798,-0.210848,-0.000422,0.194408,-0.200857,-0.070705,...,0.106757,-0.131075,-0.088616,-0.444016,0.352837,-0.088343,-0.015121,-0.224137,0.475204,0.239741
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1259046,-0.112701,-0.147042,-0.007530,-0.016934,-0.116315,-0.006384,0.278548,0.521637,-0.247869,-0.166447,...,-0.039587,-0.042352,-0.032296,-0.097632,0.382995,-0.108811,-0.093898,-0.107622,0.250654,0.187417
829792,-0.234443,0.007599,0.034869,0.003852,0.005563,-0.033420,-0.043082,0.367536,-0.230304,-0.089083,...,-0.089568,-0.388918,0.031594,-0.202298,0.100956,-0.181122,-0.097561,-0.126920,0.218845,0.035315
1890259,-0.146819,-0.550584,0.021597,-0.058736,-0.104899,-0.017471,0.185924,0.191096,-0.357679,-0.016716,...,0.107400,-0.502833,-0.110226,-0.541127,0.083311,-0.128851,0.192413,0.042367,0.014201,0.026295
215884,-0.269557,-0.109104,0.016349,-0.145224,0.082217,0.065115,0.185822,0.357128,-0.158699,-0.254989,...,0.398888,-0.667779,0.015896,-0.251051,0.163636,-0.091871,-0.018640,-0.073983,-0.003373,0.082522
