In [1]:
!pip install torch



In [2]:
import sys
import logging

import torch
import pandas as pd


class TorchEASE:
    def __init__(
        self, train, user_col="user_id", item_col="item_id", score_col=None, reg=0.05
    ):
        """

        :param train: Training DataFrame of user, item, score(optional) values
        :param user_col: Column name for users
        :param item_col: Column name for items
        :param score_col: Column name for scores. Implicit feedback otherwise
        :param reg: Regularization parameter
        """
        logging.basicConfig(
            format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
            level=logging.INFO,
            datefmt="%Y-%m-%d %H:%M:%S",
            stream=sys.stdout,
        )

        self.logger = logging.getLogger("notebook")
        self.logger.info("Building user + item lookup")
        # How much regularization do you need?
        self.reg = reg

        self.user_col = user_col
        self.item_col = item_col

        self.user_id_col = user_col
        self.item_id_col = item_col

        self.user_lookup = self.generate_labels(train, self.user_col)
        self.item_lookup = self.generate_labels(train, self.item_col)

        self.item_map = {}
        self.logger.info("Building item hashmap")
        print(f'type(self.item_lookup) = {type(self.item_lookup)}')
        for index, _item_id in enumerate(self.item_lookup):
            self.item_map[_item_id] = index

        train = pd.merge(train, self.user_lookup, on=[self.user_col])
        train = pd.merge(train, self.item_lookup, on=[self.item_col])
        self.logger.info("User + item lookup complete")
        self.indices = torch.LongTensor(
            train[[self.user_id_col, self.item_id_col]].values
        )

        if not score_col:
            # Implicit values only
            self.values = torch.ones(self.indices.shape[0])
        else:
            self.values = torch.FloatTensor(train[score_col])
        # TODO: Is Sparse the best implementation?

        self.sparse = torch.sparse.FloatTensor(self.indices.t(), self.values)

        self.logger.info("Sparse data built")

    def generate_labels(self, df, col):
        dist_labels = df[[col]].drop_duplicates()
        dist_labels[col] = dist_labels[col].astype("category").cat.codes

        return dist_labels

    def fit(self):
        self.logger.info("Building G Matrix")
        G = self.sparse.to_dense().t() @ self.sparse.to_dense()
        G += torch.eye(G.shape[0]) * self.reg

        P = G.inverse()

        self.logger.info("Building B matrix")
        B = P / (-1 * P.diag())
        # Set diagonals to 0. TODO: Use .fill_diag_
        B = B + torch.eye(B.shape[0])

        # Predictions for user `_u` will be self.sparse.to_dense()[_u]@self.B
        self.B = B

        return

    def predict_all(self, pred_df, k=5, remove_owned=True):
        """
        :param pred_df: DataFrame of users that need predictions
        :param k: Number of items to recommend to each user
        :param remove_owned: Do you want previously interacted items included?
        :return: DataFrame of users + their predictions in sorted order
        """
        pred_df = pred_df[[self.user_col]].drop_duplicates()
        n_orig = pred_df.shape[0]

        # Alert to number of dropped users in prediction set
        pred_df = pd.merge(pred_df, self.user_lookup, on=[self.user_col])
        n_curr = pred_df.shape[0]
        if n_orig - n_curr:
            self.logger.info(
                "Number of unknown users from prediction data = %i" % (n_orig - n_curr)
            )

        _output_preds = []
        # Select only user_ids in our user data
        _user_tensor = self.sparse.to_dense().index_select(
            dim=0, index=torch.LongTensor(pred_df[self.user_id_col])
        )

        # Make our (raw) predictions
        _preds_tensor = _user_tensor @ self.B
        self.logger.info("Predictions are made")
        if remove_owned:
            # Discount these items by a large factor (much faster than list comp.)
            self.logger.info("Removing owned items")
            _preds_tensor += -1.0 * _user_tensor

        self.logger.info("TopK selected per user")
        for _preds in _preds_tensor:
            # Very quick to use .topk() vs. argmax()
            _output_preds.append(
                [self.item_map[_id] for _id in _preds.topk(k).indices.tolist()]
            )

        pred_df["predicted_items"] = _output_preds
        self.logger.info("Predictions are returned to user")
        return pred_df


In [3]:
import os
import requests

In [4]:
url = 'https://files.grouplens.org/datasets/movielens/ml-25m.zip'

!pip install -q wget
import wget
filename = wget.download(url)

import shutil
destination_path = './'

shutil.unpack_archive(filename)

In [5]:
import pandas as pd

movies = pd.read_csv('~/Downloads/ml-25m/movies.csv')
movies.head()

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy


In [6]:
ratings = pd.read_csv('~/Downloads/ml-25m/ratings.csv')
ratings.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,296,5.0,1147880044
1,1,306,3.5,1147868817
2,1,307,5.0,1147868828
3,1,665,5.0,1147878820
4,1,899,3.5,1147868510


In [7]:
te = TorchEASE(ratings, user_col="userId", item_col="movieId", score_col='rating')

2022-04-22 10:23:47 [INFO] notebook - Building user + item lookup
2022-04-22 10:23:48 [INFO] notebook - Building item hashmap
type(self.item_lookup) = <class 'pandas.core.frame.DataFrame'>
2022-04-22 10:23:55 [INFO] notebook - User + item lookup complete
2022-04-22 10:23:58 [INFO] notebook - Sparse data built


In [None]:
%%time
te.fit()

2022-04-22 10:29:45 [INFO] notebook - Building G Matrix


In [None]:
%%time
out = te.predict_all(ratings, k=3)

In [None]:
out.head()

In [None]:
from pathlib import Path  
filepath = Path('~/Downloads/TorchEase_Movie_recommendations.csv')  
filepath.parent.mkdir(parents=True, exist_ok=True)  
out.to_csv(filepath)