# Tutorial on inference with approximate nearest neighbours search

RePlay models such as Word2Vec and ALS could return user or item vectors which could be used to find close items and get item2item recommendation or ret relevant items for user (ALS). Basic way to do so is to calculate the distance (e.g. cosine, negative dot product) between all available vectors and select top-k, but this calculation is computationally expensive. Other way is to use approximate methods.

There are many approximate nearest neighbours search methods and libraries (see [ANN benchmarks](http://ann-benchmarks.com/)).

We will use one of them to build index and make prediction for selected users. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%config Completer.use_jedi = False

In [3]:
import warnings
from optuna.exceptions import ExperimentalWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)

In [4]:
from replay.session_handler import State

spark = State().session
spark

21/12/13 13:55:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/13 13:55:33 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
21/12/13 13:55:34 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
21/12/13 13:55:34 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
21/12/13 13:55:34 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.


In [5]:
spark.sparkContext.setLogLevel("ERROR")

In [6]:
K = 5
SEED=1234

## 0. Data preprocessing <a name='data-preparator'></a>
We will use MovieLens 20 m as an example as it contains over 20000 items.

In [49]:
# !pip install rs_datasets

In [8]:
%%time
from rs_datasets import MovieLens
ml = MovieLens('20m')
ml.info()

ratings


Unnamed: 0,user_id,item_id,rating,timestamp
0,1,2,3.5,1112486027
1,1,29,3.5,1112484676
2,1,32,3.5,1112484819



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance



tags


Unnamed: 0,user_id,item_id,tag,timestamp
0,18,4141,Mark Waters,1240597180
1,65,208,dark hero,1368150078
2,65,353,dark hero,1368150079



links


Unnamed: 0,item_id,imdb_id,tmdb_id
0,1,114709,862.0
1,2,113497,8844.0
2,3,113228,15602.0



CPU times: user 2min 54s, sys: 2.19 s, total: 2min 56s
Wall time: 18.4 s


### 0.1. DataPreparator

In [9]:
from replay.data_preparator import DataPreparator

log = DataPreparator().transform(
    data=ml.ratings,
    columns_names={
        "user_id": "user_id",
        "item_id": "item_id",
        "relevance": "rating",
        "timestamp": "timestamp"
    }
)

                                                                                

In [10]:
log.show(3)

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|      2|      3.5|2005-04-02 23:53:47|
|      1|     29|      3.5|2005-04-02 23:31:16|
|      1|     32|      3.5|2005-04-02 23:33:39|
+-------+-------+---------+-------------------+
only showing top 3 rows



In [11]:
from replay.utils import get_log_info
get_log_info(log)

                                                                                

'total lines: 20000263, total users: 138493, total items: 26744'

### 0.2. Split

Let's take 2000 users and 5 interactions for each user to the test part. It is enough to show difference in speed between direct predict and predict with ANN search.

In [42]:
from replay.splitters import UserSplitter

splitter = UserSplitter(
    drop_cold_items=True,
    drop_cold_users=True,
    item_test_size=K,
    user_test_size=2000,
    seed=SEED,
    shuffle=False
)
train, test = splitter.split(log)
(
    train.count(), 
    test.count()
)

                                                                                

(19990263, 9996)

In [44]:
train.write.parquet(path='train_ann.parquet', mode='overwrite')
test.write.parquet(path='test_ann.parquet', mode='overwrite')

                                                                                

## 1. Models training
We will train ALS model and get direct prediction for 2 cases:
- with filtering items seen in user history
- without filtering items seen in user history

#### ALS

In [17]:
from replay.models import ALSWrap
from replay.model_handler import save, load

In [45]:
%%time
als = ALSWrap(rank=256, seed=SEED)
als.fit(log=train)

In [48]:
save(als, 'als_20m')

                                                                                

In [50]:
%%time
predict = als.predict(log=train, k=K, users=test.select('user_id').distinct(), filter_seen_items=True)
predict.cache()
predict.count()



CPU times: user 15.9 s, sys: 6.41 s, total: 22.3 s
Wall time: 7min 14s




10000

In [111]:
%%time
predict_seen = als.predict(log=train, k=K, users=test.select('user_id').distinct(), filter_seen_items=False)
predict_seen.cache()
predict_seen.count()



CPU times: user 16.4 s, sys: 5.34 s, total: 21.8 s
Wall time: 6min 49s




10000

Direct prediction takes about 7 minutes and time increases dramatically with the number of users / items.

In [112]:
predict_seen.write.parquet(path='./als_pred_seen_to_compare_with_ann.parquet', mode='overwrite')
predict.write.parquet(path='./als_pred_to_compare_with_ann.parquet', mode='overwrite')

                                                                                

## 2. Evaluate prediction quality

In [96]:
from replay.metrics import Coverage, HitRate, NDCG, MAP
from replay.experiment import Experiment

In [98]:
prediction_qualitity = Experiment(test, {NDCG(): K,
                            MAP() : K,
                            HitRate(): [1, K],
                            Coverage(train): K
                           })

                                                                                

In [99]:
%%time
prediction_qualitity.add_result("ALS_pred_with_seen", predict_seen)
prediction_qualitity.add_result("ALS_pred_filter_seen", predict)
prediction_qualitity.results.sort_values('NDCG@5', ascending=False)

                                                                                

CPU times: user 130 ms, sys: 108 ms, total: 238 ms
Wall time: 40.6 s


Unnamed: 0,Coverage@5,HitRate@1,HitRate@5,MAP@5,NDCG@5
ALS_pred_filter_seen,0.056545,0.0645,0.2205,0.030492,0.057042
ALS_pred_with_seen,0.046298,0.0045,0.024,0.002287,0.004903


# 3. NMSLIB

Let's get user's and item's vectors from als model and build ann index.

In [102]:
import nmslib
import numpy as np
import pandas as pd

import pyspark.sql.functions as sf
from pyspark.sql import Window

from replay.utils import get_top_k_recs


In [103]:
als = load('./als_20m')

In [104]:
train = spark.read.parquet('train_ann.parquet')
test = spark.read.parquet('test_ann.parquet')
predict = spark.read.parquet('./als_pred_to_compare_with_ann.parquet')
predict_seen = spark.read.parquet('./als_pred_seen_to_compare_with_ann.parquet')

In [105]:
# !pip install nmslib
# !pip install --no-binary :all: nmslib

In [106]:
def get_numpy_ids_vectors_from_als(id_vector_spark_df, id_name='item_id', vector_col_name='item_factors'):
    vectors = id_vector_spark_df.toPandas()
    ids = vectors[id_name].to_numpy()
    vectors = vectors[vector_col_name].to_numpy()
    return vectors, ids

In [107]:
%%time
user_vectors, _ = als.get_features(test.select('user_id').distinct())
user_vectors_np, user_ids_np = get_numpy_ids_vectors_from_als(user_vectors, id_name='user_id', vector_col_name='user_factors')



CPU times: user 5.61 s, sys: 1.69 s, total: 7.3 s
Wall time: 21.9 s




In [108]:
%%time
item_vectors, _ = als.get_features(train.select('item_id').distinct())
item_vectors_np, item_ids_np = get_numpy_ids_vectors_from_als(item_vectors, id_name='item_id', vector_col_name='item_factors')



CPU times: user 1.16 s, sys: 471 ms, total: 1.63 s
Wall time: 16.1 s


                                                                                

Let's build index on item vectors and search for nearest for the user vectors from test with negative dot product distance measure.

## index with default parameters
Using hnsw index from nmslib as is

In [109]:
%%time
index = nmslib.init(method='hnsw', space='negdotprod', data_type=nmslib.DataType.DENSE_VECTOR)
index.addDataPointBatch(data=np.stack(item_vectors_np))
index.createIndex()

CPU times: user 11.2 s, sys: 946 ms, total: 12.1 s
Wall time: 671 ms


In [110]:
def get_neighbours(index, vectors, user_ids_list, item_ids_list, k):
    """
    - find nearest items based on user vector
    - convert to spark and process columns to get `user_id, item_id, relevance` columns
    - replace item numbers in index with item ids
    """
    neighbours = index.knnQueryBatch(np.stack(vectors), k=k)
    pd_res = pd.DataFrame(neighbours, columns=['item_idx', 'distance'])
    pd_res['user_id'] = user_ids_list
    spark_res = spark.createDataFrame(pd_res)
    spark_res = spark_res.withColumn('zip_exp', sf.explode(sf.arrays_zip('item_idx', 'distance'))).select('user_id', 'zip_exp')
    spark_res = spark_res.withColumn('item_idx', sf.col('zip_exp.item_idx'))
    spark_res = spark_res.withColumn('distance', sf.col('zip_exp.distance'))
    spark_res = (spark_res.withColumn('relevance',  sf.lit(-1.) * sf.col('distance'))
                 .select('user_id', 'item_idx', 'relevance')
                )
    ids_mapping = spark.createDataFrame(list(zip(range(len(list(item_ids_list))),
                                                 list(item_ids_list))),
                                        schema="item_idx int, item_id string")
    spark_res = spark_res.join(ids_mapping, on='item_idx').drop('item_idx').orderBy('user_id')
    return spark_res

In [111]:
def filter_seen(log, pred, k):
    """
    filter items seen in log and leave top-k most relevant
    """
    
    num_of_seen = (
            log.groupBy("user_id")
            .agg(sf.count("item_id").alias("seen_count"))
        )


    max_seen = num_of_seen.select(sf.max("seen_count")).collect()[0][0]

    recs = pred.withColumn(
        "temp_rank",
        sf.row_number().over(
            Window.partitionBy("user_id").orderBy(
                sf.col("relevance").desc()
            )
        ),
    ).filter(sf.col("temp_rank") <= sf.lit(max_seen + k))
    

    
    recs = (
        recs.join(num_of_seen, on="user_id", how="left")
        .fillna(0)
        .filter(
            sf.col("temp_rank") <= sf.col("seen_count") + sf.lit(k)
        )
        .drop("temp_rank", "seen_count")
    )
    
    recs = recs.join(log, on=["user_id", "item_id"], how="anti")
    return get_top_k_recs(recs, k)

In [112]:
max_items = train.groupBy('user_id').agg(sf.count('item_id').alias('num_items')).select(sf.max('num_items')).collect()[0][0]
max_items

                                                                                

9254

We have maximum of 9254 seen items per user, and it is a lot, as to achieve proper quality we will need to retrieve 9254 + K items per users. There is a tradeoff between index construction/retrieval time and size and retrieval quality. Read about algorithm parameters and tuning [here](https://github.com/nmslib/nmslib/blob/master/manual/methods.md).

In [113]:
ann_res_k = get_neighbours(index, user_vectors_np, user_ids_np, item_ids_np, K)

In [114]:
%%time
ann_res_k.cache()
ann_res_k.count()



CPU times: user 10.9 ms, sys: 13.8 ms, total: 24.7 ms
Wall time: 1.92 s


                                                                                

10000

In [115]:
ann_res = get_neighbours(index, user_vectors_np, user_ids_np, item_ids_np, K + max_items)
ann_res = filter_seen(train, ann_res, K)

In [116]:
%%time
ann_res.cache()
ann_res.count()



CPU times: user 81.2 ms, sys: 32.2 ms, total: 113 ms
Wall time: 26.2 s


                                                                                

10000

In [117]:
index.saveIndex('index')

In [118]:
! du -sh ./index

31M	./index


Prediction with ANN is significantly faster (15 seconds with 7 minutes)

### Index precision evaluation

Let's compare ann results with ground truths (model prediction) for two settings:
    - selecting top-k (no seen items filtering)
    - selecting top-k + maximal number of seen items and post-filtering

We can expect better quality on the first task. 

In [119]:
from replay.metrics import Precision
from replay.experiment import Experiment

In [120]:
metrics_filter_seen = Experiment(predict, {Precision(): K})
metrics_with_seen = Experiment(predict_seen, {Precision(): K})

In [121]:
%%time
metrics_with_seen.add_result("Default_HNSW", ann_res_k)
metrics_with_seen.results



CPU times: user 46.2 ms, sys: 7.8 ms, total: 54 ms
Wall time: 5.71 s


                                                                                

Unnamed: 0,Precision@5
Default_HNSW,0.9516


In [122]:
metrics_filter_seen.add_result("Default_HNSW", ann_res)
metrics_filter_seen.results

                                                                                

Unnamed: 0,Precision@5
Default_HNSW,0.8859


In [123]:
%%time
prediction_qualitity.add_result("HNSW_pred_with_seen", ann_res_k)
prediction_qualitity.add_result("HNSW_pred_filter_seen", ann_res)
prediction_qualitity.results.sort_values('NDCG@5', ascending=False)

                                                                                

CPU times: user 201 ms, sys: 49.9 ms, total: 251 ms
Wall time: 18.4 s


Unnamed: 0,Coverage@5,HitRate@1,HitRate@5,MAP@5,NDCG@5
ALS_pred_filter_seen,0.056545,0.0645,0.2205,0.030492,0.057042
HNSW_pred_filter_seen,0.057105,0.061,0.213,0.029632,0.055304
ALS_pred_with_seen,0.046298,0.0045,0.024,0.002287,0.004903
HNSW_pred_with_seen,0.048467,0.005,0.0235,0.00218,0.004717


## index with more complex parameters

Lets increase parameters to improve retrieval quality: 
- M defines the maximum number of neighbors in the zero and above-zero layers
- ef defines the number of vertex to try on each level to find nearest

See example of index building with nmslib [here](https://github.com/nmslib/nmslib/blob/master/python_bindings/notebooks/search_vector_dense_optim.ipynb)

In [124]:
%%time
index = nmslib.init(method='hnsw', space='negdotprod', data_type=nmslib.DataType.DENSE_VECTOR)
index.addDataPointBatch(data=np.stack(item_vectors_np))
index.createIndex(index_params={'efConstruction':1000, 'M':64})

CPU times: user 1min 2s, sys: 2.32 s, total: 1min 4s
Wall time: 2.84 s


In [125]:
query_time_params = {'efSearch': 1000}
index.setQueryTimeParams(query_time_params)

In [126]:
%%time
ann_res_k = get_neighbours(index, user_vectors_np, user_ids_np, item_ids_np, K)

CPU times: user 2.42 s, sys: 0 ns, total: 2.42 s
Wall time: 257 ms


In [127]:
%%time
ann_res = get_neighbours(index, user_vectors_np, user_ids_np, item_ids_np, K + max_items)
ann_res = filter_seen(train, ann_res, K)

CPU times: user 3.23 s, sys: 229 ms, total: 3.46 s
Wall time: 29.2 s


In [128]:
%%time
ann_res.cache()
ann_res.count()



CPU times: user 88.8 ms, sys: 25.5 ms, total: 114 ms
Wall time: 41.7 s


                                                                                

10000

In [129]:
index.saveIndex('big_index')

In [130]:
! du -sh ./big_index

40M	./big_index


### Index precision evaluation

In [131]:
%%time
metrics_with_seen.add_result("tuned_HNSW", ann_res_k)
metrics_with_seen.results



CPU times: user 31.7 ms, sys: 44.1 ms, total: 75.8 ms
Wall time: 7.39 s


                                                                                

Unnamed: 0,Precision@5
Default_HNSW,0.9516
tuned_HNSW,1.0


In [132]:
metrics_filter_seen.add_result("tuned_HNSW", ann_res)
metrics_filter_seen.results

                                                                                

Unnamed: 0,Precision@5
Default_HNSW,0.8859
tuned_HNSW,1.0


In [133]:
%%time
prediction_qualitity.add_result("tuned_HNSW_pred_with_seen", ann_res_k)
prediction_qualitity.add_result("tuned_HNSW_pred_filter_seen", ann_res)
prediction_qualitity.results.sort_values('NDCG@5', ascending=False)

                                                                                

CPU times: user 228 ms, sys: 125 ms, total: 353 ms
Wall time: 35 s


Unnamed: 0,Coverage@5,HitRate@1,HitRate@5,MAP@5,NDCG@5
ALS_pred_filter_seen,0.056545,0.0645,0.2205,0.030492,0.057042
tuned_HNSW_pred_filter_seen,0.056545,0.0645,0.2205,0.030492,0.057042
HNSW_pred_filter_seen,0.057105,0.061,0.213,0.029632,0.055304
ALS_pred_with_seen,0.046298,0.0045,0.024,0.002287,0.004903
tuned_HNSW_pred_with_seen,0.046298,0.0045,0.024,0.002287,0.004903
HNSW_pred_with_seen,0.048467,0.005,0.0235,0.00218,0.004717


Index quality increased and achieved precision = 1. 
You can try to use hnsw, faiss, annoy of any other ANN library to build you index and get faster prediction for vector models.