In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeatureInteraction(nn.Module):
    def __init__(self, self_interaction):
        super().__init__()
        self.self_interaction = self_interaction

    def forward(self, inputs):
        feature_dim = inputs.shape[1]
        concat_features = inputs.view(-1, feature_dim, 1)
        dot_products = torch.matmul(concat_features, concat_features.transpose(1, 2))
        ones = torch.ones_like(dot_products)
        mask = torch.triu(ones)
        out_dim = feature_dim * (feature_dim + 1) // 2
        flat_result = dot_products[mask.bool()]
        reshape_result = flat_result.view(-1, out_dim)
        return reshape_result

class DLRM(nn.Module):
    def __init__(self, sparse_feature_number, dense_feature_number, num_embeddings, embed_dim, bottom_mlp_dims, top_mlp_dims, self_interaction):
        super(DLRM, self).__init__()
        self.embed_dim = embed_dim
        self.sparse_feature_number = sparse_feature_number
        self.bottom_mlp_output_dim = bottom_mlp_dims[-1]
        self.embedding = nn.Embedding(num_embeddings, embed_dim)
        self.layer_feature_interaction = FeatureInteraction(self_interaction)
        self.bottom_mlp = nn.Sequential(
            nn.Linear(dense_feature_number, bottom_mlp_dims[0]),
            nn.ReLU(),
            nn.Linear(bottom_mlp_dims[0], bottom_mlp_dims[1]),
            nn.ReLU()
        )
        feature_interaction_input_dim = self.bottom_mlp_output_dim + (sparse_feature_number * embed_dim)
        interaction_output_dim = (feature_interaction_input_dim * (feature_interaction_input_dim + 1)) // 2
        input_dim_for_top_mlp = interaction_output_dim + self.bottom_mlp_output_dim
        self.top_mlp = nn.Sequential(
            nn.Linear(input_dim_for_top_mlp, top_mlp_dims[0]),
            nn.ReLU(),
            nn.Linear(top_mlp_dims[0], top_mlp_dims[1]),
            nn.ReLU(),
            nn.Linear(top_mlp_dims[1], 1)
        )

    def forward(self, x_sparse, x_dense):
        embed_x = self.embedding(x_sparse)
        embed_x = embed_x.view(x_sparse.shape[0], -1)
        bottom_mlp_output = self.bottom_mlp(x_dense)
        concat_first = torch.cat([bottom_mlp_output, embed_x], dim=-1)
        interaction = self.layer_feature_interaction(concat_first)
        concat_second = torch.cat([interaction, bottom_mlp_output], dim=-1)
        output = self.top_mlp(concat_second)
        return output.squeeze().unsqueeze(1)