In [1]:
import math
import time
from typing import Callable

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_text
from sklearn.metrics import ndcg_score
from google.cloud import bigquery

2024-04-18 11:34:37.566350: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def get_predict_fn_cern(
    saved_model,
    expected_gains: list[float],
    signature: str = "serving_default",
):
    """Get a predict function which takes in a dataframe
    and returns a relevance score for each row.
    
    Args:
        saved_model: Tensorflow saved model with query and title input
        expected_gains: How much to weight each softmax output in the NDCG gain
        signature: Model serving signature
        
    Returns:
        Function to inference relevance score
    """
    gains = tf.constant([expected_gains], dtype=tf.float32)
    
    def wrapper(df: pd.DataFrame):
        outputs = saved_model.signatures[signature](
            queries=df["query"].to_numpy(),
            titles=df["listingTitle"].to_numpy()
        )
        weighted_scores = outputs["softmax"] * gains
        return tf.reduce_sum(weighted_scores, axis=-1)
    
    return wrapper

In [3]:
def evaluate_model(
    predict_fn: Callable[[pd.DataFrame], np.ndarray],
    df: pd.DataFrame,
    batch_size: int,
) -> pd.DataFrame:
    """Runs a model predict function on a dataframe of inputs and
    returns a float tensor for each row. Returns a dataframe of scores
    for each query, listing pair. Also includes metadata for each prediction.
    
    Args:
        predict_fn: Function to run model inference
        df: Dataframe containing batch of query listing data
        batch_size: Size of inference batches
        
    Returns:
        DataFrame of guid, listingId, and relevance score
    """
    df_batches = np.array_split(df, math.ceil(len(df) / batch_size))
    
    total_batches = len(df_batches)
    metric_dfs = []
    for i, df_batch in enumerate(df_batches):
        print(f"Batch {i}/{total_batches}", end="\r")
        y_true = predict_fn(df_batch)
        metric_df_batch = pd.DataFrame({
            "guid": df_batch["guid"],
            "listingId": df_batch["listingId"],
            "relevanceScore": y_true,
        })
        metric_dfs.append(metric_df_batch)
        
    return pd.concat(metric_dfs) 

In [4]:
model_name = "bert-cern-l24-h1024-a16"
model_path = "gs://training-dev-search-data-jtzn/user/ctran/semantic_relevance/cern4/bert-l24-h1024-a16-batch256-run1/export/saved_model"
# model_path = "gs://training-dev-search-data-jtzn/user/ctran/semantic_relevance/cern4/bert-l2-h128-a2-amazone2-run1/export/saved_model"
date = "2024-04-05"
batch_size = 16
expected_gains = [0.0, 0.5, 0.5, 1.0]
pairs_table_name = "etsy-data-warehouse-prod.search.sem_rel_query_listing_metrics"
requests_table_name = "etsy-data-warehouse-prod.search.sem_rel_requests_metrics"

In [5]:
model = tf.saved_model.load(model_path)
predict_fn = get_predict_fn_cern(model, expected_gains)

In [34]:
client = bigquery.Client()

sql = f"""
    with tmp as (
        SELECT distinct guid
        FROM `etsy-data-warehouse-prod.search.sem_rel_hydrated_daily_requests`
        WHERE date = "{date}"
        LIMIT 50
    )
    SELECT guid, query, listingTitle, listingId, pageNum, rankingRank, retrievalRank, bordaRank
    FROM `etsy-data-warehouse-prod.search.sem_rel_hydrated_daily_requests`
    WHERE date = "{date}"
    AND guid in (select guid from tmp)
"""
df = client.query_and_wait(sql).to_dataframe()
df = df[df["guid"].notna()]
df.loc[df['retrievalRank'].notna(), 'retrievalStage'] = "pre-borda"
df.loc[df['bordaRank'].notna(), 'retrievalStage'] = "post-borda"



In [35]:
df_dedup = df[["guid", "query", "listingTitle", "listingId"]].drop_duplicates()
df_metrics = evaluate_model(predict_fn, df_dedup, batch_size) 

Batch 193/194

In [36]:
df_scores = pd.merge(df, df_metrics, on=["guid", "listingId"])
df_bq_pairs = df_scores[["guid", "query", "listingId", "pageNum", "retrievalStage", "relevanceScore"]]
df_bq_pairs.insert(loc=0, column="date", value=[date] * len(df_bq_pairs))
df_bq_pairs.insert(loc=1, column="modelName", value=[model_name] * len(df_bq_pairs))

In [37]:
df_scores.head()

Unnamed: 0,guid,query,listingTitle,listingId,pageNum,rankingRank,retrievalRank,bordaRank,retrievalStage,relevanceScore
0,363d26c0-a88e-4ce4-8e79-eff3c50c83de,men nak,Knee-Gotiate Short-Sleeve Unisex T-Shirt,1694598290,1.0,1.0,,,,0.500869
1,363d26c0-a88e-4ce4-8e79-eff3c50c83de,men nak,Knee-Gotiate Short-Sleeve Unisex T-Shirt,1694598290,,,132.0,,pre-borda,0.500869
2,363d26c0-a88e-4ce4-8e79-eff3c50c83de,men nak,"Square Gay art poster, naked man, LGBTQ Print,...",1684065726,2.0,37.0,,,,0.70743
3,363d26c0-a88e-4ce4-8e79-eff3c50c83de,men nak,Photo print - Handsome muscular young man in b...,1654143651,1.0,13.0,,,,0.763924
4,363d26c0-a88e-4ce4-8e79-eff3c50c83de,men nak,Gay Photo | Gay Photograpy | Gay Photo Shoot |...,1508957549,1.0,12.0,,,,0.709824


