In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

In [2]:
class FCMF(nn.Module):
    '''
    Base class for Fully-Connected Matrix Factorization networks
    '''
    
    def __init__ (self, N, M, D, D_, K, layers):
        '''
        variable definitions taken from paper: https://arxiv.org/pdf/1511.06443.pdf
        
        @param N:  Number of users
        @param M:  Number of items
        @param D:  size of latent-feature vectors
        @param D_: num rows in latent-features matrices
        @param K:  num cols in latent-feature matrices
        
        @param layers: list of hidden layer sizes; does not include input or output
        '''
        
        assert (min(N,M,D,D_,K) > 0), "Params must be nonzero and positive"
        assert (len(layers) > 0),     "Must have nonzero hidden layers"
        
        ########################################################################
        
        super(FCMF, self).__init__()
        
        self.N, self.M, self.D, self.D_, self.K = N, M, D, D_, K
        
        self.userLatentVectors = torch.rand(N,D, requires_grad=True)
        self.itemLatentVectors = torch.rand(M,D, requires_grad=True)
        
        self.userLatentMatrices = torch.rand(N,D_,K, requires_grad=True)
        self.itemLatentMatrices = torch.rand(M,D_,K, requires_grad=True)
        
        linear_inputs = [2*D + D_] + layers
        linear_outputs = layers + [1]
        
        self.layers = nn.ModuleList([nn.Linear(i,o) for (i,o) in zip(linear_inputs, linear_outputs)])
        
    def forward(self, x):
        '''
        @param x: let this be a tensor of size (X, 2): (user index, item index)
        
        WARNING: 
            - forward currently does not account for user/items outside of training data
            - mitigations include returning smart averages    
        '''        
        userIndices, itemIndices = x[:,0].long(), x[:,1].long()
                
        userLatMats = self.userLatentMatrices[userIndices]
        itemLatMats = self.itemLatentMatrices[itemIndices]
        latentDotProducts = torch.sum(userLatMats * itemLatMats, dim=-1)        
        
        x = torch.hstack([
            self.userLatentVectors[userIndices],
            self.itemLatentVectors[itemIndices],
            latentDotProducts
        ])
        
        for l in self.layers[:-1]:
            x = F.relu(l(x))
        
        # TODO: should last layer go through a sigmoid?
        return self.layers[-1](x)
        

In [3]:
def getBatches(mat, usersPerBatch=100):
    '''
    batchSize = min(N - start, usersPerBatch) * M
    '''
    N, M = mat.shape
    
    start = 0
    while start < N:
        batchSize = min(N - start, usersPerBatch) * M
                
        batch_x = torch.empty(batchSize, 2)
        batch_y = torch.empty(batchSize, 1)
        
        for userId, ratings in enumerate(mat[start: start+N]):
            for movieId, stars in enumerate(ratings):
                
                curId = userId * M + movieId
                                
                batch_x[curId][0] = userId
                batch_x[curId][1] = movieId
                batch_y[curId][0] = stars
                    
        start += N
        
        yield (batch_x, batch_y)
        
    

def trainEpoch(opt, criterion, model, mat):
    opt.zero_grad()
    loss = 0
    for batch_x, batch_y in getBatches(mat):
        pred_y = model(batch_x)
        loss += criterion(batch_y, pred_y)
    loss.backward()
    optimizer.step()

In [4]:
numUsers = 5
numItems = 5

testMatrix = np.random.randint(6, size=(numUsers, numItems))
testMatrix

array([[2, 5, 2, 3, 2],
       [4, 0, 3, 2, 0],
       [1, 1, 3, 1, 1],
       [1, 2, 5, 5, 4],
       [2, 0, 5, 5, 0]])

In [5]:
fc3 = FCMF(numUsers, numItems ,2,2,1,[5])

In [6]:
import torch.optim as optim

# Paper uses RMSE as objective and RMSProp optimizer
criterion = nn.MSELoss()
optimizer = optim.RMSprop(fc3.parameters(), lr=0.01)

In [7]:
trainEpoch(optimizer, criterion, fc3, testMatrix)

  Variable._execution_engine.run_backward(
