In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from transformers import DistilBertModel, DistilBertTokenizer
import os
import pandas as pd
import dask.dataframe as dd
import numpy as np
from collections import Counter

  torch.utils._pytree._register_pytree_node(
Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



# Load Data

In [2]:
examples_path = os.path.join('..', 'data', 'shopping_queries_dataset_examples.parquet')
products_path = os.path.join('..', 'data', 'shopping_queries_dataset_products.parquet')
sources_path = os.path.join('..', 'data', 'shopping_queries_dataset_sources.csv')

examples = dd.read_parquet(examples_path)
products = dd.read_parquet(products_path)
sources = dd.read_csv(sources_path)

In [3]:
examples_products = dd.merge(
    examples,
    products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

examples_products = examples_products[examples_products['product_locale'] == 'us']

task_2 = examples_products[examples_products['large_version'] == 1]

label_mapping = {'E': 0, 
                 'S': 1, 
                 'C': 2, 
                 'I': 3}

task_2['encoded_labels'] = task_2['esci_label'].map(label_mapping).astype(int)

task_2_train = task_2[task_2['split'] == 'train']
task_2_test = task_2[task_2['split'] == 'test']

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)

for param in model.parameters():
    param.requires_grad = False

def generate_embeddings(texts):
    batch_size = 128  # Adjust this size
    embeddings = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = tokenizer(batch.tolist(), return_tensors='pt', padding=True, truncation=True, max_length=512)
        inputs = {key: value.to(device) for key, value in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        embeddings.append(batch_embeddings)

    return np.vstack(embeddings)

def process_partition(partition):
    query_embeddings = generate_embeddings(partition['query'])
    product_title_embeddings = generate_embeddings(partition['product_title'])

    combined = torch.cat((torch.tensor(query_embeddings), torch.tensor(product_title_embeddings)), dim=1).numpy()
    
    print(f'Combined shape: {combined.shape}')  # expecting (n, 1536)

    result = pd.DataFrame(combined, index=partition.index, columns=[f'embedding_{i}' for i in range(combined.shape[1])])

    return result



In [6]:
meta = pd.DataFrame(columns=[f'embedding_{i}' for i in range(2 * 768)], dtype='float64')

In [7]:
total_rows = task_2_train.shape[0].compute()

sample_fraction = 10000 / total_rows

task_2_train_sample = task_2_train.sample(frac=sample_fraction, random_state=21)

In [8]:
result = task_2_train.map_partitions(process_partition, meta=meta)

In [9]:
result = result.compute()

Combined shape: (10000, 1536)
Combined shape: (10000, 1536)


In [10]:
total_rows2 = task_2_test.shape[0].compute()

sample_fraction2 = 10000 / total_rows2

task_2_test_sample = task_2_test.sample(frac=sample_fraction2, random_state=21)

# query_texts = task_2_test_sample['query'].tolist()
# product_titles = task_2_test_sample['product_title'].tolist()

In [11]:
result2 = task_2_test.map_partitions(process_partition, meta=meta)

In [12]:
result2 = result2.compute()

In [13]:
result

Unnamed: 0,embedding_0,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_1526,embedding_1527,embedding_1528,embedding_1529,embedding_1530,embedding_1531,embedding_1532,embedding_1533,embedding_1534,embedding_1535
1135254,-0.290532,-0.063103,0.026229,-0.118062,-0.053797,0.001423,0.179665,0.128031,-0.258895,-0.140017,...,0.108044,-0.016486,-0.091957,-0.185485,0.208071,-0.012939,-0.199160,0.007739,0.308227,0.313024
507395,-0.285455,-0.111262,-0.113142,-0.158004,-0.117075,0.013010,0.261024,0.233638,-0.136493,-0.161162,...,0.081286,-0.443307,-0.073551,-0.483737,0.111754,-0.181978,0.077074,-0.073647,0.051551,0.116166
421122,-0.178601,0.053528,-0.211663,-0.139839,0.103700,-0.052566,0.326615,0.228888,-0.437036,-0.257336,...,0.257401,-0.436189,-0.060040,-0.092231,0.288025,-0.158821,-0.056691,-0.083708,-0.046926,0.269202
977862,-0.152321,-0.181627,-0.028486,-0.103172,-0.212782,-0.040475,0.177514,0.349715,-0.080296,-0.143331,...,0.073216,-0.087939,0.008919,-0.086273,0.165844,-0.034154,-0.117882,0.007349,0.295800,0.228783
2186808,-0.246065,-0.205255,-0.052422,-0.045568,0.063240,-0.169057,0.067012,0.268431,-0.144231,-0.409057,...,0.308713,-0.443466,0.051421,-0.294180,0.144461,0.061179,-0.215118,-0.250605,-0.057553,0.176653
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
904339,-0.111236,-0.239602,-0.068928,-0.018820,-0.234648,0.094480,0.295593,0.181494,-0.154788,-0.032624,...,0.140249,-0.542634,0.024627,-0.191631,0.156213,-0.042998,-0.054532,-0.069859,0.013905,0.149266
973076,-0.280886,-0.076505,0.004162,-0.162118,-0.040654,-0.014583,0.137229,0.208141,-0.280146,-0.179859,...,-0.007630,-0.461214,-0.157896,-0.087187,0.232041,-0.056665,-0.133207,-0.189570,0.073114,0.221584
1656276,-0.221028,0.025194,-0.078133,-0.123258,-0.090522,-0.171234,0.167856,0.217071,-0.198434,0.038669,...,0.093497,-0.406056,0.004769,-0.291771,0.251960,-0.204405,-0.226528,-0.396483,0.053727,0.434428
897766,-0.107969,-0.252370,0.010800,-0.037156,-0.121543,-0.076469,0.267313,0.412872,-0.231740,-0.202620,...,0.180342,-0.311242,-0.029214,-0.283201,0.087618,0.084257,-0.113983,-0.089989,0.035200,0.227009


In [14]:
task_2_train = task_2_train.compute()
task_2_test = task_2_test.compute()

In [15]:
class FullyConnected(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(FullyConnected, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.dropout = nn.Dropout(p=0.1)
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        pooled_output_size = hidden_size // 2
        self.fc2 = nn.Linear(pooled_output_size, num_layers)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = x.unsqueeze(1)

        x = self.maxpool(x)
        x = x.view(x.size(0), -1)

        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [16]:
input_size = 1536
hidden_size = 128
num_layers = 4

model = FullyConnected(input_size, hidden_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5, eps=1e-8, weight_decay=0.01)

In [17]:
subset_indices = result.index
subset_indices = subset_indices.astype(int)
task_2_train_indicies = task_2_train.index.astype(int)

valid_indicies = task_2_train_indicies[task_2_train_indicies.isin(subset_indices)]
subset_labels = task_2_train.loc[valid_indicies, 'encoded_labels']
subset_labels = subset_labels.to_frame()

In [18]:
subset_indices2 = result2.index
subset_indices2 = subset_indices2.astype(int)
task_2_test_indices = task_2_test.index.astype(int)
 
valid_indices2 = task_2_test_indices[task_2_test_indices.isin(subset_indices2)]
subset_labels2 = task_2_test.loc[valid_indices2, 'encoded_labels'] 
subset_labels2 = subset_labels2.to_frame()

In [19]:
result = result.sort_index()
result2 = result2.sort_index()

In [20]:
class ESCIDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings.values

        print(f'Shape of embeddings: {self.embeddings.shape}')
        self.labels = labels

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

In [21]:
train_dataset = ESCIDataset(embeddings=result, labels=subset_labels['encoded_labels'].values)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

Shape of embeddings: (10000, 1536)


In [22]:
test_dataset = ESCIDataset(embeddings=result2, labels=subset_labels2['encoded_labels'].values)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Shape of embeddings: (10000, 1536)


In [23]:
print("Type of embeddings:", type(train_dataset.embeddings))
print("Type of labels:", type(train_dataset.labels))

Type of embeddings: <class 'numpy.ndarray'>
Type of labels: <class 'numpy.ndarray'>


In [24]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=8):
    model.train()  # set model to training mode
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch_idx, (embeddings, labels) in enumerate(train_loader):
            embeddings, labels = embeddings.to(device), labels.to(device)

            optimizer.zero_grad()  # Clear previous gradients
            outputs = model(embeddings.float())  # Forward pass
            # converting the labels to long in order to 
            labels = labels.long()
            # calculate the loss 
            loss = criterion(outputs, labels) 
            # backpropogation 
            loss.backward() 
            # updating the weights 
            optimizer.step()  

            # add up the loss 
            epoch_loss += loss.item()  

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(train_loader):.4f}")

# run the training model with the 10000 samples 
train_model(model, train_loader, criterion, optimizer)

Epoch 1/8, Loss: 0.9037
Epoch 2/8, Loss: 0.8495
Epoch 3/8, Loss: 0.8400
Epoch 4/8, Loss: 0.8356
Epoch 5/8, Loss: 0.8326
Epoch 6/8, Loss: 0.8287
Epoch 7/8, Loss: 0.8273
Epoch 8/8, Loss: 0.8256


In [25]:
def evaluate_model(test_loader, model):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            
    # evaluate on the f1 score with micro averages
    return f1_score(all_labels, all_preds, average='micro')

In [26]:
f1 = evaluate_model(test_loader, model)
print(f'Micro F1 Score: {f1:.4f}')

Micro F1 Score: 0.6517


In [27]:
def evaluate_and_capture_mismatches(test_loader, model, task_2_test):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # convert task_2_test to pandas df if it's a dask df
    if hasattr(task_2_test, 'compute'):
        test_df = task_2_test[['query', 'product_title', 'encoded_labels']].compute()
    else:
        test_df = task_2_test[['query', 'product_title', 'encoded_labels']]

    test_df['predicted_label'] = all_preds
    test_df['true_label'] = all_labels
    
    mismatch_df = test_df[test_df['true_label'] != test_df['predicted_label']]
    
    return mismatch_df

mismatch_df = evaluate_and_capture_mismatches(test_loader, model, task_2_test)

In [30]:
# count top 10mismatches per query
mismatch_counts_per_query = mismatch_df['query'].value_counts().head(10) 
mismatch_counts_per_product = mismatch_df['product_title'].value_counts().head(10)

all_text = ' '.join(mismatch_df['query'].tolist() + mismatch_df['product_title'].tolist())
word_counts = Counter(all_text.split()).most_common(10)  # Top 10 common words

print("Top 10 queries with the most mismatches:\n", mismatch_counts_per_query)
print("\nTop 10 most common words in mismatched entries:\n", word_counts)

Top 10 queries with the most mismatches:
 query
pet boundry flags                           4
lv coin purse                               3
lemon essential oils for body butter        3
gucci disco                                 3
tecnu wipes                                 3
die grinder set                             3
macbook air                                 3
toothpaste without sodium lauryl sulfate    3
taupe blackout curtains                     3
electric motor mongose                      3
Name: count, dtype: int64[pyarrow]

Top 10 most common words in mismatched entries:
 [('for', 1681), ('-', 1084), ('with', 981), ('and', 881), ('&', 412), ('of', 357), ('Black', 261), ('|', 258), ('2', 252), ('to', 218)]


In [32]:
print(device)

cpu
