In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from fastai.vision.all import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset


In [3]:
from fastai.collab import *
from fastai.tabular.all import *
path = untar_data(URLs.ML_100k)

In [4]:
ratings = pd.read_csv(path/'u.data', delimiter='\t', header=None,
                      names=['user', 'movie', 'rating', 'timestamp'])
ratings.head()

Unnamed: 0,user,movie,rating,timestamp
0,196,242,3,881250949
1,186,302,3,891717742
2,22,377,1,878887116
3,244,51,2,880606923
4,166,346,1,886397596


In [5]:
movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1'
                     ,usecols=(0,1), names=('movie', 'title'), header=None)
movies.head()

Unnamed: 0,movie,title
0,1,Toy Story (1995)
1,2,GoldenEye (1995)
2,3,Four Rooms (1995)
3,4,Get Shorty (1995)
4,5,Copycat (1995)


In [9]:
ratings = ratings.merge(movies)
ratings.head()

Unnamed: 0,user,movie,rating,timestamp,title
0,196,242,3,881250949,Kolya (1996)
1,63,242,3,875747190,Kolya (1996)
2,226,242,5,883888671,Kolya (1996)
3,154,242,3,879138235,Kolya (1996)
4,306,242,5,876503793,Kolya (1996)


In [7]:
dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)
dls.show_batch()

Unnamed: 0,user,title,rating
0,399,"Shawshank Redemption, The (1994)",3
1,385,True Lies (1994),1
2,624,Bed of Roses (1996),4
3,276,"Shining, The (1980)",5
4,87,"Santa Clause, The (1994)",4
5,661,Alien (1979),4
6,642,Richie Rich (1994),4
7,632,Executive Decision (1996),2
8,416,Nell (1994),5
9,643,When a Man Loves a Woman (1994),3


In [10]:
n_users = len(dls.classes['user'])
n_movies = len(dls.classes['title'])
n_factors=5

In [11]:
user_factors = torch.randn(n_users, n_factors)
movie_factors = torch.randn(n_movies, n_factors)

In [13]:
user_factors[0]
movie_factors[0]

tensor([-2.0546,  0.9548, -1.0165,  0.6747, -0.0802])

the tric for finding the user or movie out of all users and movies is to just multiply their matrices by one hot encoding vector which is 1 at the index of the user or movie we want

In [21]:
one_hot_3 = one_hot(3, n_users).float()


In [22]:
user_factors.t() @ one_hot_3

tensor([ 0.7067, -0.5306, -0.4217,  0.2815,  1.8563])

In [23]:
user_factors[3]

tensor([ 0.7067, -0.5306, -0.4217,  0.2815,  1.8563])

In [48]:
class DotProduct(Module):
       def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):
              self.user_factors = Embedding(n_users, n_factors)
              self.movie_factors = Embedding(n_movies, n_factors)
              self.y_range = y_range


       def forward(self, x):
              users = self.user_factors(x[:,0])
              movies = self.movie_factors(x[:,1])
              return sigmoid_range((users * movies).sum(dim=1), *self.y_range)

In [49]:
x, y = dls.one_batch()
x.shape, y.shape

(torch.Size([64, 2]), torch.Size([64, 1]))

In [50]:
model = DotProduct(n_users,n_movies, 50)
learn = Learner(dls, model, loss_func=MSELossFlat())

In [51]:
learn.fit_one_cycle(5, 5e-3)

epoch,train_loss,valid_loss,time
0,0.92708,1.000413,00:10
1,0.648457,0.950271,00:09
2,0.441373,0.949463,00:09
3,0.35632,0.950903,00:08
4,0.350133,0.949383,00:08


In [52]:
class DotProduct(Module):
       def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):
              self.user_factors = Embedding(n_users, n_factors)
              self.user_bias = Embedding(n_users, 1)
              self.movie_factors = Embedding(n_movies, n_factors)
              self.movie_bias = Embedding(n_movies, 1)
              self.y_range = y_range


       def forward(self, x):
              users = self.user_factors(x[:,0])
              movies = self.movie_factors(x[:,1])
              res = (users * movies).sum(dim=1, keepdim=True)
              res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])
              return sigmoid_range(res, *self.y_range)

In [53]:
model = DotProduct(n_users,n_movies, 50)
learn = Learner(dls, model, loss_func=MSELossFlat())
learn.fit_one_cycle(5, 5e-3)

epoch,train_loss,valid_loss,time
0,0.866512,0.949807,00:11
1,0.568673,0.910155,00:10
2,0.40278,0.935953,00:11
3,0.323955,0.944626,00:11
4,0.285354,0.9452,00:10
