In [1]:
import numpy as np
import pandas as pd

K_LIST = [1, 2, 3, 5, 10, 15, 20]

def _ensure_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure that all required columns exist in the DataFrame.
    Missing columns are added with default values.
    """
    required_columns = {
        "cid": None, "pb_size": 0, "item_id": None, "ptcode": None,
        "label": 0, "urgency": 99, "relevance": 0,
        "elapse_days": 400, "score": -1, "is_consumable": 0
    }
    d = df.copy()
    for col, default_value in required_columns.items():
        if col not in d.columns:
            print(f"Column {col} not found in DataFrame. Adding with default value {default_value}.")
            d[col] = default_value
    return d

def _prep_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Prepare and clean the input DataFrame:
    - Ensure all columns exist
    - Cast types
    - Sort values for proper ranking
    """
    d = _ensure_columns(df)
    d["label"] = d["label"].astype(int)
    d["urgency"] = d["urgency"].astype(int)
    d["relevance"] = d["relevance"].astype(float)
    d["score"] = pd.to_numeric(d["score"], errors="coerce")
    # Sort by cid, then score descending, then urgency ascending (lower is more urgent)
    d = d.sort_values(["cid", "score", "urgency"], ascending=[True, False, True])
    return d

def _dcg_from_rels(rels: np.ndarray) -> float:
    """
    Compute discounted cumulative gain (DCG) from a vector of relevances.
    rels: An array of relevance scores, ranked in order.
    """
    # Standard DCG: sum rel_i / log2(i+1), i starts at 1
    if rels.size == 0:
        return 0.0
    denom = np.log2(np.arange(2, rels.size + 2))  # log2(2..k+1)
    return float((rels.astype(float) / denom).sum())

def _group_arrays(d: pd.DataFrame):
    """
    Precompute per-cid arrays in model-ranked order.
    Returns list of tuples: (cid, y, u, rel, item_ids, total_purchased)
        - y: label array (binary, purchased)
        - u: urgency
        - rel: relevance
        - item_ids: item_id array
        - total_purchased: number of positives in y
    """
    grouped = []
    for cid, g in d.groupby("cid", sort=False):
        y = g["label"].to_numpy(dtype=int)
        u = g["urgency"].to_numpy(dtype=int)
        rel = g["relevance"].to_numpy(dtype=float)
        item_ids = g["item_id"].to_numpy()
        total_purchased = int(y.sum())
        grouped.append((cid, y, u, rel, item_ids, total_purchased))
    return grouped

def metric_precision_at_k(grouped, k: int) -> float:
    """
    Compute mean precision at k across all groups (CIDs).
    Only groups with at least k items and at least 1 purchase are included.
    """
    vals = []
    for _, y, _, _, _, _ in grouped:
        if y.size < k:
            continue
        tp = int(y[:k].sum())
        vals.append(tp / k)
    return float(np.mean(vals)) if vals else np.nan

def metric_recall_at_k(grouped, k: int) -> float:
    """
    Compute mean recall at k across all groups (CIDs).
    Only groups with at least k items and at least 1 purchase are included.
    """
    vals = []
    for _, y, _, _, _, total_purchased in grouped:
        if y.size < k:
            continue
        tp = int(y[:k].sum())
        vals.append(tp / total_purchased)
    return float(np.mean(vals)) if vals else np.nan

def metric_mean_pos_rank_at_k(grouped, k: int) -> float:
    """
    Compute the mean rank of true positives (purchased) in the top k items, averaged across groups.
    """
    vals = []
    for _, y, _, _, _, _ in grouped:
        if y.size < k:
            continue
        yk = y[:k]
        pos_idx = np.flatnonzero(yk == 1)
        if pos_idx.size > 0:
            # pos_idx is zero-based, so add 1 for rank
            vals.append(float((pos_idx + 1).mean()))
    return float(np.mean(vals)) if vals else np.nan

