In [18]:
import os, sys
import numpy as np
import itertools
from collections import defaultdict

from tqdm.notebook import trange
from tqdm import tqdm

In [19]:
for p in ['../spotlight_ext']:
    module_path = os.path.abspath(os.path.join(p))
    if module_path not in sys.path:
        sys.path.append(module_path)

random_state = np.random.RandomState(2020)

In [22]:
from spotlight.cross_validation import random_train_test_split
from spotlight.datasets.movielens import get_movielens_dataset

# get dataset
dataset = get_movielens_dataset(variant='1M')
train, test = random_train_test_split(dataset, random_state=random_state)

max_sequence_length = 20
train = train.to_sequence(max_sequence_length=max_sequence_length)
test = test.to_sequence(max_sequence_length=max_sequence_length)

In [20]:
# load functions from another notebook
%run helpers.ipynb

In [21]:
pooling_model = load_model('pooling')

In [111]:
target_item_pos = [3]

res = []
tuples_length = 3
top_k = 10
model = pooling_model
num_users = 50
materialized_pred_scores = defaultdict(list)

with tqdm(total=len(target_item_pos), desc='target position loop') as pbar:
    for pos in target_item_pos:
#         best_tot_loss_data[pos] = []
        pbar.update(10)

        for user_id in trange(1, num_users, desc='users loop', leave=False):
            seq_size = len(test.sequences[test.user_ids == user_id])

            for j in range(seq_size):
                if all(v > 0 for v in test.sequences[test.user_ids == user_id][j]):
                    items_interacted = test.sequences[test.user_ids == user_id][j]

                    predictions = -model.predict(items_interacted)
                    predictions[items_interacted] = StaticVars.FLOAT_MAX

                    target_item = predictions.argsort()[min(top_k, int(pos)) - 1]
                    
                    for r in range(tuples_length):
                        combs = itertools.combinations(items_interacted, r + 1)
                        for c in combs:
                            preds = -model.predict(list(set(items_interacted).difference(set(c))))
                            preds[items_interacted] = StaticVars.FLOAT_MAX

                            materialized_pred_scores[user_id].append([c, preds[target_item], preds[preds.argsort()[top_k - 1]]])
                            
#                     for r in range(2, tuples_length + 2):
#                         combs = itertools.combinations(items_interacted, r)
#                         for c in combs:
#                             preds = -model.predict(list(set(items_interacted).difference(set(c))))
#                             preds[items_interacted] = StaticVars.FLOAT_MAX
                            
                            

target position loop:   0%|          | 0/1 [00:00<?, ?it/s]

HBox(children=(HTML(value='users loop'), FloatProgress(value=0.0, max=49.0), HTML(value='')))

target position loop: 10it [00:32,  3.23s/it]              


In [112]:
for k, val in materialized_pred_scores.items():
    for comb in val:
        if (comb[1] / comb[2]) < 1: print(k, comb)

8 [(384, 461), -2.361434, -2.3661966]
8 [(384, 478), -2.458649, -2.4738796]
8 [(384, 511), -2.4070444, -2.4686155]
8 [(461, 478), -2.3311832, -2.3760533]
8 [(461, 511), -2.279579, -2.337636]
8 [(236, 511), -2.4137821, -2.440359]
8 [(478, 511), -2.3767934, -2.4386554]
8 [(457, 511), -2.4066133, -2.4261217]
8 [(511, 60), -2.4103146, -2.4145877]
8 [(511, 453), -2.4401903, -2.4441133]
8 [(384, 461, 236), -2.3088708, -2.380344]
8 [(384, 461, 478), -2.2698271, -2.3764615]
8 [(384, 461, 86), -2.3670297, -2.4064393]
8 [(384, 461, 457), -2.3013036, -2.3604395]
8 [(384, 461, 511), -2.2153559, -2.3512402]
8 [(384, 461, 60), -2.30521, -2.3613875]
8 [(384, 461, 443), -2.296328, -2.3234963]
8 [(384, 461, 514), -2.3728878, -2.3791914]
8 [(384, 461, 453), -2.336746, -2.368527]
8 [(384, 461, 444), -2.3593736, -2.3687687]
8 [(384, 236, 478), -2.4114864, -2.4428372]
8 [(384, 236, 511), -2.3570151, -2.4468393]
8 [(384, 478, 86), -2.4696453, -2.5127397]
8 [(384, 478, 457), -2.4039192, -2.4736218]
8 [(384, 