In [1]:
import pandas as pd
import sys;sys.path.append('..')
import warnings; warnings.simplefilter("ignore")
from surprise import SVD, Reader, Dataset
from collections import defaultdict
import numpy as np

df = pd.read_csv('../data/explicit_usa.csv')
df = df.drop("Unnamed: 0", axis=1)
df.head()

Unnamed: 0,User-ID,ISBN,Book-Rating,Age,city,state,country,Book-Title,Book-Author,Year-Of-Publication,Publisher,Image-URL-S,Image-URL-M,Image-URL-L
0,2954,60973129,8.0,71.0,wichita,kansas,usa,Decision in Normandy,Carlo D'Este,1991,HarperPerennial,http://images.amazon.com/images/P/0060973129.0...,http://images.amazon.com/images/P/0060973129.0...,http://images.amazon.com/images/P/0060973129.0...
1,35704,374157065,6.0,53.0,kansas city,missouri,usa,Flu: The Story of the Great Influenza Pandemic...,Gina Bari Kolata,1999,Farrar Straus Giroux,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...
2,110912,374157065,10.0,36.0,milpitas,california,usa,Flu: The Story of the Great Influenza Pandemic...,Gina Bari Kolata,1999,Farrar Straus Giroux,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...
3,157969,374157065,8.0,30.0,denver,colorado,usa,Flu: The Story of the Great Influenza Pandemic...,Gina Bari Kolata,1999,Farrar Straus Giroux,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...
4,192665,374157065,8.0,47.0,vacaville,california,usa,Flu: The Story of the Great Influenza Pandemic...,Gina Bari Kolata,1999,Farrar Straus Giroux,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...,http://images.amazon.com/images/P/0374157065.0...


In [2]:
import joblib
train_df, test_df = joblib.load("../data/train_test_split.pkl")

In [3]:
print(f"trainデータサイズ：　{len(train_df)}")
print(f"testデータサイズ：　{len(test_df)}")

trainデータサイズ：　17180
testデータサイズ：　8545


In [45]:
from sklearn.metrics import mean_squared_error
from surprise.model_selection import cross_validate

class MF:
    def __init__(self, train_df, test_df):
        reader = Reader(rating_scale=(1, 10))
        self.data_train = Dataset.load_from_df(
            train_df[["User-ID", "ISBN", "Book-Rating"]], reader
        )
        self.test_df = test_df
    
    def cross_validation(self, **kwargs):
        self.n_factors = kwargs.get("n_factors", 200)
        self.lr_all = kwargs.get("lr_all", 0.005)
        self.n_epochs = kwargs.get('n_epochs', 200)
        self.reg_all = kwargs.get('reg_all', 0.4)
        
        self.mf = SVD(
            n_factors=self.n_factors,
            lr_all=self.lr_all,
            n_epochs=self.n_epochs,
            reg_all=self.reg_all
        )
        
        result = cross_validate(self.mf, self.data_train, measures=["rmse"], cv=5, n_jobs=-1)
        val_score = result["test_rmse"].mean()
        return val_score
    
    def test(self):
        full_data = self.data_train.build_full_trainset()
        model = self.mf.fit(full_data)
        predictions = model.test(full_data.build_anti_testset(None))
        
        test_data = pd.DataFrame.from_dict(
            [{"User-ID": p.uid, "ISBN": p.iid, "rating_pred": p.est}
                for p in predictions
            ]
        )
        book_rating_predict = self.test_df.merge(test_data, on=["User-ID", "ISBN"])
        
        test_score = self._calc_rmse(
            book_rating_predict["Book-Rating"].to_list(),
            book_rating_predict["rating_pred"].to_list()
        )
        
        return test_score
     
    def _calc_rmse(self, true_rating, pred_rating):
        return np.sqrt(mean_squared_error(true_rating, pred_rating))
        
        
        
    

In [46]:
mf = MF(train_df, test_df)

In [47]:
val_score = mf.cross_validation()

In [49]:
print(f"validation RMSE = {val_score:.3}")

validation RMSE = 1.51


In [50]:
test_score = mf.test()

In [51]:
print(f"test RMSE = {test_score:.3}")

test RMSE = 1.61
