In [1]:
import torch
import torch.nn as nn
from common import MLP, CIN

In [None]:
class xDeepFM(nn.Module):
    def __init__(self, params, get_embeddings=True, use_batchnorm=True, use_dropout=True, use_fm_second_order=False):
        super(xDeepFM, self).__init__()
        self.device = params['device']
        self.mlp_input_dim = params['field_size']*params['embedding_size']
        self.use_fm_second_order = use_fm_second_order
        
        self.first_order = FirstOrder(params)
        self.second_order = SecondOrder(params, get_embeddings=get_embeddings)
        self.mlp = MLP(params, use_batchnorm=use_batchnorm, use_dropout=use_dropout)
        self.cin = CIN(params)
        if params['split_half']:
            cinOutputSize = reduce(lambda x,y:x//2+y//2, params['cin_hidden_dims'])
        else:
            cinOutputSize = reduce(lambda x,y:x+y, params['cin_hidden_dims'])
        if self.use_fm_second_order:
            concat_size = params['field_size']+params['embedding_size']+params['hidden_dims'][-1]+cinOutputSize
        else:
            concat_size = params['field_size']+params['hidden_dims'][-1]+cinOutputSize
        self.concat_layer = nn.Linear(concat_size, 1).to(self.device)\
        
    def forward(self, features):
        feature_idx = features['feature_idx']
        feature_values = features['feature_values']

        first_order = self.first_order(feature_value, feature_idx)
        second_order, embeddings = self.second_order(feature_values, feature_idx)

        mlpInput = embeddings.reshape(embeddings.shape[0], self.mlp_input_dim)
        mlpOut = self.mlp(mlpInput)

        cinOut = self.cin(embeddings)

        if self.use_fm_second_order:
            concat = torch.cat([first_order, second_order, mlpOut, cinOut], dim=1)
        else:
            concat = torch.cat([first_order, mlpOut, cinOut], dim=1)
        logits = self.concat_layer(concat)

        return logits