In [38]:
df_grouped = df_scores[np.logical_and(df_scores["pageNum"] == 1, df_scores["rankingRank"] <= 9)].groupby(by=["guid"]).agg({"rankingRank": list, "relevanceScore": list})

In [39]:
df_grouped.head()

Unnamed: 0_level_0,rankingRank,relevanceScore
guid,Unnamed: 1_level_1,Unnamed: 2_level_1
07d94ede-6ac5-4e0e-911f-d27c137e6c2e,"[9, 6, 3, 0, 4, 5, 8, 2, 7, 1]","[0.9790623784065247, 0.9904165267944336, 0.979..."
085bea15-8c6a-4602-9173-ae3503fcfb18,"[3, 7, 4, 5, 2, 0, 8, 9, 6, 1]","[0.596187174320221, 0.4723755717277527, 0.3251..."
089695d1-37d2-4e5a-ad3c-9b0eaaed8d4c,"[9, 1, 3, 8, 5, 4, 6, 7, 2, 0]","[0.9970335364341736, 0.9979644417762756, 0.998..."
08e266d9-2d6d-4dbc-ac9e-1951f77cc719,"[0, 8, 2, 7, 6, 3, 1, 9, 4, 5]","[0.9982101321220398, 0.9994362592697144, 0.999..."
0e15ef18-6199-467f-8eee-b32981ae9670,"[6, 1, 9, 4, 0, 7, 3, 2, 8, 5]","[0.6200137138366699, 0.9964001178741455, 0.714..."


In [50]:
df_grouped[df_grouped.index == "07d94ede-6ac5-4e0e-911f-d27c137e6c2e"].rankingRank.values

array([list([9, 6, 3, 0, 4, 5, 8, 2, 7, 1])], dtype=object)

In [49]:
df_grouped[df_grouped.index == "07d94ede-6ac5-4e0e-911f-d27c137e6c2e"].relevanceScore.values

array([list([0.9790623784065247, 0.9904165267944336, 0.9790623784065247, 0.9904165267944336, 0.9635042548179626, 0.852206826210022, 0.7836074829101562, 0.8845033645629883, 0.9659247398376465, 0.8799769878387451])],
      dtype=object)

In [57]:
ranking_scores = [-x for x in df_grouped[df_grouped.index == "07d94ede-6ac5-4e0e-911f-d27c137e6c2e"].rankingRank.values[0]]
ranking_scores

[-9, -6, -3, 0, -4, -5, -8, -2, -7, -1]

In [61]:
df_scores[df_scores.guid == "07d94ede-6ac5-4e0e-911f-d27c137e6c2e"].sort_values("rankingRank").head(n=20)

Unnamed: 0,guid,query,listingTitle,listingId,pageNum,rankingRank,retrievalRank,bordaRank,retrievalStage,relevanceScore
1541,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,3 x Decorative Artificial Plant Eucalyptus Ar...,1469803986,1,0,,,,0.990417
1581,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Maranta 70cm With or Without Pot,1504818387,1,1,,,,0.879977
1576,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Potted String of Hearts Vines,1181602409,1,2,,,,0.884503
1537,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Fern Hanging Plants fake plant Plas...,1657471635,1,3,,,,0.979062
1544,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Hanging Plants Fake Plant Outdoor W...,1444331324,1,4,,,,0.963504
1551,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Trailing Gunni Eucalyptus - 95cm - ...,1144136408,1,5,,,,0.852207
1521,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,3 x Decorative Artificial Plant Eucalyptus Art...,1642638152,1,6,,,,0.990417
1578,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,"Artificial Hanging Plants, Decorative Artifici...",1511424408,1,7,,,,0.965925
1556,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,"Artificial Eucalyptus Leaf Stem, Artificial Gr...",1312343958,1,8,,,,0.783607
1515,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,artificial garden,Artificial Fern Hanging Plants fake plant Pla...,1503269089,1,9,,,,0.979062


In [51]:

guids = []
ndcgs = []
for row in df_grouped.itertuples():
    if len(row.rankingRank) > 1:
        guids.append(row.Index)
        historical_score = - np.array([row.rankingRank])
        ndcg = ndcg_score([row.relevanceScore], historical_score)
        ndcgs.append(ndcg)
        
df_ndcg = pd.DataFrame({
    "guid": guids,
    "relevanceNDCG": ndcgs,
})

In [60]:
df_ndcg[df_ndcg.guid == "07d94ede-6ac5-4e0e-911f-d27c137e6c2e"]

Unnamed: 0,guid,relevanceNDCG
0,07d94ede-6ac5-4e0e-911f-d27c137e6c2e,0.986212