def metric_mean_urgency_pos_at_k(grouped, k: int) -> float:
    """
    Compute the mean urgency of true positives (purchased) in the top k items, averaged across groups.
    """
    vals = []
    for _, y, u, _, _, _ in grouped:
        if y.size < k:
            continue
        yk = y[:k]
        uk = u[:k]
        pos_idx = np.flatnonzero(yk == 1)
        if pos_idx.size > 0:
            vals.append(float(uk[pos_idx].mean()))
    return float(np.mean(vals)) if vals else np.nan

def metric_ndcg_at_k_standard(grouped, k: int) -> float:
    """
    Compute (mean) standard nDCG at k.
    Uses any relevance scores provided.
    """
    vals = []
    for _, y, _, rel, _, _ in grouped:
        if y.size < k:
            continue
        relk = rel[:k]
        dcg = _dcg_from_rels(relk)
        ideal = np.sort(rel)[::-1][:k]  # Sorted descending
        idcg = _dcg_from_rels(ideal)
        vals.append(dcg / idcg if idcg > 0 else 0.0)
    return float(np.mean(vals)) if vals else np.nan

def metric_ndcg_at_k_binary(grouped, k: int) -> float:
    """
    Compute mean binary nDCG@k across CIDs.
    Relevance = label (0/1).
    """
    vals = []
    for _, y, _, _, _, total_purchased in grouped:
        if y.size < k:
            continue

        # DCG using binary relevance
        yk = y[:k]
        dcg = _dcg_from_rels(yk)

        # IDCG = ideal ordering of binary labels
        ideal_len = min(k, total_purchased)
        ideal = np.ones(ideal_len)
        idcg = _dcg_from_rels(ideal)

        vals.append(dcg / idcg if idcg > 0 else 0.0)

    return float(np.mean(vals)) if vals else np.nan


def compute_eval_metrics(
    df: pd.DataFrame,
    k_list=K_LIST,
    pb_size_threshold=None
) -> pd.DataFrame:
    """
    Main evaluation function.
    For each k in k_list, aggregate various metrics at k across all CIDs.
    Optionally restrict to CIDs with PB size >= pb_size_threshold.
    Returns a DataFrame with aggregated metrics per k.
    """
    d = _prep_df(df)

    # Filter by PB size if required
    if pb_size_threshold is not None:
        d = d[d["pb_size"] >= pb_size_threshold].copy()

    # If no data, output empty dataframe with columns.
    if d.empty:
        return pd.DataFrame(columns=[
            "k", "num_cids", "total_purch", "total_pos", "true_pos", "unique_items",
            "precision", "recall", "avg_pos_rank", "avg_pos_urgency", "ndcg"
        ])

    grouped_raw = _group_arrays(d)
    grouped = [row for row in grouped_raw if row[-1] > 0]
    
    d["relevance"] = d["relevance"] * d["label"] # ensure relevance is 0 for non-purchased items

    # helpful aggregates

    rows = []
    for k in k_list:
        cid_count = 0         # Number of CIDs included for this k
        total_pos = 0         # Total possible positives (k per cid if included)
        true_pos = 0          # Number of correctly predicted positives (top-k hit purchases)
        total_purch_all = 0   # Total purchases across included CIDs
        unique_items = set()  # Set to track unique items in top-k
        for _, y, _, _, item_ids, total_purchased in grouped:
            if y.size < k:
                continue
            cid_count += 1
            total_pos += k
            true_pos += int(y[:k].sum())
            total_purch_all += total_purchased
            # Add unique items from top-k
            unique_items.update(item_ids[:k])

        rows.append({
            "k": k,
            "num_cids": cid_count,
            "total_purch": float(total_purch_all),
            "total_pos": float(total_pos),
            "true_pos": float(true_pos),
            "unique_items": len(unique_items),
            "precision": metric_precision_at_k(grouped, k),
            "recall": metric_recall_at_k(grouped, k),
            "ndcg_binary": metric_ndcg_at_k_binary(grouped, k),
            "ndcg_relevance": metric_ndcg_at_k_standard(grouped, k),
            "avg_pos_rank": metric_mean_pos_rank_at_k(grouped, k),
            "avg_pos_urgency": metric_mean_urgency_pos_at_k(grouped, k),
        })

    return pd.DataFrame(rows)

