In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class FCMF(nn.Module):
    '''
    Base class for Fully-Connected Matrix Factorization networks
    '''
    
    def __init__ (N, M, D, D_, K, layers):
        '''
        variable definitions taken from paper: https://arxiv.org/pdf/1511.06443.pdf
        
        @param N:  Number of users
        @param M:  Number of items
        @param D:  size of latent-feature vectors
        @param D_: num rows in latent-features matrices
        @param K:  num cols in latent-feature matrices
        
        @param layers: list of hidden layer sizes; does not include input or output
        '''
        self.N, self.M, self.D, self.D_, self.K = N, M, D, D_, K
        
        self.userLatentVectors = torch.rand(N,D, requires_grad=True)
        self.itemLatentVectors = torch.rand(M,D, requires_grad=True)
        
        self.userLatentMatrices = torch.rand(N,D_,K, requires_grad=True)
        self.itemLatentMatrices = torch.rand(M,D_,K, requires_grad=True)
        
        linear_inputs = [2*D + D_] + layers
        linear_outputs = layers + [1]
        
        self.layers = nn.ModuleList([nn.Linear(i,o) for (i,o) in zip(linear_inputs, linear_outputs)])
        
    def forward(self, x):
        '''
        @param x: let this be a tensor of size (1, 2): (user index, item index)
        
        WARNING: 
            - forward currently does not account for user/items outside of training data
            - mitigations include returning smart averages    
        '''
        
        
        userIndex, itemIndex = x[0][0].item(), x[0][1].item()
        
        latentDotProducts = torch.empty(self.D_)
        for i in range(D_):
            latentDotProducts[i] = self.userLatentMatrices[userIndex][i].dot(self.itemLatentMatrices[itemIndex[i]])
        
        x = torch.stack([
            self.userLatentVectors[userIndex],
            self.itemLatentVectors[itemIndex],
            latentDotProducts
        ])
        
        for l in self.layers[:-1]:
            x = F.relu(l(x))
        
        # TODO: should last layer go through a sigmoid?
        return self.layers[-1](x)
        