In [1]:
%load_ext autoreload
%autoreload 2

### Import dependencies

In [2]:
import os
import ast
import sys
import time
import json
import gdown
import random
import numpy as np
import pandas as pd
from math import ceil
from sklearn.preprocessing import StandardScaler

In [3]:
module_path = os.path.abspath(os.path.join('../src'))
if module_path not in sys.path:
    sys.path.append(module_path)

from training import *
from utils import User, VisualSimilarityHandler, get_decaying_learning_rates, load_embeddings_and_ids

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


###  Load pre-trained image embeddings

In [4]:
resnet_embedding_url = 'https://drive.google.com/uc?id=1fIplLP0Oyuilv-VgkbgAUzON7fKc2NUX'
inventory_url = 'https://drive.google.com/uc?id=1Eskmof-qehDlp0-aiw60dP0dFVs6f2Fa'
purchases_url = 'https://drive.google.com/uc?id=1tSTFCOj5WhQQFNPQsknfGF9pG1Cud7f4'
clusters_url = 'https://drive.google.com/uc?id=19eHFoWexE-iqvpQZHDimAZqTNLdqVGOg'

In [5]:
data_path = '../data/'
try:
    os.mkdir(data_path)
except FileExistsError:
    pass

gdown.download(resnet_embedding_url, data_path + 'resnet_embeddings.npy')
gdown.download(inventory_url, data_path + 'inventory.csv')
gdown.download(purchases_url, data_path + 'purchases.csv')
gdown.download(clusters_url, data_path + 'clusters.json')