## Walmart

# Ablation

In [2]:
## Walmart CNN BCE (Ablation)
json_path = "gs://p13n-storage2/user/y0c07th/exp/cnn_set_pb/inference-job-output/7559272647188545536/predictions/predictions.json"

predictions = pd.read_json(json_path)
predictions_exploded = predictions.explode(['top_500_prediction_items','top_500_prediction_scores'])[['cid','top_500_prediction_items','top_500_prediction_scores']]
predictions_exploded.columns = ['cid','item_id','score']

label = pd.read_parquet('gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_label')

print('Walmart CNN Ranker (Ablation)')
print('before label join', predictions_exploded.shape)
predictions_labeled = predictions_exploded.merge(label, on=['cid','item_id'], how='outer').fillna({'label':0})
print('after label outer join', predictions_labeled.shape)

compute_eval_metrics(predictions_labeled, k_list=K_LIST, pb_size_threshold=None)[['k','num_cids','total_purch','total_pos','true_pos','precision','recall','ndcg_binary','unique_items']]

Walmart CNN Ranker (Ablation)
before label join (20805506, 3)
after label outer join (21715897, 4)
Column pb_size not found in DataFrame. Adding with default value 0.
Column ptcode not found in DataFrame. Adding with default value None.
Column urgency not found in DataFrame. Adding with default value 99.
Column relevance not found in DataFrame. Adding with default value 0.
Column elapse_days not found in DataFrame. Adding with default value 400.
Column is_consumable not found in DataFrame. Adding with default value 0.


Unnamed: 0,k,num_cids,total_purch,total_pos,true_pos,precision,recall,ndcg_binary,unique_items
0,1,79064,1836082.0,79064.0,37786.0,0.477917,0.041996,0.477917,15649
1,2,78865,1835883.0,157730.0,68049.0,0.431427,0.064883,0.446494,23628
2,3,78628,1835600.0,235884.0,94576.0,0.400943,0.082044,0.425673,29101
3,5,78175,1834834.0,390875.0,139917.0,0.357958,0.109519,0.396983,36674
4,10,76872,1831182.0,768720.0,227218.0,0.29558,0.156954,0.35762,47765
5,15,75504,1825935.0,1132560.0,293290.0,0.258962,0.18998,0.337605,54588
6,20,74127,1819010.0,1482540.0,345537.0,0.233071,0.215343,0.326181,59383


In [3]:
## Walmart CNN Ranker (Ablation)
json_path = "gs://p13n-storage2/user/y0c07th/exp/cnn_set_pb/inference-job-output/5981279947283496960/predictions/predictions.json"

predictions = pd.read_json(json_path)
predictions_exploded = predictions.explode(['top_500_prediction_items','top_500_prediction_scores'])[['cid','top_500_prediction_items','top_500_prediction_scores']]
predictions_exploded.columns = ['cid','item_id','score']

label = pd.read_parquet('gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_label')

print('Walmart CNN Ranker (Ablation)')
print('before label join', predictions_exploded.shape)
predictions_labeled = predictions_exploded.merge(label, on=['cid','item_id'], how='outer').fillna({'label':0})
print('after label outer join', predictions_labeled.shape)

compute_eval_metrics(predictions_labeled, k_list=K_LIST, pb_size_threshold=None)[['k','num_cids','total_purch','total_pos','true_pos','precision','recall','ndcg_binary','unique_items']]

Walmart CNN Ranker (Ablation)
before label join (20805506, 3)
after label outer join (21719637, 4)
Column pb_size not found in DataFrame. Adding with default value 0.
Column ptcode not found in DataFrame. Adding with default value None.
Column urgency not found in DataFrame. Adding with default value 99.
Column relevance not found in DataFrame. Adding with default value 0.
Column elapse_days not found in DataFrame. Adding with default value 400.
Column is_consumable not found in DataFrame. Adding with default value 0.


