In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

Factorization machines mathematically represented :

$$
\begin{align}
\hat{y} = w_0 + \sum_{i=1}^n w_i x_i + \sum_{i=1}^n \sum_{j=i+1}^n \langle v_i, v_j \rangle x_i x_j
\end{align}
$$

where $w_0$ is the bias term, $w_i$ is the linear term for the $i$-th feature, $v_i$ is the factor vector for the $i$-th feature, and $\langle \cdot, \cdot \rangle$ is the inner product.

The interaction term can be shown to be equivalent to rowwise sum of XV element wise squared minus rowwise sum of (X squared element wise squared * V squared element wise squared) divided by 2.

Useful references:
https://www.kaggle.com/code/gennadylaptev/factorization-machine-implemented-in-pytorch


In [4]:
#factorization machine in pytorch
class FM(nn.Module):
    def __init__(self, n, k):
        """   
        n: number of features
        k: number of latent factors
        """
        super(FM, self).__init__()
        self.n = n
        self.k = k
        self.linear = nn.Linear(n, 1, bias=True)
        self.v = nn.Parameter(torch.randn(n, k))
        
    def forward(self, x):
        # x is a batch of samples, each sample has n features
        # linear part
        linear = self.linear(x)
        # factorization machine part
        # first order interaction
        x = x.unsqueeze(2) # unsqueeze to make it 3-dim
        v = self.v.unsqueeze(0)
        inter_1 = torch.bmm(x, v) # batch matrix multiplication
        inter_1 = torch.pow(inter_1, 2) # element-wise power to obtain the square
        
        # second order interaction
        inter_2 = torch.pow(x, 2) # element-wise power to obtain the square
        inter_2 = torch.bmm(inter_2, torch.pow(v, 2)) # batch matrix multiplication
        
        fm = 0.5 * torch.sum(inter_1 - inter_2, dim=1)
        fm = fm.squeeze(1) # squeeze to make it 2-dim
        # output
        out = linear + fm
        return out