In [1]:
import pandas as pd
import numpy as np
from caboose_nbr.tifuknn import TIFUKNN

In [2]:
train_baskets = pd.read_csv('data/instacart_30k/train_baskets.csv.gz')
test_baskets = pd.read_csv('data/instacart_30k/test_baskets.csv')
valid_baskets = pd.read_csv('data/instacart_30k/valid_baskets.csv')

In [3]:
aisles = pd.read_csv('data/instacart_30k/aisles.csv')
products = pd.read_csv('data/instacart_30k/products.csv')
products_with_aisles = products.merge(aisles, on='aisle_id')

In [4]:
train_baskets_with_aisles = train_baskets.merge(products_with_aisles, left_on="item_id", right_on="product_id")

In [5]:
# Works for seed=42 and sample_size=100,250,1000

seed = 1312
sample_size = 1000
sensitive_aisles = [82, 92, 102, 56]

np.random.seed(seed)

baby_baskets = train_baskets_with_aisles[train_baskets_with_aisles.aisle_id.isin(sensitive_aisles)]
all_baby_users = baby_baskets.user_id.unique()
baby_users = np.array(np.random.choice(all_baby_users, sample_size))
baby_user_baskets = train_baskets_with_aisles[train_baskets_with_aisles.user_id.isin(baby_users)]

other_aisles = [aisle for aisle in baby_user_baskets.aisle_id.unique() if aisle not in sensitive_aisles]

all_nonbaby_users  = train_baskets_with_aisles[(train_baskets_with_aisles.aisle_id.isin(other_aisles)) \
                          & (~train_baskets_with_aisles.user_id.isin(all_baby_users))].user_id.unique()


nonbaby_users = np.array(np.random.choice(all_nonbaby_users, sample_size))

In [6]:
users = np.concatenate((baby_users, nonbaby_users))
sampled_train_baskets = train_baskets[train_baskets['user_id'].isin(users)]
sampled_test_baskets = test_baskets[test_baskets['user_id'].isin(users)]
sampled_valid_baskets = valid_baskets[valid_baskets['user_id'].isin(users)]

In [7]:
sampled_train_baskets[['user_id','item_id']].drop_duplicates().shape

(159382, 2)

In [8]:
tifu_caboose = TIFUKNN(sampled_train_baskets, sampled_test_baskets, sampled_valid_baskets, 'caboose')
tifu_caboose.train()

number of test users: 1881
filtered items: 22076
initial data processing
item count: 8702
compute basket reps
10000  baskets passed
20000  baskets passed
30000  baskets passed
compute user reps 1881
1000  users passed
(1881, 8702)
start of knn
knn finished


--Creating transpose of R...
--Computing row norms...
--Configuring for top-k -- num_threads: 8; pinning? false;
--Scheduling parallel top-k computation...


In [9]:
baby_items = set(baby_baskets.item_id.unique())

In [10]:
for user in baby_users:
    predictions = tifu_caboose.predict_for_user(user, 10)
    predicted_baby_items = set(predictions) & baby_items
    has_baby_items = len(predicted_baby_items) > 0
    if has_baby_items:
        chosen_users_items = sampled_train_baskets[sampled_train_baskets.user_id==user].item_id.unique()
        chosen_users_baby_items = set(chosen_users_items) & baby_items

        to_forget = [(user, item) for item in chosen_users_baby_items]
        tifu_caboose.forget_interactions(to_forget)
        predictions_after_forget = tifu_caboose.predict_for_user(user, 10)
        remaining_baby_items = set(predictions_after_forget) & baby_items
        print(f'User {user}: ({len(predicted_baby_items)},{len(to_forget)}) --> {len(remaining_baby_items)}')

User 206056: (1,4) --> 0
User 160617: (1,29) --> 0
User 194289: (1,1) --> 0
User 61832: (1,2) --> 0
User 160640: (1,1) --> 0
User 33263: (1,2) --> 0
User 197782: (1,2) --> 0
User 158226: (3,83) --> 0
User 84199: (1,1) --> 0
User 80520: (1,4) --> 0
User 45893: (1,1) --> 0
User 67810: (1,21) --> 0
User 44924: (1,2) --> 0
User 100195: (1,1) --> 0
User 205905: (1,1) --> 0
User 105379: (1,3) --> 0
User 165126: (3,27) --> 0
User 33477: (1,22) --> 0
User 140883: (1,32) --> 0
User 21491: (4,4) --> 0
User 90179: (3,17) --> 0
User 116013: (1,5) --> 0
User 154913: (3,8) --> 0
User 4775: (2,3) --> 0
User 171530: (2,6) --> 0
User 126291: (1,9) --> 0
User 177213: (1,4) --> 0
User 125429: (1,1) --> 0
User 28902: (7,13) --> 0
User 86390: (1,15) --> 0
User 151910: (1,7) --> 0
User 31016: (2,4) --> 0
User 132317: (4,8) --> 0
User 134258: (2,18) --> 0
User 105469: (1,14) --> 0
User 27971: (1,4) --> 0
User 135017: (3,16) --> 0
User 188195: (1,11) --> 0
User 80524: (1,5) --> 0
User 148644: (1,19) --> 0
Use