Unnamed: 0,k,num_cids,total_purch,total_pos,true_pos,precision,recall,ndcg_binary,unique_items
0,1,79064,1836082.0,79064.0,37001.0,0.467988,0.040597,0.467988,15450
1,2,78865,1835883.0,157730.0,66929.0,0.424326,0.063109,0.438536,23204
2,3,78628,1835600.0,235884.0,93158.0,0.394931,0.080071,0.418488,28861
3,5,78175,1834834.0,390875.0,138231.0,0.353645,0.107,0.390861,36789
4,10,76872,1831182.0,768720.0,225241.0,0.293008,0.154196,0.352683,49030
5,15,75504,1825935.0,1132560.0,291261.0,0.25717,0.186819,0.333084,56825
6,20,74127,1819010.0,1482540.0,343543.0,0.231726,0.211855,0.321897,62605


In [14]:
## Walmart No SetTran PMA (Ablation)
prediction_path = 'gs://p13n-storage2/user/y0c07th/exp/cnn_set_pb/inference-job-output/884936350158028800/predictions/predictions.parquet'
# input_data_path = 'gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_df'

predictions = pd.read_parquet(prediction_path)
predictions_exploded = predictions.explode(['top_500_prediction_items','top_500_prediction_scores'])[['cid','top_500_prediction_items','top_500_prediction_scores']]
predictions_exploded.columns = ['cid','item_id','score']

label = pd.read_parquet('gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_label')

# print('Walmart No SetTran PMA (Ablation)')
print('before label join', predictions_exploded.shape)
predictions_labeled = predictions_exploded.merge(label, on=['cid','item_id'], how='outer').fillna({'label':0})
print('after label outer join', predictions_labeled.shape)

compute_eval_metrics(predictions_labeled, k_list=K_LIST, pb_size_threshold=None)[['k','num_cids','total_purch','total_pos','true_pos','precision','recall','ndcg_binary','unique_items']]

before label join (20805506, 3)
after label outer join (21716327, 4)
Column pb_size not found in DataFrame. Adding with default value 0.
Column ptcode not found in DataFrame. Adding with default value None.
Column urgency not found in DataFrame. Adding with default value 99.
Column relevance not found in DataFrame. Adding with default value 0.
Column elapse_days not found in DataFrame. Adding with default value 400.
Column is_consumable not found in DataFrame. Adding with default value 0.


Unnamed: 0,k,num_cids,total_purch,total_pos,true_pos,precision,recall,ndcg_binary,unique_items
0,1,79064,1836082.0,79064.0,37896.0,0.479308,0.041845,0.479308,14757
1,2,78865,1835883.0,157730.0,68182.0,0.43227,0.064637,0.447429,22437
2,3,78628,1835600.0,235884.0,94574.0,0.400934,0.081567,0.425916,27752
3,5,78175,1834834.0,390875.0,139690.0,0.357378,0.108713,0.396618,35292
4,10,76872,1831182.0,768720.0,227360.0,0.295764,0.156296,0.357576,46829
5,15,75504,1825935.0,1132560.0,293373.0,0.259035,0.188961,0.337339,53989
6,20,74127,1819010.0,1482540.0,345386.0,0.232969,0.214255,0.325746,59038


In [11]:
## Walmart No CNN (Ablation)
prediction_path = 'gs://p13n-storage2/user/y0c07th/exp/cnn_set_pb/inference-job-output/685775211949195264/predictions/predictions.parquet'
# input_data_path = 'gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_df'

predictions = pd.read_parquet(prediction_path)
predictions_exploded = predictions.explode(['top_500_prediction_items','top_500_prediction_scores'])[['cid','top_500_prediction_items','top_500_prediction_scores']]
predictions_exploded.columns = ['cid','item_id','score']

label = pd.read_parquet('gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_label')

print('Walmart No CNN (Ablation)')
print('before label join', predictions_exploded.shape)
predictions_labeled = predictions_exploded.merge(label, on=['cid','item_id'], how='outer').fillna({'label':0})
print('after label outer join', predictions_labeled.shape)

