In [None]:
import pandas as pd
import numpy as np
from caboose_nbr.tifuknn import TIFUKNN

In [None]:
train_baskets = pd.read_csv('data/instacart_30k/train_baskets.csv.gz')
test_baskets = pd.read_csv('data/instacart_30k/test_baskets.csv')
valid_baskets = pd.read_csv('data/instacart_30k/valid_baskets.csv')

In [None]:
aisles = pd.read_csv('data/instacart_30k/aisles.csv')
products = pd.read_csv('data/instacart_30k/products.csv')
products_with_aisles = products.merge(aisles, on='aisle_id')

In [None]:
train_baskets_with_aisles = train_baskets.merge(products_with_aisles, left_on="item_id", right_on="product_id")

In [None]:
# Works for seed=42 and sample_size=100

seed = 42
sample_size = 250
sensitive_aisles = [82, 92, 102, 56]

np.random.seed(seed)

baby_baskets = train_baskets_with_aisles[train_baskets_with_aisles.aisle_id.isin(sensitive_aisles)]
all_baby_users = baby_baskets.user_id.unique()
baby_users = np.array(np.random.choice(all_baby_users, sample_size))
baby_user_baskets = train_baskets_with_aisles[train_baskets_with_aisles.user_id.isin(baby_users)]

other_aisles = [aisle for aisle in baby_user_baskets.aisle_id.unique() if aisle not in sensitive_aisles]

all_nonbaby_users  = train_baskets_with_aisles[(train_baskets_with_aisles.aisle_id.isin(other_aisles)) \
                          & (~train_baskets_with_aisles.user_id.isin(all_baby_users))].user_id.unique()


nonbaby_users = np.array(np.random.choice(all_nonbaby_users, sample_size))

In [None]:
users = np.concatenate((baby_users, nonbaby_users))
sampled_train_baskets = train_baskets[train_baskets['user_id'].isin(users)]
sampled_test_baskets = test_baskets[test_baskets['user_id'].isin(users)]
sampled_valid_baskets = valid_baskets[valid_baskets['user_id'].isin(users)]

In [None]:
sampled_train_baskets[['user_id','item_id']].drop_duplicates().shape

In [None]:
tifu_caboose = TIFUKNN(sampled_train_baskets, sampled_test_baskets, sampled_valid_baskets, 'caboose')
tifu_caboose.train()

In [None]:
baby_items = set(baby_baskets.item_id.unique())

In [None]:
for user in baby_users:
    predictions = tifu_caboose.predict_for_user(user, 10)
    predicted_baby_items = set(predictions) & baby_items
    has_baby_items = len(predicted_baby_items) > 0
    if has_baby_items:
        print(f'User: {user} has predictions {predictions} with sensitive items {predicted_baby_items}.')
        
        chosen_users_items = sampled_train_baskets[sampled_train_baskets.user_id==user].item_id.unique()
        chosen_users_baby_items = set(chosen_users_items) & baby_items

        to_forget = [(user, item) for item in chosen_users_baby_items]
        tifu_caboose.forget_interactions(to_forget)
        predictions_after_forget = tifu_caboose.predict_for_user(user, 10)
        remaining_baby_items = set(predictions_after_forget) & baby_items
        print(f'Number of sensitive items in prediction after forgetting {len(to_forget)} interactions: {len(remaining_baby_items)}')        
        print("---")