Downloading...
From: https://drive.google.com/uc?id=1fIplLP0Oyuilv-VgkbgAUzON7fKc2NUX
To: C:\Users\victor_accete\CuratorNet\data\resnet_embeddings.npy
100%|███████████████████████████████████████████████████████████████████████████████| 272M/272M [00:10<00:00, 25.6MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Eskmof-qehDlp0-aiw60dP0dFVs6f2Fa
To: C:\Users\victor_accete\CuratorNet\data\inventory.csv
100%|███████████████████████████████████████████████████████████████████████████████| 550k/550k [00:00<00:00, 1.57MB/s]
Downloading...
From: https://drive.google.com/uc?id=1tSTFCOj5WhQQFNPQsknfGF9pG1Cud7f4
To: C:\Users\victor_accete\CuratorNet\data\purchases.csv
100%|███████████████████████████████████████████████████████████████████████████████| 429k/429k [00:00<00:00, 1.10MB/s]
Downloading...
From: https://drive.google.com/uc?id=19eHFoWexE-iqvpQZHDimAZqTNLdqVGOg
To: C:\Users\victor_accete\CuratorNet\data\clusters.json
100%|███████████████████████████████████████████████████████

'../data/clusters.json'

In [6]:
resnet50 = load_embeddings_and_ids(data_path + 'resnet_embeddings.npy')

###  Concatenate embeddings + z-score normalization

In [7]:
embedding_list = [resnet50,]

In [8]:
artwork_ids_set = set()
for embedding in embedding_list:
    artwork_ids_set.update(embedding['index2id'].values())
artwork_ids = list(artwork_ids_set)
artwork_id2index = {_id:i for i,_id in enumerate(artwork_ids)}
n_artworks = len(artwork_ids)

In [9]:
featmat_list = [tmp['featmat'] for tmp in embedding_list]
id2index_list = [tmp['id2index'] for tmp in embedding_list]
concat_featmat = resnet50['featmat']

In [10]:
concat_featmat = StandardScaler().fit_transform(concat_featmat)

###  Load clusters

In [11]:
with open(data_path + 'clusters.json') as f:
    artworkId2clusterId = json.load(f)
cluster_ids = np.full((n_artworks,), -1, dtype=int)
for k, v in artworkId2clusterId.items():
    cluster_ids[artwork_id2index[k]] = v

In [12]:
cluster_ids.min(), cluster_ids.max(), cluster_ids.shape

(0, 99, (13297,))

In [13]:
n_clusters = len(set(cluster_ids))

In [14]:
clusterId2artworkIndexes = [[] for _ in range(n_clusters)]
for i, cluster_id in enumerate(cluster_ids):
    clusterId2artworkIndexes[cluster_id].append(i)

###  Load PCA200 embeddings

In [15]:
from sklearn.decomposition import PCA

In [16]:
pca = PCA(n_components=2, svd_solver='full')
pca200_embeddings = pca.fit_transform(concat_featmat)

###  Load transactions

In [17]:
sales_df = pd.read_csv(data_path + 'purchases.csv')
artworks_df = pd.read_csv(data_path + 'inventory.csv')

In [18]:
artist2index = {a:i for i, a in enumerate(artworks_df.artist_id_hash.unique())}
index2artist = {i:a for i, a in enumerate(artworks_df.artist_id_hash.unique())}

In [19]:
artist_ids = np.full((n_artworks,), -1, dtype=int)
failed = []
for _artworkId, _artistId in zip(artworks_df.artwork_id_hash, artworks_df.artist_id_hash):
    i = artwork_id2index[_artworkId]
    artist_ids[i] = artist2index[_artistId]

In [20]:
artistId2artworkIndexes = dict()
for i, _artistId in enumerate(artist_ids):
    if _artistId == -1:
        continue
    try:
        artistId2artworkIndexes[_artistId].append(i)
    except KeyError:
        artistId2artworkIndexes[_artistId] = [i]

### Collect transactions per user (making sure we hide the last nonfirst purchase basket per user)

#### create list of users

In [21]:
user_ids = sales_df.user_id_hash.unique()
user_id2index = { _id:i for i,_id in enumerate(user_ids) }
users = [User(uid) for uid in user_ids]
n_users = len(user_ids)
n_users

2919

#### collect and sanity check transactions per user

In [22]:
sorted_sales_df = sales_df.sort_values('purchase_timestamp')

In [23]:
# clear structures to prevent possible duplicate elements
for user in users:
    user.clear()

# collect transactions per user sorted by timestamp
for uid, a_ids, t in zip(sorted_sales_df.user_id_hash,
                       sorted_sales_df.purchased_artwork_ids_hash,
                       sorted_sales_df.purchase_timestamp):
    for aid in ast.literal_eval(a_ids):
        users[user_id2index[uid]].append_transaction(aid, t, artwork_id2index, artist_ids, cluster_ids)
        assert users[user_id2index[uid]]._uid == uid
    
# bin transctions with same timestamps into purchase baskets
for user in users:
    user.build_purchase_baskets()
    user.sanity_check_purchase_baskets()
    user.remove_last_nonfirst_purchase_basket(
        artwork_id2index, artist_ids, cluster_ids)
    user.sanity_check_purchase_baskets()
    user.refresh_nonpurchased_cluster_ids(n_clusters)
    user.refresh_cluster_ids()
    user.refresh_artist_ids()

### Generate training data

In [24]:
_MOD = 402653189
_BASE = 92821

def hash_triple(profile, pi, ni):
    h = 0
    for x in profile:
        h = ((h * _BASE) % _MOD + x) % _MOD
    h = ((h * _BASE) % _MOD + pi) % _MOD
    h = ((h * _BASE) % _MOD + ni) % _MOD
    return h

In [25]:
def sanity_check_instance(instance,
                          pos_in_profile=True,
                          profile_set=None,
                          pos_sharing_cluster_artist=None,
                          clusters_set=None,
                          artists_set=None,
                          neg_notsharing_artist=None,
                          neg_notsharing_artist_cluster=None,
                         ):
    profile, pi, ni, ui = instance
    try:
        assert 0 <= pi < n_artworks
        assert 0 <= ni < n_artworks
        assert pi != ni        
        assert not vissimhandler.same(pi,ni)
        if ui == -1: return
        
        assert 0 <= ui < n_users
        user = users[ui]
        
        assert all(i in user.artwork_idxs_set for i in profile)
        
        user_profile = user.artwork_idxs_set if profile_set is None else profile_set
        
        if pos_in_profile is True:
            assert pi in user_profile
        elif pos_in_profile is False:
            assert pi not in user_profile
            
        if pos_sharing_cluster_artist:
            assert cluster_ids[pi] in clusters_set
            assert artist_ids[pi] in artists_set
            
        assert ni not in user_profile
        
        if neg_notsharing_artist:
            assert artist_ids[ni] not in artists_set
        if neg_notsharing_artist_cluster:
            assert artist_ids[ni] not in artists_set
            assert cluster_ids[ni] not in clusters_set

    except AssertionError:
        print('profile = ', profile)
        print('pi = ', pi)
        print('ni = ', ni)
        print('ui = ', ui)
        raise

In [26]:
def append_instance(container, instance, **kwargs):
    global _hash_collisions
    profile, pi, ni, ui = instance
    
    h = hash_triple(profile, pi, ni)
    if h in used_hashes:
        _hash_collisions += 1
        return False
    
    if vissimhandler.same(pi, ni):
        return False
    
    sanity_check_instance(instance, **kwargs)
    container.append(instance)
    used_hashes.add(h)
    return True

In [27]:
def print_triple(t):
    profile, pi, ni, ui = t
    print ('profile = ', [artwork_ids[i] for i in profile])
    print ('pi = ', artwork_ids[pi])
    print ('ni = ', artwork_ids[ni])
    print ('ui = ', user_ids[ui] if ui != -1 else -1)

In [28]:
def print_num_samples(sampler_func):
    def wrapper(instances_container, n_samples):
        len_before = len(instances_container)
        sampler_func(instances_container, n_samples)
        actual_samples = len(instances_container) - len_before
        print('  target samples: %d' % n_samples)
        print('  actual samples: %d' % actual_samples)
        print('  delta: %d' % (n_samples - actual_samples))
    return wrapper

In [29]:
vissimhandler = VisualSimilarityHandler(cluster_ids, pca200_embeddings)

In [30]:
vissimhandler.count = 0
used_hashes = set()
_hash_collisions = 0
train_instances = []
test_instances = []

In [31]:
N_STRATEGIES = 6
TOTAL_SAMPLES__TRAIN = int(1e7 + 2)
TOTAL_SAMPLES__TEST =  int(np.round(TOTAL_SAMPLES__TRAIN*.05))
N_SAMPLES_PER_STRATEGY__TRAIN = int(TOTAL_SAMPLES__TRAIN / N_STRATEGIES)
N_SAMPLES_PER_STRATEGY__TEST = int(TOTAL_SAMPLES__TEST / N_STRATEGIES)
N_SAMPLES_PER_STRATEGY__TRAIN, N_SAMPLES_PER_STRATEGY__TEST

(1666667, 83333)

In [32]:
FINE_GRAINED_THRESHOLD = 0.6
VISUAL_CONFIDENCE_THRESHOLD = 0.1

## Original BPR strategy

#### 1) given profile, recommend profile (real users)
Given a user's profile, all items in the profile should be ranked higher than items outside the profile

In [33]:
def sample_artwork_index__outsideprofile(profile_set, pi):
    while True:
        if random.random() <= FINE_GRAINED_THRESHOLD:
            ni = random.choice(clusterId2artworkIndexes[cluster_ids[pi]])
        else:            
            c = random.randint(0, n_clusters-1)
            ni = random.choice(clusterId2artworkIndexes[c])
        if ni not in profile_set:
            return ni

In [34]:
@print_num_samples
def generate_samples__rank_profile_above_nonprofile(instances_container, n_samples):
    n_samples_per_user = ceil(n_samples / n_users)    
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        for _ in range(n_samples_per_user):
            for __ in range(5):
                pi = random.choice(profile)
                ni = sample_artwork_index__outsideprofile(profile_set, pi)
                if append_instance(instances_container, (profile, pi, ni, ui),
                                   profile_set=profile_set):
                    break

In [35]:
print('sampling train instances ...')
generate_samples__rank_profile_above_nonprofile(train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__rank_profile_above_nonprofile(test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1620236
  delta: 46431
sampling test instances ...
  target samples: 83333
  actual samples: 79403
  delta: 3930
1620236 79403
hash_collisions =  1267403
visual_collisions =  285


##### 2) Given profile, recommend profile (fake 1-item profiles)
Given a fake profile of a single item, such item should be ranked higher than any other item

In [36]:
def sample_artwork_index__nonidentical(pi):        
    while True:
        if random.random() <= FINE_GRAINED_THRESHOLD:
            ni = random.choice(clusterId2artworkIndexes[cluster_ids[pi]])
        else:            
            c = random.randint(0, n_clusters-1)
            ni = random.choice(clusterId2artworkIndexes[c])
        if ni != pi:
            return ni

In [37]:
@print_num_samples
def generate_samples__rank_single_item_above_anything_else(instances_container, n_samples):
    n_samples_per_item = ceil(n_samples / n_artworks)
    for pi in range(n_artworks):
        profile = (pi,)
        n = n_samples_per_item
        while n > 0:
            ni = sample_artwork_index__nonidentical(pi)
            if append_instance(instances_container, (profile, pi, ni, -1)):
                n -= 1

In [38]:
print('sampling train instances ...')
generate_samples__rank_single_item_above_anything_else(
    train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__rank_single_item_above_anything_else(
    test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1675422
  delta: -8755
sampling test instances ...
  target samples: 83333
  actual samples: 93079
  delta: -9746
3295658 172482
hash_collisions =  2005324
visual_collisions =  562


## Domain-specific strategies

##### 3) Recommend visually similar items from favorite artists (real users)
Given a user, any item outside the user's profile that shares artist and visual cluster with items in the user's profile should be ranked higher than any item from an artist and visual cluster not present in the user's profile

In [39]:
def sample_artwork_index__outsideprofile__sharing_artist_cluster(profile_set, artists_list, clusters_set):
    for _ in range(20): # try at most 20 times
        # sharing artist
        a = random.choice(artists_list)
        i = random.choice(artistId2artworkIndexes[a])
        # sharing cluster
        if cluster_ids[i] not in clusters_set: continue
        # oustide profile
        if i in profile_set: continue
        # done
        return i
    return None # failed to find

In [40]:
def sample_artwork_index__notsharing_artist_cluster(artists_set, nonused_clusters_list):
    while True:
        # not sharing cluster
        c = random.choice(nonused_clusters_list)
        i = random.choice(clusterId2artworkIndexes[c])
        # not sharing artist
        if artist_ids[i] not in artists_set:
            return i

In [41]:
@print_num_samples
def generate_samples__rank_sharing_artist_cluster_above_notsharing_artist_cluster(instances_container, n_samples):
    n_samples_per_user = ceil(n_samples / n_users)    
    for ui, user in enumerate(users):
        profile = user.artwork_idxs
        profile_set = user.artwork_idxs_set
        artists_list = user.artist_ids
        artists_set = user.artist_ids_set
        clusters_set = user.cluster_ids_set
        nonused_clusters_list = user.nonp_cluster_ids
        for _ in range(n_samples_per_user):
            for __ in range(5):
                pi = sample_artwork_index__outsideprofile__sharing_artist_cluster(profile_set, artists_list, clusters_set)
                if pi is None: continue
                ni = sample_artwork_index__notsharing_artist_cluster(artists_set, nonused_clusters_list)
                if append_instance(instances_container, (profile, pi, ni, ui),
                                   pos_in_profile=False,
                                   pos_sharing_cluster_artist=True,
                                   neg_notsharing_artist_cluster=True,
                                   profile_set = profile_set,
                                   clusters_set = clusters_set,
                                   artists_set = artists_set,
                                  ):
                    break

In [42]:
print('sampling train instances ...')
generate_samples__rank_sharing_artist_cluster_above_notsharing_artist_cluster(
    train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__rank_sharing_artist_cluster_above_notsharing_artist_cluster(
    test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1254404
  delta: 412263
sampling test instances ...
  target samples: 83333
  actual samples: 63684
  delta: 19649
4550062 236166
hash_collisions =  2034640
visual_collisions =  562


##### 4) Recommend visual similar items from favorite artists (fake 1-item profiles)
Given a fake profile of a single item, other items sharing same artist should be ranked higher than items from different artists as long as the PCA200 embedding agrees

In [43]:
def sample_artwork_index__nonidentical_sharing_artist(i):
    a = artist_ids[i]
    assert a != -1
    candidate_idxs = artistId2artworkIndexes[a]
    assert len(candidate_idxs) >= 2
    while True:
        pi = random.choice(candidate_idxs) # sharing artist
        if pi != i: # non-identical
            return pi

In [44]:
def sample_artwork_index__notsharing_artist__visually_acceptable(i, pi):
    for _ in range(20): # try at most 20 times
        if random.random() <= FINE_GRAINED_THRESHOLD:
            ni = random.choice(clusterId2artworkIndexes[cluster_ids[i]])
        else:
            c = random.randint(0, n_clusters-1)
            ni = random.choice(clusterId2artworkIndexes[c])
        if artist_ids[ni] == artist_ids[i]: # not sharing artist
            continue        
        if vissimhandler.validate_triple(i, pi, ni, margin=VISUAL_CONFIDENCE_THRESHOLD): # visually acceptable
            return ni
    return None


In [45]:
@print_num_samples
def generate_samples__rank_sharing_artist_above_notsharing_artist__visuallyacceptable__single_item(
        instances_container, n_samples):
    
    n_valid_items = sum(1 for i in range(n_artworks) if artist_ids[i] != -1 and\
                        len(artistId2artworkIndexes[artist_ids[i]]) >= 2)
    n_samples_per_item = ceil(n_samples / n_valid_items)
    
    for i in range(n_artworks):
        a = artist_ids[i]
        if a == -1 or len(artistId2artworkIndexes[a]) < 2:
            continue
        profile = (i,)
        for _ in range(n_samples_per_item):
            for __ in range(5):
                pi = sample_artwork_index__nonidentical_sharing_artist(i)
                ni = sample_artwork_index__notsharing_artist__visually_acceptable(i, pi)
                if ni is None:
                    continue
                if append_instance(instances_container, (profile, pi, ni, -1)):
                    break


In [46]:
print('sampling train instances ...')
generate_samples__rank_sharing_artist_above_notsharing_artist__visuallyacceptable__single_item(
    train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__rank_sharing_artist_above_notsharing_artist__visuallyacceptable__single_item(
    test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1662072
  delta: 4595
sampling test instances ...
  target samples: 83333
  actual samples: 83077
  delta: 256
6212134 319243
hash_collisions =  2077762
visual_collisions =  562


##### 5) Predict next purchase basket
Given all previous purchases, rank each  item of the next purchase basket higher than any item from a never purchased artist and cluster

In [47]:
@print_num_samples
def generate_samples__given_past_rank_next(instances_container, n_samples):
    
    n_valid_users = sum(1 for user in users if len(user.baskets) >= 2) # at last 2 purchase baskets
    n_samples_per_user = ceil(n_samples / n_valid_users)
    
    for ui, user in enumerate(users):
        n = len(user.baskets)
        if n <= 1:
            continue
        past_items = []        
        n_samples_per_basket = ceil(n_samples_per_user / (n-1))
        for bi in range(n-1):
            cur_b = user.baskets[bi]
            for j in range(cur_b[0], cur_b[0] + cur_b[1]):
                past_items.append(user.artwork_idxs[j])
            next_b  = user.baskets[bi+1]
            profile = past_items.copy()
            for _ in range(n_samples_per_basket):
                for __ in range(5):
                    pi = user.artwork_idxs[random.randint(next_b[0], next_b[0] + next_b[1] - 1)]
                    ni = sample_artwork_index__notsharing_artist_cluster(user.artist_ids_set, user.nonp_cluster_ids)
                    if append_instance(instances_container, (profile, pi, ni, ui),
                                      neg_notsharing_artist_cluster=True,
                                      artists_set=user.artist_ids_set,
                                      clusters_set=user.cluster_ids_set,
                                      ):
                        break

In [48]:
print('sampling train instances ...')
generate_samples__given_past_rank_next(
    train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__given_past_rank_next(
    test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1662943
  delta: 3724
sampling test instances ...
  target samples: 83333
  actual samples: 82827
  delta: 506
7875077 402070
hash_collisions =  2519598
visual_collisions =  562


##### 6) Predict hidden item in the k-th purchase basket given first k
Given the first k purchase baskets of a user, hide one item in the k-th purchase basket, use the rest as profile and rank the hidden item higher than any item from a never purchased artist and cluster

In [49]:
@print_num_samples
def generate_samples__hide_and_predict_one_from_last__first_k_purchase_baskets(instances_container, n_samples):
    
    n_valid_baskets_list = [sum(1 for b in user.baskets if b[1] >= 2) for user in users]
    n_valid_users = sum(1 for x in n_valid_baskets_list if x > 0)
    n_samples_per_user = ceil(n_samples / n_valid_users)
    
    for ui, (user, n_valid_baskets) in enumerate(zip(users, n_valid_baskets_list)):
        if n_valid_baskets == 0:
            continue
        n_samples_per_basket = ceil(n_samples_per_user / n_valid_baskets)
        u_artwork_idxs = user.artwork_idxs
        purchased = []
        for b in user.baskets:            
            bs = b[0]
            be = b[0] + b[1]
            purchased.extend(u_artwork_idxs[j] for j in range(bs, be))
            assert len(purchased) == be
            if b[1] < 2:
                continue            
            n_samples_per_item = ceil(n_samples_per_basket / b[1])            
            for i in range(bs, be):
                profile = [purchased[j] for j in range(be) if j != i]
                assert len(profile) == be - 1
                assert len(profile) > 0
                pi = purchased[i]
                for _ in range(n_samples_per_item):
                    for __ in range(5):
                        ni = sample_artwork_index__notsharing_artist_cluster(user.artist_ids_set, user.nonp_cluster_ids)
                        if append_instance(instances_container, (profile, pi, ni, ui),
                                          neg_notsharing_artist_cluster=True,
                                          artists_set=user.artist_ids_set,
                                          clusters_set=user.cluster_ids_set,
                                          ):
                            break

In [50]:
print('sampling train instances ...')
generate_samples__hide_and_predict_one_from_last__first_k_purchase_baskets(
    train_instances, n_samples=N_SAMPLES_PER_STRATEGY__TRAIN)
print('sampling test instances ...')
generate_samples__hide_and_predict_one_from_last__first_k_purchase_baskets(
    test_instances, n_samples=N_SAMPLES_PER_STRATEGY__TEST)
print(len(train_instances), len(test_instances))
print('hash_collisions = ', _hash_collisions)
print('visual_collisions = ', vissimhandler.count)

sampling train instances ...
  target samples: 1666667
  actual samples: 1667575
  delta: -908
sampling test instances ...
  target samples: 83333
  actual samples: 84437
  delta: -1104
9542652 486507
hash_collisions =  2695203
visual_collisions =  562


#### sort train and test instances by profile size

In [51]:
random.shuffle(train_instances)
train_instances.sort(key=lambda x: len(x[0]))
test_instances.sort(key=lambda x: len(x[0]))

### Train CuratorNet

In [52]:
import tensorflow as tf

In [53]:
train_minibatches = generate_minibatches(train_instances, max_users_items_per_batch=5000*10*2)
sanity_check_minibatches(train_minibatches)

n_tuples =  9542652
n_batches =  484


In [54]:
test_minibatches = generate_minibatches(test_instances, max_users_items_per_batch=5000*10*2)
sanity_check_minibatches(test_minibatches)

n_tuples =  486507
n_batches =  25


In [55]:
learning_rates = get_decaying_learning_rates(1e-4, 1e-6, 0.6)
learning_rates

[0.0001,
 6e-05,
 3.6e-05,
 2.16e-05,
 1.296e-05,
 7.776e-06,
 4.6656e-06,
 2.79936e-06,
 1.679616e-06,
 1.0077696e-06]

In [56]:
FINE_GRAINED_THRESHOLD, VISUAL_CONFIDENCE_THRESHOLD

(0.6, 0.1)

In [57]:
MODEL_PATH = 'C:\\Users\\victor_accete\\CuratorNet\\experiments\\curatornet_10m\\'

In [58]:
avg_train_batch_size = ceil(np.mean([b.shape[0] for b in train_minibatches['profile_indexes_batches']]))
avg_train_batch_size

19717

In [60]:
train_network(
    train_minibatches, test_minibatches,
    len(train_instances), len(test_instances),
    batch_size=avg_train_batch_size,
    pretrained_embeddings=concat_featmat,
    user_layer_units=[300,300,200],
    item_layer_units=[200,200],
    profile_pooling_mode='AVG+MAX',
    model_path = MODEL_PATH,
    epochs=10,
    early_stopping_checks=2,
    weight_decay=.0001,
    learning_rates=learning_rates,
)

learning_rates =  [0.0001, 6e-05, 3.6e-05, 2.16e-05, 1.296e-05, 7.776e-06, 4.6656e-06, 2.79936e-06, 1.679616e-06, 1.0077696e-06]


Variables to be trained:
	 <tf.Variable 'trainable_item_embedding/fc1/kernel:0' shape=(2048, 200) dtype=float32_ref>
	 <tf.Variable 'trainable_item_embedding/fc1/bias:0' shape=(200,) dtype=float32_ref>
	 <tf.Variable 'trainable_item_embedding/fc2/kernel:0' shape=(200, 200) dtype=float32_ref>
	 <tf.Variable 'trainable_item_embedding/fc2/bias:0' shape=(200,) dtype=float32_ref>
	 <tf.Variable 'user_hidden_1/kernel:0' shape=(400, 300) dtype=float32_ref>
	 <tf.Variable 'user_hidden_1/bias:0' shape=(300,) dtype=float32_ref>
	 <tf.Variable 'user_hidden_2/kernel:0' shape=(300, 300) dtype=float32_ref>
	 <tf.Variable 'user_hidden_2/bias:0' shape=(300,) dtype=float32_ref>
	 <tf.Variable 'user_vector/kernel:0' shape=(300, 200) dtype=float32_ref>
	 <tf.Variable 'user_vector/bias:0' shape=(200,) dtype=float32_ref>
	 <tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>
	 <tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>
	 <tf.Variable 'trainable_item_embedding/fc1/kernel/Adam:0' shape=(204