compute_eval_metrics(predictions_labeled, k_list=K_LIST, pb_size_threshold=None)[['k','num_cids','total_purch','total_pos','true_pos','precision','recall','ndcg_binary','unique_items']]

Column pb_size not found in DataFrame. Adding with default value 0.
Column ptcode not found in DataFrame. Adding with default value None.
Column urgency not found in DataFrame. Adding with default value 99.
Column relevance not found in DataFrame. Adding with default value 0.
Column elapse_days not found in DataFrame. Adding with default value 400.
Column is_consumable not found in DataFrame. Adding with default value 0.


Unnamed: 0,k,num_cids,total_purch,total_pos,true_pos,precision,recall,ndcg_binary,unique_items
0,1,79064,1836082.0,79064.0,19471.0,0.246269,0.023437,0.246269,6806
1,2,78865,1835883.0,157730.0,35164.0,0.222938,0.036332,0.231222,12537
2,3,78628,1835600.0,235884.0,47855.0,0.202875,0.046056,0.217763,17480
3,5,78175,1834834.0,390875.0,69113.0,0.176816,0.061514,0.200246,25634
4,10,76872,1831182.0,768720.0,113303.0,0.147392,0.0904,0.181731,39114
5,15,75504,1825935.0,1132560.0,150024.0,0.132465,0.112025,0.174466,47976
6,20,74127,1819010.0,1482540.0,181994.0,0.122758,0.130116,0.171671,53981


In [None]:
# PROD PPM
prod_ppm = pd.read_parquet('gs://p13n-storage2/user/y0c07th/tmp/ppm_prod_predictions_1015.parquet')
# prod_ppm = prod_ppm[prod_ppm['cid'].isin(cid_list)]
prod_ppm = prod_ppm[~prod_ppm['item_id'].isna()]
prod_ppm['item_id'] = prod_ppm['item_id'].astype(int)
prod_ppm = prod_ppm[prod_ppm['g3']>0]
prod_ppm = prod_ppm.sort_values('g3', ascending=False)
prod_ppm['constant'] = 1
prod_ppm['rank'] = prod_ppm.groupby('cid')['constant'].cumsum()
prod_ppm = prod_ppm[prod_ppm['rank'] <= 100]

prod_ppm.columns = ['cid','item_id','score']

label = pd.read_parquet('gs://p13n-storage2/data/features/pb_inspirational/cnn_st_nbr/data_new/2025-10-15/test_label')
predictions_labeled = prod_ppm.merge(label, on=['cid','item_id'], how='outer').fillna({'label':0})

compute_eval_metrics(predictions_labeled, k_list=K_LIST, pb_size_threshold=None)[['k','num_cids','total_purch','total_pos','true_pos','precision','recall','ndcg_binary','unique_items']]

Column pb_size not found in DataFrame. Adding with default value 0.
Column ptcode not found in DataFrame. Adding with default value None.
Column urgency not found in DataFrame. Adding with default value 99.
Column relevance not found in DataFrame. Adding with default value 0.
Column elapse_days not found in DataFrame. Adding with default value 400.
Column is_consumable not found in DataFrame. Adding with default value 0.


Unnamed: 0,k,num_cids,total_purch,total_pos,true_pos,precision,recall,ndcg_binary,unique_items
0,1,79064,1836082.0,79064.0,33211.0,0.420052,0.036509,0.420052,14918
1,2,78780,1835798.0,157560.0,60837.0,0.38612,0.056797,0.397218,22335
2,3,78479,1835379.0,235437.0,85318.0,0.362381,0.072806,0.3811,27645
3,5,77957,1834323.0,389785.0,128226.0,0.328966,0.098917,0.359036,34861
4,10,76567,1829799.0,765670.0,212095.0,0.277006,0.144785,0.326843,45993
5,15,75137,1823385.0,1127055.0,275668.0,0.244591,0.176601,0.309448,52465
6,20,73779,1815362.0,1475580.0,326562.0,0.221311,0.201537,0.299516,57210


: 