In [None]:
# default_exp models.hofm

# HOFM
> A pytorch implementation of Higher-Order Factorization Machines.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import torch

from recohut.layers.common import FeaturesEmbedding, FeaturesLinear

In [None]:
#export
class FactorizationMachine(torch.nn.Module):

    def __init__(self, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        square_of_sum = torch.sum(x, dim=1) ** 2
        sum_of_square = torch.sum(x ** 2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

class AnovaKernel(torch.nn.Module):

    def __init__(self, order, reduce_sum=True):
        super().__init__()
        self.order = order
        self.reduce_sum = reduce_sum

    def forward(self, x):
        """
        :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
        """
        batch_size, num_fields, embed_dim = x.shape
        a_prev = torch.ones((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device)
        for t in range(self.order):
            a = torch.zeros((batch_size, num_fields + 1, embed_dim), dtype=torch.float).to(x.device)
            a[:, t+1:, :] += x[:, t:, :] * a_prev[:, t:-1, :]
            a = torch.cumsum(a, dim=1)
            a_prev = a
        if self.reduce_sum:
            return torch.sum(a[:, -1, :], dim=-1, keepdim=True)
        else:
            return a[:, -1, :]

class HOFM(torch.nn.Module):
    """
    A pytorch implementation of Higher-Order Factorization Machines.
    Reference:
        M Blondel, et al. Higher-Order Factorization Machines, 2016.
    """

    def __init__(self, field_dims, order, embed_dim):
        super().__init__()
        if order < 1:
            raise ValueError(f'invalid order: {order}')
        self.order = order
        self.embed_dim = embed_dim
        self.linear = FeaturesLinear(field_dims)
        if order >= 2:
            self.embedding = FeaturesEmbedding(field_dims, embed_dim * (order - 1))
            self.fm = FactorizationMachine(reduce_sum=True)
        if order >= 3:
            self.kernels = torch.nn.ModuleList([
                AnovaKernel(order=i, reduce_sum=True) for i in range(3, order + 1)
            ])

    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        y = self.linear(x).squeeze(1)
        if self.order >= 2:
            x = self.embedding(x)
            x_part = x[:, :, :self.embed_dim]
            y += self.fm(x_part).squeeze(1)
            for i in range(self.order - 2):
                x_part = x[:, :, (i + 1) * self.embed_dim: (i + 2) * self.embed_dim]
                y += self.kernels[i](x_part).squeeze(1)
        return torch.sigmoid(y)

> **References:-**
- M Blondel, et al. Higher-Order Factorization Machines, 2016.
- https://github.com/rixwew/pytorch-fm/blob/master/torchfm/model/hofm.py

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p recohut