diff --git a/notebooks/02_model/ripplenet_deep_dive.ipynb b/notebooks/02_model/ripplenet_deep_dive.ipynb
new file mode 100644
index 0000000000..e497067307
--- /dev/null
+++ b/notebooks/02_model/ripplenet_deep_dive.ipynb
@@ -0,0 +1,1177 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# RippleNet on MovieLens using Wikidata (Python, GPU)¶"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this example, we will walk through each step of the [RippleNet](https://arxiv.org/pdf/1803.03467.pdf) algorithm.\n",
+ "RippleNet is an end-to-end framework that naturally incorporates knowledge graphs into recommender systems.\n",
+ "To make the results of the paper reproducible we have used MovieLens as our dataset and Wikidata as our Knowledge Graph.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To address the sparsity and cold start problem of collaborative filtering, researchers usually make use of side information, such as social networks or item attributes, to improve recommendation performance. This paper considers the knowledge graph as the source of side information. To address the limitations of existing embedding-based and path-based methods for knowledge-graph-aware recommendation, we propose RippleNet, an end-to-end framework that naturally incorporates the knowledge graph into recommender systems. Similar to actual ripples propagating on the water, RippleNet stimulates the propagation of user preferences over the set of knowledge entities by automatically and iteratively extending a user’s potential interests along links in the knowledge graph. The multiple \"ripples\" activated by a user’s historically clicked items are thus superposed to form the preference distribution of the user with respect to a candidate item, which could be used for predicting the final clicking probability. Through extensive experiments on real-world datasets, we demonstrate that RippleNet achieves substantial gains in a variety of scenarios, including movie, book and news recommendation, over several state-of-the-art baselines.\n",
+ "\n",
+ "![alt text](https://github.com/hwwang55/RippleNet/raw/master/framework.jpg)\n",
+ "\n",
+ "The overall framework of the RippleNet. It takes one user and one item as input, and outputs the predicted probability that the user will click the item. The KGs in the upper part illustrate the corresponding ripple sets activated by the user’s click history."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Implementation\n",
+ "Details of the python implementation can be found [here](../../reco_utils/recommender/ripplenet). The implementation is based on the original code of RippleNet: https://github.com/hwwang55/RippleNet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## RippleNet Movie Recommender"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "System version: 3.6.7 | packaged by conda-forge | (default, Nov 21 2018, 03:09:43) \n",
+ "[GCC 7.3.0]\n",
+ "Pandas version: 0.23.4\n",
+ "Tensorflow version: 1.12.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "sys.path.append(\"../../\")\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import os\n",
+ "import papermill as pm\n",
+ "\n",
+ "from reco_utils.common.timer import Timer\n",
+ "from reco_utils.dataset import movielens\n",
+ "from reco_utils.dataset.python_splitters import python_stratified_split\n",
+ "from reco_utils.recommender.ripplenet.preprocess import (read_item_index_to_entity_id_file, \n",
+ " convert_rating, \n",
+ " convert_kg)\n",
+ "from reco_utils.recommender.ripplenet.data_loader import load_kg, get_ripple_set\n",
+ "from reco_utils.recommender.ripplenet.model import RippleNet\n",
+ "from reco_utils.evaluation.python_evaluation import auc, precision_at_k, recall_at_k\n",
+ "\n",
+ "print(\"System version: {}\".format(sys.version))\n",
+ "print(\"Pandas version: {}\".format(pd.__version__))\n",
+ "print(\"Tensorflow version: {}\".format(tf.__version__))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "tags": [
+ "parameters"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "# Select MovieLens data size: 100k, 1M, 10M\n",
+ "MOVIELENS_DATA_SIZE = '100k'\n",
+ "rating_threshold = 4 #Minimum rating of a movie to be considered positive\n",
+ "remove_negative_ratings = True #Items rated below the threshold will be removed from train and test \n",
+ "\n",
+ "# Ripple parameters\n",
+ "n_epoch = 10 #the number of epochs\n",
+ "batch_size = 1024 #batch size\n",
+ "dim = 16 #dimension of entity and relation embeddings\n",
+ "n_hop = 2 #maximum hops\n",
+ "kge_weight = 0.01 #weight of the KGE term\n",
+ "l2_weight = 1e-7 #weight of the l2 regularization term\n",
+ "lr = 0.02 #learning rate\n",
+ "n_memory = 32 #size of ripple set for each hop\n",
+ "item_update_mode = 'plus_transform' #how to update item at the end of each hop. \n",
+ " #possible options are replace, plus, plus_transform or replace transform\n",
+ "using_all_hops = True #whether using outputs of all hops or just the last hop when making prediction\n",
+ "optimizer_method = \"adam\" #optimizer method from adam, adadelta, adagrad, ftrl (FtrlOptimizer),\n",
+ " #gd (GradientDescentOptimizer), rmsprop (RMSPropOptimizer)\n",
+ "show_loss = False #whether or not to show the loss\n",
+ "seed = 12\n",
+ "\n",
+ "#Evaluation parameters\n",
+ "TOP_K = 10\n",
+ "remove_seen = True"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Read original data and transform entity ids to numerical"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "RippleNet is built on:\n",
+ "- Ratings from users on Movies\n",
+ "- Knowledge Graph (KG) linking Movies to their connected entities in Wikidata. See [this notebook](../01_prepare_data/wikidata_knowledge_graph.ipynb) to understand better how the knowledge graph was created."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 4.81k/4.81k [00:01<00:00, 4.52kKB/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " UserId | \n",
+ " ItemId | \n",
+ " Rating | \n",
+ " Timestamp | \n",
+ " Title | \n",
+ " Genres | \n",
+ " Year | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 196 | \n",
+ " 242 | \n",
+ " 3.0 | \n",
+ " 881250949 | \n",
+ " Kolya (1996) | \n",
+ " Comedy | \n",
+ " 1996 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 63 | \n",
+ " 242 | \n",
+ " 3.0 | \n",
+ " 875747190 | \n",
+ " Kolya (1996) | \n",
+ " Comedy | \n",
+ " 1996 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 226 | \n",
+ " 242 | \n",
+ " 5.0 | \n",
+ " 883888671 | \n",
+ " Kolya (1996) | \n",
+ " Comedy | \n",
+ " 1996 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " UserId ItemId Rating Timestamp Title Genres Year\n",
+ "0 196 242 3.0 881250949 Kolya (1996) Comedy 1996\n",
+ "1 63 242 3.0 875747190 Kolya (1996) Comedy 1996\n",
+ "2 226 242 5.0 883888671 Kolya (1996) Comedy 1996"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ratings_original = movielens.load_pandas_df(MOVIELENS_DATA_SIZE,\n",
+ " ('UserId', 'ItemId', 'Rating', 'Timestamp'),\n",
+ " title_col='Title',\n",
+ " genres_col='Genres',\n",
+ " year_col='Year')\n",
+ "ratings_original.head(3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " original_entity | \n",
+ " linked_entities | \n",
+ " name_linked_entities | \n",
+ " movielens_title | \n",
+ " movielens_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Q1141186 | \n",
+ " Q130232 | \n",
+ " drama film | \n",
+ " Kolya (1996) | \n",
+ " 242 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Q1141186 | \n",
+ " Q157443 | \n",
+ " comedy film | \n",
+ " Kolya (1996) | \n",
+ " 242 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Q1141186 | \n",
+ " Q10819887 | \n",
+ " Andrei Chalimon | \n",
+ " Kolya (1996) | \n",
+ " 242 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " original_entity linked_entities name_linked_entities movielens_title \\\n",
+ "0 Q1141186 Q130232 drama film Kolya (1996) \n",
+ "1 Q1141186 Q157443 comedy film Kolya (1996) \n",
+ "2 Q1141186 Q10819887 Andrei Chalimon Kolya (1996) \n",
+ "\n",
+ " movielens_id \n",
+ "0 242 \n",
+ "1 242 \n",
+ "2 242 "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "kg_original = pd.read_csv(\"https://recodatasets.blob.core.windows.net/wikidata/movielens_{}_wikidata.csv\".format(MOVIELENS_DATA_SIZE))\n",
+ "kg_original.head(3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To be able to link the Ratings and KG ids we create two dictionaries match the KG original IDs to homogeneous numerical IDs. This will be done in two steps:\n",
+ "1. Transforming both Rating ID and KG ID to numerical\n",
+ "2. Matching the IDs using a dictionary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def transform_id(df, entities_id, col_transform, col_name = \"unified_id\"):\n",
+ " df = df.merge(entities_id, left_on = col_transform, right_on = \"entity\")\n",
+ " df = df.rename(columns = {\"unified_id\": col_name})\n",
+ " return df.drop(columns = [col_transform, \"entity\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unified_id | \n",
+ " entity | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " Q607910 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " Q657259 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 2 | \n",
+ " Q491185 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unified_id entity\n",
+ "0 0 Q607910\n",
+ "1 1 Q657259\n",
+ "2 2 Q491185"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create Dictionary that matches KG Wikidata ID to internal numerical KG ID\n",
+ "entities_id = pd.DataFrame({\"entity\":list(set(kg_original.original_entity)) + list(set(kg_original.linked_entities))}).reset_index()\n",
+ "entities_id = entities_id.rename(columns = {\"index\": \"unified_id\"})\n",
+ "entities_id.head(3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " original_entity_id | \n",
+ " relation | \n",
+ " linked_entities_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1177 | \n",
+ " 1 | \n",
+ " 15580 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 16107 | \n",
+ " 1 | \n",
+ " 15580 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1278 | \n",
+ " 1 | \n",
+ " 15580 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " original_entity_id relation linked_entities_id\n",
+ "0 1177 1 15580\n",
+ "1 16107 1 15580\n",
+ "2 1278 1 15580"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Tranforming KG IDs to internal numerical KG IDs created above \n",
+ "kg = kg_original[[\"original_entity\", \"linked_entities\"]].drop_duplicates()\n",
+ "kg = transform_id(kg, entities_id, \"original_entity\", \"original_entity_id\")\n",
+ "kg = transform_id(kg, entities_id, \"linked_entities\", \"linked_entities_id\")\n",
+ "kg[\"relation\"] = 1\n",
+ "kg_wikidata = kg[[\"original_entity_id\",\"relation\", \"linked_entities_id\"]]\n",
+ "kg_wikidata.head(3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " movielens_id | \n",
+ " unified_id | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 242 | \n",
+ " 1177 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 242 | \n",
+ " 16107 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 302 | \n",
+ " 1278 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " movielens_id unified_id\n",
+ "0 242 1177\n",
+ "1 242 16107\n",
+ "2 302 1278"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create Dictionary matching Movielens ID to internal numerical KG ID created above\n",
+ "var_id = \"movielens_id\"\n",
+ "item_to_entity = kg_original[[var_id, \"original_entity\"]].drop_duplicates().reset_index().drop(columns = \"index\")\n",
+ "item_to_entity = transform_id(item_to_entity, entities_id, \"original_entity\")\n",
+ "item_to_entity.head(3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vars_movielens = [\"UserId\", \"ItemId\", \"Rating\", \"Timestamp\"]\n",
+ "ratings = ratings_original[vars_movielens].sort_values(vars_movielens[1])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Preprocess module from RippleNet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " The dictionaries created above will be used on the Ratings and KG dataframes and unify their IDs. Also the Ratings will be converted from a numerical rating (1-5) to a binary rating (0-1) using the rating_threshold"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use dictionary Movielens ID - numerical KG ID to extract two dictionaries to be used on Ratings and KG\n",
+ "item_index_old2new, entity_id2index = read_item_index_to_entity_id_file(item_to_entity)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In the original paper, items are divided into those rated and above the threshold marked as 1, and those unwatched marked as 0. Items watched with a rating below the threshold are removed from train and test:\n",
+ "\n",
+ "> Since MovieLens-1M and Book-Crossing are explicit feedback data, we transform them into implicit feedback where each entry is marked with 1 indicating that the user has rated the item (the threshold of rating is 4 for MovieLens-1M, while no threshold is set for Book-Crossing due to its sparsity), and sample an unwatched set marked as 0 for each user, which is of equal size with the rated ones.\n",
+ "\n",
+ "We have added a param with the option to keep or remove the items watched and rated below the threshold marked as 0, *remove_negative_ratings*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.preprocess:converting rating file ...\n",
+ "INFO:reco_utils.recommender.ripplenet.preprocess:number of users: 942\n",
+ "INFO:reco_utils.recommender.ripplenet.preprocess:number of items: 1677\n"
+ ]
+ }
+ ],
+ "source": [
+ "ratings_final = convert_rating(ratings, item_index_old2new = item_index_old2new,\n",
+ " threshold = rating_threshold,\n",
+ " remove_negative_ratings=remove_negative_ratings,\n",
+ " seed = 12)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.preprocess:converting kg file ...\n",
+ "INFO:reco_utils.recommender.ripplenet.preprocess:number of entities (containing items): 22994\n",
+ "INFO:reco_utils.recommender.ripplenet.preprocess:number of relations: 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "kg_final = convert_kg(kg_wikidata, entity_id2index = entity_id2index)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Split Data and Build RippleSet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The data is divided into train, test and evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_data, test_data, eval_data = python_stratified_split(ratings_final, ratio=[0.6, 0.2, 0.2], col_user='user_index', col_item='item', seed=12)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item | \n",
+ " original_rating | \n",
+ " rating | \n",
+ " user_index | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 129 | \n",
+ " 3281 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 231 | \n",
+ " 1407 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 52 | \n",
+ " 461 | \n",
+ " 4.0 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 229 | \n",
+ " 3273 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 250 | \n",
+ " 2007 | \n",
+ " 0.0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item original_rating rating user_index\n",
+ "129 3281 0.0 0 0\n",
+ "231 1407 0.0 0 0\n",
+ "52 461 4.0 1 0\n",
+ "229 3273 0.0 0 0\n",
+ "250 2007 0.0 0 0"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_data.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The original KG dataframe is transformed into a dictionary, and the number of entities and relations extracted as parameters"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.data_loader:reading KG file ...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of entities: 22908\n",
+ "Number of relations: 1\n"
+ ]
+ }
+ ],
+ "source": [
+ "n_entity, n_relation, kg = load_kg(kg_final)\n",
+ "print(\"Number of entities:\", n_entity)\n",
+ "print(\"Number of relations:\", n_relation)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The rippleset dictionary is built on the positive ratings (relevant entities) of the training data, and using the KG to build set of knowledge triples per user positive rating, from 0 until `n_hop`.\n",
+ "\n",
+ "**Relevant entity**: Given interaction matrix Y and knowledge graph G, the set of k-hop relevant entities for user u is defined as\n",
+ "\n",
+ "$$E^{k}_{u} = \\{t\\ |\\ (h,r,t) ∈ G\\ and\\ h ∈ E^{k−1}_{u}\\}, k=1,2,...,H$$\n",
+ "\n",
+ "Where $E_{u} = V_{u} = \\{v|yuv =1\\}$ is the set of user’s clicked items in the past, which can be seen as the seed set of user $u$ in KG\n",
+ "\n",
+ "**RippleSet**: The k-hop rippleset of user $u$ is defined as the set of knowledge triples starting from $E_{k−1}$:\n",
+ "\n",
+ "$$S^{k}_{u} = \\{(h,r,t)\\ |\\ (h,r,t) ∈ G\\ and\\ h ∈ E^{k−1}_{u}\\}, k = 1,2,...,H$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.data_loader:constructing ripple set ...\n"
+ ]
+ }
+ ],
+ "source": [
+ "user_history_dict = train_data.loc[train_data.rating == 1].groupby('user_index')['item'].apply(list).to_dict()\n",
+ "ripple_set = get_ripple_set(kg, user_history_dict, n_hop=n_hop, n_memory=n_memory)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build model and predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ripple = RippleNet(dim=dim,\n",
+ " n_hop=n_hop,\n",
+ " kge_weight=kge_weight, \n",
+ " l2_weight=l2_weight, \n",
+ " lr=lr,\n",
+ " n_memory=n_memory,\n",
+ " item_update_mode=item_update_mode, \n",
+ " using_all_hops=using_all_hops,\n",
+ " n_entity=n_entity,\n",
+ " n_relation=n_relation,\n",
+ " optimizer_method=optimizer_method,\n",
+ " seed=seed)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 0 train auc: 0.9051 acc: 0.8202\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 1 train auc: 0.9162 acc: 0.8308\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 2 train auc: 0.9326 acc: 0.8527\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 3 train auc: 0.9407 acc: 0.8631\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 4 train auc: 0.9515 acc: 0.8775\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 5 train auc: 0.9615 acc: 0.8932\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 6 train auc: 0.9690 acc: 0.9076\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 7 train auc: 0.9747 acc: 0.9173\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 8 train auc: 0.9789 acc: 0.9248\n",
+ "INFO:reco_utils.recommender.ripplenet.model:epoch 9 train auc: 0.9818 acc: 0.9316\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Took 72.98155543790199 seconds for training.\n"
+ ]
+ }
+ ],
+ "source": [
+ "with Timer() as train_time:\n",
+ " ripple.fit(n_epoch=n_epoch, batch_size=batch_size,\n",
+ " train_data=train_data[[\"user_index\", \"item\", \"rating\"]], \n",
+ " ripple_set=ripple_set,\n",
+ " show_loss=show_loss)\n",
+ "\n",
+ "print(\"Took {} seconds for training.\".format(train_time.interval))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Took 0.7585273641161621 seconds for prediction.\n"
+ ]
+ }
+ ],
+ "source": [
+ "with Timer() as test_time:\n",
+ " labels, scores = ripple.predict(batch_size=batch_size, \n",
+ " data=test_data[[\"user_index\", \"item\", \"rating\"]])\n",
+ " predictions = [1 if i >= 0.5 else 0 for i in scores]\n",
+ "\n",
+ "test_data['scores'] = scores\n",
+ "print(\"Took {} seconds for prediction.\".format(test_time.interval))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:reco_utils.recommender.ripplenet.model:Removing seen items\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Took 2.4120034659281373 seconds for top_k_items.\n"
+ ]
+ }
+ ],
+ "source": [
+ "with Timer() as topk_time:\n",
+ " top_k_items = ripple.recommend_k_items(batch_size=batch_size, \n",
+ " data=test_data[[\"user_index\", \"item\", \"rating\", \"original_rating\"]],\n",
+ " top_k=TOP_K, remove_seen=remove_seen)\n",
+ "print(\"Took {} seconds for top_k_items.\".format(topk_time.interval))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In case you need to re-create the RippleNet again, simply run:\n",
+ "```python\n",
+ "tf.reset_default_graph()```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Results and Evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The auc score is 0.9012968931693994\n"
+ ]
+ }
+ ],
+ "source": [
+ "auc_score = auc(test_data, test_data, \n",
+ " col_user=\"user_index\",\n",
+ " col_item=\"item\",\n",
+ " col_rating=\"rating\",\n",
+ " col_prediction=\"scores\")\n",
+ "print(\"The auc score is {}\".format(auc_score))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The accuracy is 0.8271610513955379\n"
+ ]
+ }
+ ],
+ "source": [
+ "acc_score = np.mean(np.equal(predictions, labels)) # same result as in sklearn.metrics.accuracy_score \n",
+ "print(\"The accuracy is {}\".format(acc_score))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Drop column rank, not necessary for evaluation\n",
+ "top_k_items = top_k_items.drop(columns = \"rank\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The precision_k_score score at k = 10, is 0.8679405520169851\n"
+ ]
+ }
+ ],
+ "source": [
+ "precision_k_score = precision_at_k(top_k_items, top_k_items, \n",
+ " col_user=\"user_index\",\n",
+ " col_item=\"item\",\n",
+ " col_rating=\"original_rating\",\n",
+ " col_prediction=\"scores\",\n",
+ " relevancy_method=\"top_k\",\n",
+ " k=TOP_K)\n",
+ "print(\"The precision_k_score score at k = {}, is {}\".format(TOP_K, precision_k_score))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The recall_k_score score at k = 10, is 1.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "recall_k_score = recall_at_k(top_k_items, top_k_items, \n",
+ " col_user=\"user_index\",\n",
+ " col_item=\"item\",\n",
+ " col_rating=\"original_rating\",\n",
+ " col_prediction=\"scores\",\n",
+ " relevancy_method=\"top_k\",\n",
+ " k=TOP_K)\n",
+ "print(\"The recall_k_score score at k = {}, is {}\".format(TOP_K, recall_k_score))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "auc": 0.9012968931693994
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "accuracy": 0.8271610513955379
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "precision": 0.8679405520169851
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "recall": 1
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "train_time": 72.98155543790199
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "test_time": 0.7585273641161621
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/papermill.record+json": {
+ "topk_time": 2.4120034659281373
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Record results with papermill for tests - ignore this cell\n",
+ "pm.record(\"auc\", auc_score)\n",
+ "pm.record(\"accuracy\", acc_score)\n",
+ "pm.record(\"precision\", precision_k_score)\n",
+ "pm.record(\"recall\", recall_k_score)\n",
+ "pm.record(\"train_time\", train_time.interval)\n",
+ "pm.record(\"test_time\", test_time.interval)\n",
+ "pm.record(\"topk_time\", topk_time.interval)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "\n",
+ "1. Hongwei Wang, Fuzheng Zhang, Jialin Wang, Miao Zhao, Wenjie Li, Xing Xie, Minyi Guo, \"RippleNet: Propagating User Preferences on the Knowledge Graph for Recommender Systems\", *The 27th ACM International Conference on Information and Knowledge Management (CIKM 2018)*, 2018. https://arxiv.org/pdf/1803.03467.pdf\n",
+ "1. The original implementation of RippleNet: https://github.com/hwwang55/RippleNet"
+ ]
+ }
+ ],
+ "metadata": {
+ "celltoolbar": "Tags",
+ "kernelspec": {
+ "display_name": "Python (reco_gpu)",
+ "language": "python",
+ "name": "reco_gpu"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/reco_utils/recommender/ripplenet/data_loader.py b/reco_utils/recommender/ripplenet/data_loader.py
new file mode 100644
index 0000000000..d8cddc91b6
--- /dev/null
+++ b/reco_utils/recommender/ripplenet/data_loader.py
@@ -0,0 +1,88 @@
+# This code is modified from RippleNet
+# Online code of RippleNet: https://github.com/hwwang55/RippleNet
+
+import collections
+import os
+import numpy as np
+import logging
+
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger(__name__)
+
+
+def load_kg(kg_final):
+ """Standarize indexes for items and entities
+
+ Args:
+ kg_final (pd.DataFrame): knowledge graph converted with columns head,
+ relation and tail, with internal entity IDs
+
+ Returns:
+ n_entity (int): number of entities in KG
+ n_relation (int): number of relations in KG
+ kg (dictionary): KG in dictionary shape
+ """
+ log.info("reading KG file ...")
+
+ n_entity = len(set(kg_final.iloc[:, 0]) | set(kg_final.iloc[:, 2]))
+ n_relation = len(set(kg_final.iloc[:, 1]))
+
+ kg = collections.defaultdict(list)
+ for index, row in kg_final.iterrows():
+ kg[row["head"]].append((row["tail"], row["relation"]))
+
+ return n_entity, n_relation, kg
+
+
+def get_ripple_set(kg, user_history_dict, n_hop=2, n_memory=36):
+ """Build Ripple Set, dictionary for the related entities in the KG
+ given the paths of users, number of hops and memory
+
+ Args:
+ kg (dictionary): KG in dictionary shape
+ user_history_dict (dictionary): positive ratings from train data, to build ripple structure
+ n_hop (int): int, maximum hops in the KG
+ n_memory (int): int, size of ripple set for each hop
+
+ Returns:
+ ripple_set (dictionary): set of knowledge triples per user positive rating, from 0 until n_hop
+ """
+ log.info("constructing ripple set ...")
+
+ # user -> [(hop_0_heads, hop_0_relations, hop_0_tails), (hop_1_heads, hop_1_relations, hop_1_tails), ...]
+ ripple_set = collections.defaultdict(list)
+
+ for user in user_history_dict:
+ for h in range(n_hop):
+ memories_h = []
+ memories_r = []
+ memories_t = []
+
+ if h == 0:
+ tails_of_last_hop = user_history_dict[user]
+ else:
+ tails_of_last_hop = ripple_set[user][-1][2]
+
+ for entity in tails_of_last_hop:
+ for tail_and_relation in kg[entity]:
+ memories_h.append(entity)
+ memories_r.append(tail_and_relation[1])
+ memories_t.append(tail_and_relation[0])
+
+ # if the current ripple set of the given user is empty, we simply copy the ripple set of the last hop here
+ # this won't happen for h = 0, because only the items that appear in the KG have been selected
+ # this only happens on 154 users in Book-Crossing dataset (since both BX dataset and the KG are sparse)
+ if len(memories_h) == 0:
+ ripple_set[user].append(ripple_set[user][-1])
+ else:
+ # sample a fixed-size 1-hop memory for each user
+ replace = len(memories_h) < n_memory
+ indices = np.random.choice(
+ len(memories_h), size=n_memory, replace=replace
+ )
+ memories_h = [memories_h[i] for i in indices]
+ memories_r = [memories_r[i] for i in indices]
+ memories_t = [memories_t[i] for i in indices]
+ ripple_set[user].append((memories_h, memories_r, memories_t))
+
+ return ripple_set
diff --git a/reco_utils/recommender/ripplenet/model.py b/reco_utils/recommender/ripplenet/model.py
new file mode 100644
index 0000000000..38920edbb0
--- /dev/null
+++ b/reco_utils/recommender/ripplenet/model.py
@@ -0,0 +1,432 @@
+# This code is modified from RippleNet
+# Online code of RippleNet: https://github.com/hwwang55/RippleNet
+
+import tensorflow as tf
+import numpy as np
+import pandas as pd
+import logging
+from sklearn.metrics import roc_auc_score
+
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger(__name__)
+
+
+class RippleNet(object):
+ """RippleNet Implementation. RippleNet is an end-to-end framework that naturally
+ incorporates the knowledge graphs into recommender systems.
+ Similar to actual ripples propagating on the water, RippleNet stimulates the propagation
+ of user preferences over the set of knowledge entities by automatically and iteratively
+ extending a user’s potential interests along links in the knowledge graph.
+ """
+
+ def __init__(
+ self,
+ dim,
+ n_hop,
+ kge_weight,
+ l2_weight,
+ lr,
+ n_memory,
+ item_update_mode,
+ using_all_hops,
+ n_entity,
+ n_relation,
+ optimizer_method="adam",
+ seed=None,
+ ):
+
+ """Initialize model parameters
+
+ Args:
+ dim (int): dimension of entity and relation embeddings
+ n_hop (int): maximum hops to create ripples using the KG
+ kge_weight (float): weight of the KGE term
+ l2_weight (float): weight of the l2 regularization term
+ lr (float): learning rate
+ n_memory (int): size of ripple set for each hop
+ item_update_mode (string): how to update item at the end of each hop.
+ possible options are replace, plus, plus_transform or replace transform
+ using_all_hops (bool): whether to use outputs of all hops or just the
+ last hop when making prediction
+ n_entity (int): number of entitites in the KG
+ n_relation (int): number of types of relations in the KG
+ optimizer_method (string): optimizer method from adam, adadelta, adagrad, ftrl (FtrlOptimizer),
+ #gd (GradientDescentOptimizer), rmsprop (RMSPropOptimizer)
+ seed (int): initial seed value
+ """
+ self.seed = seed
+ tf.set_random_seed(seed)
+ np.random.seed(seed)
+
+ self.n_entity = n_entity
+ self.n_relation = n_relation
+ self.dim = dim
+ self.n_hop = n_hop
+ self.kge_weight = kge_weight
+ self.l2_weight = l2_weight
+ self.lr = lr
+ self.n_memory = n_memory
+ self.item_update_mode = item_update_mode
+ self.using_all_hops = using_all_hops
+ self.optimizer_method = optimizer_method
+
+ self._build_inputs()
+ self._build_embeddings()
+ self._build_model()
+ self._build_loss()
+ self._build_optimizer()
+
+ self.init_op = tf.global_variables_initializer()
+
+ # set GPU use with demand growth
+ gpu_options = tf.GPUOptions(allow_growth=True)
+ self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+ self.sess.run(self.init_op)
+
+ def _build_inputs(self):
+ self.items = tf.placeholder(dtype=tf.int32, shape=[None], name="items")
+ self.labels = tf.placeholder(dtype=tf.float64, shape=[None], name="labels")
+ self.memories_h = []
+ self.memories_r = []
+ self.memories_t = []
+
+ for hop in range(self.n_hop):
+ self.memories_h.append(
+ tf.placeholder(
+ dtype=tf.int32,
+ shape=[None, self.n_memory],
+ name="memories_h_" + str(hop),
+ )
+ )
+ self.memories_r.append(
+ tf.placeholder(
+ dtype=tf.int32,
+ shape=[None, self.n_memory],
+ name="memories_r_" + str(hop),
+ )
+ )
+ self.memories_t.append(
+ tf.placeholder(
+ dtype=tf.int32,
+ shape=[None, self.n_memory],
+ name="memories_t_" + str(hop),
+ )
+ )
+
+ def _build_embeddings(self):
+ self.entity_emb_matrix = tf.get_variable(
+ name="entity_emb_matrix",
+ dtype=tf.float64,
+ shape=[self.n_entity, self.dim],
+ initializer=tf.contrib.layers.xavier_initializer(),
+ )
+ self.relation_emb_matrix = tf.get_variable(
+ name="relation_emb_matrix",
+ dtype=tf.float64,
+ shape=[self.n_relation, self.dim, self.dim],
+ initializer=tf.contrib.layers.xavier_initializer(),
+ )
+
+ def _build_model(self):
+ # transformation matrix for updating item embeddings at the end of each hop
+ self.transform_matrix = tf.get_variable(
+ name="transform_matrix",
+ shape=[self.dim, self.dim],
+ dtype=tf.float64,
+ initializer=tf.contrib.layers.xavier_initializer(),
+ )
+
+ # [batch size, dim]
+ self.item_embeddings = tf.nn.embedding_lookup(
+ self.entity_emb_matrix, self.items
+ )
+
+ self.h_emb_list = []
+ self.r_emb_list = []
+ self.t_emb_list = []
+ for i in range(self.n_hop):
+ # [batch size, n_memory, dim]
+ self.h_emb_list.append(
+ tf.nn.embedding_lookup(self.entity_emb_matrix, self.memories_h[i])
+ )
+
+ # [batch size, n_memory, dim, dim]
+ self.r_emb_list.append(
+ tf.nn.embedding_lookup(self.relation_emb_matrix, self.memories_r[i])
+ )
+
+ # [batch size, n_memory, dim]
+ self.t_emb_list.append(
+ tf.nn.embedding_lookup(self.entity_emb_matrix, self.memories_t[i])
+ )
+
+ o_list = self._key_addressing()
+
+ self.scores = tf.squeeze(self._predict_scores(self.item_embeddings, o_list))
+ self.scores_normalized = tf.sigmoid(self.scores)
+
+ def _key_addressing(self):
+ o_list = []
+ for hop in range(self.n_hop):
+ # [batch_size, n_memory, dim, 1]
+ h_expanded = tf.expand_dims(self.h_emb_list[hop], axis=3)
+
+ # [batch_size, n_memory, dim]
+ Rh = tf.squeeze(tf.matmul(self.r_emb_list[hop], h_expanded), axis=3)
+
+ # [batch_size, dim, 1]
+ v = tf.expand_dims(self.item_embeddings, axis=2)
+
+ # [batch_size, n_memory]
+ probs = tf.squeeze(tf.matmul(Rh, v), axis=2)
+
+ # [batch_size, n_memory]
+ probs_normalized = tf.nn.softmax(probs)
+
+ # [batch_size, n_memory, 1]
+ probs_expanded = tf.expand_dims(probs_normalized, axis=2)
+
+ # [batch_size, dim]
+ o = tf.reduce_sum(self.t_emb_list[hop] * probs_expanded, axis=1)
+
+ self.item_embeddings = self._update_item_embedding(self.item_embeddings, o)
+ o_list.append(o)
+ return o_list
+
+ def _update_item_embedding(self, item_embeddings, o):
+
+ if self.item_update_mode == "replace":
+ item_embeddings = o
+ elif self.item_update_mode == "plus":
+ item_embeddings = item_embeddings + o
+ elif self.item_update_mode == "replace_transform":
+ item_embeddings = tf.matmul(o, self.transform_matrix)
+ elif self.item_update_mode == "plus_transform":
+ item_embeddings = tf.matmul(item_embeddings + o, self.transform_matrix)
+ else:
+ raise Exception("Unknown item updating mode: " + self.item_update_mode)
+ return item_embeddings
+
+ def _predict_scores(self, item_embeddings, o_list):
+ y = o_list[-1]
+ if self.using_all_hops:
+ for i in range(self.n_hop - 1):
+ y += o_list[i]
+
+ scores = tf.reduce_sum(item_embeddings * y, axis=1)
+ return scores
+
+ def _build_loss(self):
+ self.base_loss = tf.reduce_mean(
+ tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=self.labels, logits=self.scores
+ )
+ )
+
+ self.kge_loss = 0
+ for hop in range(self.n_hop):
+ h_expanded = tf.expand_dims(self.h_emb_list[hop], axis=2)
+ t_expanded = tf.expand_dims(self.t_emb_list[hop], axis=3)
+ hRt = tf.squeeze(
+ tf.matmul(tf.matmul(h_expanded, self.r_emb_list[hop]), t_expanded)
+ )
+ self.kge_loss += tf.reduce_mean(tf.sigmoid(hRt))
+ self.kge_loss = -self.kge_weight * self.kge_loss
+
+ self.l2_loss = 0
+ for hop in range(self.n_hop):
+ self.l2_loss += tf.reduce_mean(
+ tf.reduce_sum(self.h_emb_list[hop] * self.h_emb_list[hop])
+ )
+ self.l2_loss += tf.reduce_mean(
+ tf.reduce_sum(self.t_emb_list[hop] * self.t_emb_list[hop])
+ )
+ self.l2_loss += tf.reduce_mean(
+ tf.reduce_sum(self.r_emb_list[hop] * self.r_emb_list[hop])
+ )
+ if (
+ self.item_update_mode == "replace nonlinear"
+ or self.item_update_mode == "plus nonlinear"
+ ):
+ self.l2_loss += tf.nn.l2_loss(self.transform_matrix)
+ self.l2_loss = self.l2_weight * self.l2_loss
+
+ self.loss = self.base_loss + self.kge_loss + self.l2_loss
+
+ def _build_optimizer(self):
+
+ if self.optimizer_method == "adam":
+ self.optimizer = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
+ elif self.optimizer_method == "adadelta":
+ self.optimizer = tf.train.AdadeltaOptimizer(self.lr).minimize(self.loss)
+ elif self.optimizer_method == "adagrad":
+ self.optimizer = tf.train.AdagradOptimizer(self.lr).minimize(self.loss)
+ elif self.optimizer_method == "ftrl":
+ self.optimizer = tf.train.FtrlOptimizer(self.lr).minimize(self.loss)
+ elif self.optimizer_method == "gd":
+ self.optimizer = tf.train.GradientDescentOptimizer(self.lr).minimize(
+ self.loss
+ )
+ elif self.optimizer_method == "rmsprop":
+ self.optimizer = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss)
+ else:
+ raise Exception("Unkown optimizer method: " + self.optimizer_method)
+
+ def _train(self, feed_dict):
+ return self.sess.run([self.optimizer, self.loss], feed_dict)
+
+ def _return_scores(self, feed_dict):
+ labels, scores = self.sess.run([self.labels, self.scores_normalized], feed_dict)
+ return labels, scores
+
+ def _eval(self, feed_dict):
+ labels, scores = self.sess.run([self.labels, self.scores_normalized], feed_dict)
+ auc = roc_auc_score(y_true=labels, y_score=scores)
+ predictions = [1 if i >= 0.5 else 0 for i in scores]
+ acc = np.mean(np.equal(predictions, labels))
+ return auc, acc
+
+ def _get_feed_dict(self, data, start, end):
+ feed_dict = dict()
+ feed_dict[self.items] = data[start:end, 1]
+ feed_dict[self.labels] = data[start:end, 2]
+ for i in range(self.n_hop):
+ feed_dict[self.memories_h[i]] = [
+ self.ripple_set[user][i][0] for user in data[start:end, 0]
+ ]
+ feed_dict[self.memories_r[i]] = [
+ self.ripple_set[user][i][1] for user in data[start:end, 0]
+ ]
+ feed_dict[self.memories_t[i]] = [
+ self.ripple_set[user][i][2] for user in data[start:end, 0]
+ ]
+ return feed_dict
+
+ def _print_metrics_evaluation(self, data, batch_size):
+ start = 0
+ auc_list = []
+ acc_list = []
+ while start < data.shape[0]:
+ auc, acc = self._eval(
+ self._get_feed_dict(data=data, start=start, end=start + batch_size)
+ )
+ auc_list.append(auc)
+ acc_list.append(acc)
+ start += batch_size
+ return float(np.mean(auc_list)), float(np.mean(acc_list))
+
+ def fit(self, n_epoch, batch_size, train_data, ripple_set, show_loss):
+ """Main fit method for RippleNet.
+
+ Args:
+ n_epoch (int): the number of epochs
+ batch_size (int): batch size
+ train_data (pd.DataFrame): User id, item and rating dataframe
+ ripple_set (dictionary): set of knowledge triples per user positive rating, from 0 until n_hop
+ show_loss (bool): whether to show loss update
+ """
+ self.ripple_set = ripple_set
+ self.train_data = train_data.values
+ for step in range(n_epoch):
+ # training
+ np.random.shuffle(self.train_data)
+ start = 0
+ while start < self.train_data.shape[0]:
+ _, loss = self._train(
+ self._get_feed_dict(
+ data=self.train_data, start=start, end=start + batch_size
+ )
+ )
+ start += batch_size
+ if show_loss:
+ log.info(
+ "%.1f%% %.4f" % (start / self.train_data.shape[0] * 100, loss)
+ )
+
+ train_auc, train_acc = self._print_metrics_evaluation(
+ data=self.train_data, batch_size=batch_size
+ )
+
+ log.info(
+ "epoch %d train auc: %.4f acc: %.4f" % (step, train_auc, train_acc)
+ )
+
+ def predict(self, batch_size, data):
+ """Main predict method for RippleNet.
+
+ Args:
+ batch_size (int): batch size
+ data (pd.DataFrame): User id, item and rating dataframe
+
+ Returns:
+ (list, list): real labels of the predicted items, predicted scores of the predicted items
+ """
+ data = data.values
+ start = 0
+ labels = [0] * data.shape[0]
+ scores = [0] * data.shape[0]
+ while start < data.shape[0]:
+ (
+ labels[start : start + batch_size],
+ scores[start : start + batch_size],
+ ) = self._return_scores(
+ feed_dict=self._get_feed_dict(
+ data=data, start=start, end=start + batch_size
+ )
+ )
+ start += batch_size
+
+ return labels, scores
+
+ def recommend_k_items(self, batch_size, data, top_k=10, remove_seen=True):
+ """Recommend top K items method for RippleNet.
+
+ Args:
+ batch_size (int): batch size
+ data (pd.DataFrame): User id, item and rating dataframe
+ top_k (int): number of items to recommend
+ remove_seen (bool): if the items seen by an user in train should be recomed from the test set
+
+ Returns:
+ (pd.DataFrame): top K items by score per user
+ """
+ if remove_seen == True:
+ log.info("Removing seen items")
+ train_data = pd.DataFrame(self.train_data).iloc[:, 0:2]
+ train_data.columns = list(data.columns[0:2])
+ seen_items = data.merge(
+ train_data.iloc[:, 0:2],
+ on=list(data.columns[0:2]),
+ indicator=True,
+ how="left",
+ )
+ data = seen_items[seen_items["_merge"] == "left_only"].drop(
+ columns=["_merge"]
+ )
+ data_np = data.values
+ start = 0
+ labels = [0] * data_np.shape[0]
+ scores = [0] * data_np.shape[0]
+ while start < data_np.shape[0]:
+ (
+ labels[start : start + batch_size],
+ scores[start : start + batch_size],
+ ) = self._return_scores(
+ feed_dict=self._get_feed_dict(
+ data=data_np, start=start, end=start + batch_size
+ )
+ )
+ start += batch_size
+
+ data["scores"] = scores
+ top_k_items = (
+ data.groupby(data.columns[0], as_index=False)
+ .apply(lambda x: x.nlargest(top_k, "scores"))
+ .reset_index(drop=True)
+ )
+ # Add ranks
+ top_k_items["rank"] = (
+ top_k_items.groupby(data.columns[0], sort=False).cumcount() + 1
+ )
+
+ return top_k_items
diff --git a/reco_utils/recommender/ripplenet/preprocess.py b/reco_utils/recommender/ripplenet/preprocess.py
new file mode 100644
index 0000000000..450738e2ac
--- /dev/null
+++ b/reco_utils/recommender/ripplenet/preprocess.py
@@ -0,0 +1,164 @@
+# This code is modified from RippleNet
+# Online code of RippleNet: https://github.com/hwwang55/RippleNet
+
+import argparse
+import numpy as np
+import pandas as pd
+import logging
+
+logging.basicConfig(level=logging.INFO)
+log = logging.getLogger(__name__)
+
+
+def read_item_index_to_entity_id_file(item_to_entity):
+ """Standarize indexes for items and entities
+
+ Args:
+ item_to_entity (pd.DataFrame): KG dataframe with original item and entity IDs
+
+ Returns:
+ item_index_old2new (dictionary): dictionary conversion from original item ID to internal item ID
+ entity_id2index (dictionary): dictionary conversion from original entity ID to internal entity ID
+ """
+ item_index_old2new = dict()
+ entity_id2index = dict()
+ i = 0
+ for index, row in item_to_entity.iterrows():
+ item_index = str(row[0])
+ satori_id = str(row[1])
+ item_index_old2new[item_index] = i
+ entity_id2index[satori_id] = i
+ i += 1
+ return item_index_old2new, entity_id2index
+
+
+def convert_rating(ratings, item_index_old2new, threshold, remove_negative_ratings=True, seed=14):
+ """Apply item standarization to ratings dataset.
+ Use rating threshold to determite positive ratings
+
+ Args:
+ ratings (pd.DataFrame): ratings with columns ["UserId", "ItemId", "Rating"]
+ item_index_old2new (dictionary): dictionary, conversion from original item ID to internal item ID
+ threshold (int): minimum valur for the rating to be considered positive
+ remove_negative_ratings (bool): if the train/test set should exclude items below the threshold,
+ as the original papel proposes
+
+ Returns:
+ ratings_final (pd.DataFrame): ratings converted with columns userID,
+ internal item ID and binary rating (1, 0)
+ """
+ item_set = set(item_index_old2new.values())
+ user_pos_ratings = dict()
+ user_neg_ratings = dict()
+
+ for index, row in ratings.iterrows():
+ item_index_old = str(int(row[1]))
+ if (
+ item_index_old not in item_index_old2new
+ ): # the item is not in the final item set
+ continue
+ item_index = item_index_old2new[item_index_old]
+
+ user_index_old = int(row[0])
+
+ rating = float(row[2])
+ if rating >= threshold:
+ if user_index_old not in user_pos_ratings:
+ user_pos_ratings[user_index_old] = set()
+ user_pos_ratings[user_index_old].add((item_index, rating))
+ else:
+ if user_index_old not in user_neg_ratings:
+ user_neg_ratings[user_index_old] = set()
+ user_neg_ratings[user_index_old].add((item_index, rating))
+
+ log.info("converting rating file ...")
+ writer = []
+ user_cnt = 0
+ user_index_old2new = dict()
+ for user_index_old, pos_item_set in user_pos_ratings.items():
+ if user_index_old not in user_index_old2new:
+ user_index_old2new[user_index_old] = user_cnt
+ user_cnt += 1
+ user_index = user_index_old2new[user_index_old]
+ for item, original_rating in pos_item_set:
+ writer.append(
+ {
+ "user_index": user_index,
+ "item": item,
+ "rating": 1,
+ "original_rating": original_rating,
+ }
+ )
+ pos_item_set = set(i[0] for i in pos_item_set)
+ unwatched_set = item_set - pos_item_set
+ if user_index_old in user_neg_ratings:
+ negative_set = dict(list(user_neg_ratings[user_index_old]))
+ if remove_negative_ratings == True:
+ unwatched_set -= set(negative_set.keys())
+ else:
+ negative_set = {}
+ np.random.seed(seed)
+ for item in np.random.choice(
+ list(unwatched_set), size=len(pos_item_set), replace=False
+ ):
+ if item in negative_set:
+ original_rating = negative_set[item]
+ else:
+ original_rating = 0
+ writer.append(
+ {
+ "user_index": user_index,
+ "item": item,
+ "rating": 0,
+ "original_rating": original_rating,
+ }
+ )
+ ratings_final = pd.DataFrame(writer)
+ log.info("number of users: %d" % user_cnt)
+ log.info("number of items: %d" % len(item_set))
+ return ratings_final
+
+
+def convert_kg(kg, entity_id2index):
+ """Apply entity standarization to KG dataset
+ Args:
+ kg (pd.DataFrame): knowledge graph with columns ["original_entity_id", "relation", "linked_entities_id"]
+ entity_id2index (pd.DataFrame): dictionary, conversion from original entity ID to internal entity ID
+
+ Returns:
+ kg_final (pd.DataFrame): knowledge graph converted with columns head,
+ relation and tail, with internal entity IDs
+ """
+ log.info("converting kg file ...")
+ entity_cnt = len(entity_id2index)
+ relation_cnt = 0
+ relation_id2index = dict()
+
+ writer = []
+
+ for index, row in kg.iterrows():
+ head_old = str(int(row[0]))
+ relation_old = row[1]
+ tail_old = str(int(row[2]))
+
+ if head_old not in entity_id2index:
+ entity_id2index[head_old] = entity_cnt
+ entity_cnt += 1
+ head = entity_id2index[head_old]
+
+ if tail_old not in entity_id2index:
+ entity_id2index[tail_old] = entity_cnt
+ entity_cnt += 1
+ tail = entity_id2index[tail_old]
+
+ if relation_old not in relation_id2index:
+ relation_id2index[relation_old] = relation_cnt
+ relation_cnt += 1
+ relation = relation_id2index[relation_old]
+
+ writer.append({"head": head, "relation": relation, "tail": tail})
+
+ kg_final = pd.DataFrame(writer)
+ log.info("number of entities (containing items): %d" % entity_cnt)
+ log.info("number of relations: %d" % relation_cnt)
+ return kg_final