In [1]:
import torch
import torch.nn as nn

In [2]:
torch.cuda.is_available()

True

In [8]:
def make_mlp(layer_sizes, activation=nn.ReLU, output_activation=None):
    """
    Build a simple MLP from a list of layer sizes.
    Example : [13, 64, 16] -> Linear(13, 64) + ReLU + Linear(64,16)
    """
    layers = []
    for i in range(len(layer_sizes) - 1):
        in_f, out_f = layer_sizes[i], layer_sizes[i + 1]
        layers.append(nn.Linear(in_f, out_f))
        if i < len(layer_sizes) - 2:
            layers.append(activation())
        elif output_activation is not None:
            layers.append(output_activation())
    return nn.Sequential(*layers)

class DLRM(nn.Module):
    """
    Minimal DLRM-style model:
    
        dense features -> bottom MLP -> dense embedding (D)
        sparse features -> per-field embeddings (D)
        [dense embedding + sparse embeddings] -> pairwise dot-product interactions
        [dense embedding || interactions] -> top MLP -> logit
    """
    def __init__(
        self,
        num_dense_features,
        sparse_feature_sizes,
        embedding_dim=16,
        bottom_mlp_sizes=(64, 16),
        top_mlp_sizes=(64, 32, 1),
    ):
        super().__init__()
        
        self.num_dense_features = num_dense_features
        self.sparse_feature_sizes = list(sparse_feature_sizes)
        self.embedding_dim = embedding_dim
        
        # embedding tables for each sparse feature
        self.embeddings = nn.ModuleList(
            [
                nn.Embedding(num_embeddings=n, embedding_dim=embedding_dim)
                for n in self.sparse_feature_sizes
            ]
        )
        
        # Bottom MLP: num_dense_features -> ... -> embedding_dim
        bottom_layers = [num_dense_features] + list(bottom_mlp_sizes)
        if bottom_layers[-1] != embedding_dim:
            raise ValueError(
                f"Last bottom MLP size ({bottom_layers[-1]}) must equal embedding_dim ({embedding_dim})"
            )
        self.bottom_mlp = make_mlp(bottom_layers)
        
        # Interaction sizes
        self.n_sparse = len(self.sparse_feature_sizes)
        n_f = self.n_sparse + 1
        self.n_int = n_f * (n_f - 1) // 2
        
        # Top MLP: (dense embedding + interactions) -> ... -> 1 logit
        top_input_dim = embedding_dim + self.n_int
        top_layers = [top_input_dim] + list(top_mlp_sizes)
        self.top_mlp = make_mlp(top_layers)
        
    def forward(self, dense_x, sparse_x):
        """
        dense_x : [B, num_dense_features] (float)
        sparse_x : [B, n_sparse]          (long indices per field)
        Returns : [B, 1] logits
        """
        
        B = dense_x.size(0)
        if sparse_x.size(1) != self.n_sparse:
            raise ValueError(
                f"Expected {self.n_sparse} sparse fields, got {sparse_x.size(1)}"
            )
            
        # Bottom MLP on dense features
        z0 = self.bottom_mlp(dense_x)
        
        # Embeddings for sparse features
        emb_list = []
        for i, emb in enumerate(self.embeddings):
            # Each column of sparse_x is one categorical feature
            e = emb(sparse_x[:, i])    # [B, D]
            emb_list.append(e)
            
        # Stack dense + sparse embeddings: [B, n_f, D]
        z = torch.stack([z0] + emb_list, dim=1)
        
        # Batched dot products: [B, n_f, n_f]
        zz = torch.bmm(z, z.transpose(1, 2))
        
        # Upper-triangular (without diag) per sample -> interactions [B, n_int]
        n_f = zz.size(1)
        li, lj = torch.triu_indices(n_f, n_f, offset=1)
        interactions = zz[:, li, lj]
        
        # Concatenate dense embedding with interaction vector
        top_input = torch.cat([z0, interactions], dim=1)    # [B, D + n_int]
        
        # Top MLP to get logits
        logits = self.top_mlp(top_input)    # [B, 1]
        return logits

In [11]:
# ---- config ----
num_dense_features = 13
sparse_feature_sizes = [1000, 5000, 10000]    # 3 sparse fields
embedding_dim = 16

# ---- device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Model ----
model = DLRM(
    num_dense_features=num_dense_features,
    sparse_feature_sizes=sparse_feature_sizes,
    embedding_dim=embedding_dim,
    bottom_mlp_sizes=(64,16),
    top_mlp_sizes=(64,32,1),
).to(device)

# dummy batch
batch_size = 32

# dense numerical features
dense_x = torch.randn(batch_size, num_dense_features, device=device)

# sparse categorical features: one ID per field per sample
sparse_columns = []
for n in sparse_feature_sizes:
    sparse_columns.append(torch.randint(0, n, (batch_size,), device=device))
sparse_x = torch.stack(sparse_columns, dim=1)    # [B, n_sparse]

# binary labels (e.g., click / no click)
targets = torch.randint(0, 2, (batch_size,), device=device).float()

# training step example
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model.train()
for step in range(3):
    optimizer.zero_grad()
    logits = model(dense_x, sparse_x).squeeze(1)    # [B]
    loss = criterion(logits, targets)
    loss.backward()
    optimizer.step()
    print(f"Step {step} | loss = {loss.item():.4f}")
    
# inference example
model.eval()
with torch.no_grad():
    logits = model(dense_x, sparse_x).squeeze(1)
    probs = torch.sigmoid(logits)
    print("Predicted CTR probabilities:", probs[:5])

Using device: cuda
Step 0 | loss = 0.7447
Step 1 | loss = 0.7320
Step 2 | loss = 0.7206
Predicted CTR probabilities: tensor([0.4885, 0.4028, 0.4506, 0.4592, 0.4927], device='cuda:0')
