In [1]:
import torch
import torch.nn as nn

In [2]:
torch.cuda.is_available()

True

In [3]:
def make_mlp(layer_sizes, output_activation=None, activation=nn.ReLU):
    """
    Build a simple MLP from a list of layer sizes.
    Example : [13, 64, 16] -> Linear(13, 64) + ReLU + Linear(64,16)
    """
    layers = []
    for i in range(len(layer_sizes) - 1):
        in_f, out_f = layer_sizes[i], layer_sizes[i + 1]
        layers.append(nn.Linear(in_f, out_f))
        if i < len(layer_sizes) - 2:
            layers.append(activation())
        elif output_activation is not None:
            layers.append(output_activation())
    return nn.Sequential(*layers)

class DLRM(nn.Module):
    """
    DLRM-style model:
    
        dense features -> bottom MLP -> dense embedding (D)
        sparse features -> per-field embeddings (D)
        [dense embedding + sparse embeddings] -> pairwise dot-product interactions
        [dense embedding || interactions] -> top MLP -> logit
    """
    def __init__(
        self,
        num_dense_features,
        sparse_feature_sizes,
        embedding_dim=16,
        bottom_mlp_sizes=(64, 16),
        top_mlp_sizes=(64, 32, 1),
    ):
        super().__init__()
        
        self.num_dense_features = num_dense_features
        self.sparse_feature_sizes = list(sparse_feature_sizes)
        self.embedding_dim = embedding_dim
        
        # embedding tables for each sparse feature
        self.embeddings = nn.ModuleList(
            [
                nn.Embedding(num_embeddings=n, embedding_dim=embedding_dim)
                for n in self.sparse_feature_sizes
            ]
        )
        
        # Bottom MLP: num_dense_features -> ... -> embedding_dim
        bottom_layers = [num_dense_features] + list(bottom_mlp_sizes)
        if bottom_layers[-1] != embedding_dim:
            raise ValueError(
                f"Last bottom MLP size ({bottom_layers[-1]}) must equal embedding_dim ({embedding_dim})"
            )
        self.bottom_mlp = make_mlp(bottom_layers)
        
        # Interaction sizes
        self.n_sparse = len(self.sparse_feature_sizes)
        n_f = self.n_sparse + 1
        self.n_int = n_f * (n_f - 1) // 2
        
        # Top MLP: (dense embedding + interactions) -> ... -> 1 logit
        top_input_dim = embedding_dim + self.n_int
        top_layers = [top_input_dim] + list(top_mlp_sizes)
        self.top_mlp = make_mlp(top_layers, nn.Sigmoid)
        
    def forward(self, dense_x, sparse_x):
        """
        dense_x : [B, num_dense_features] (float)
        sparse_x : [B, n_sparse]          (long indices per field)
        Returns : [B, 1] logits
        """
        
        B = dense_x.size(0)
        if sparse_x.size(1) != self.n_sparse:
            raise ValueError(
                f"Expected {self.n_sparse} sparse fields, got {sparse_x.size(1)}"
            )
            
        # Bottom MLP on dense features
        # self.bottom_mlp is nn.Sequential module.
        # nn.Sequential에 정의된 MLP를 통과한 결과를 얻게 된다.
        # e.g.,) Input(13) -> Linear(13 -> 64) -> ReLU -> Linear(64 -> 16) -> Output(16)
        z0 = self.bottom_mlp(dense_x)
        
        # Embeddings for sparse features
        # self.embeddings에는 nn.ModuleList()가 할당되어있다.
        # nn.ModuleList에는 sparse feature에 관한 embedding table이 순서대로 들어가 있다.
        emb_list = []
        for i, emb in enumerate(self.embeddings):
            # Each column of sparse_x is one categorical feature
            # sparse_x[:, i]에 담긴 정수 ID들은 임베딩 테이블의 Row Index로 사용
            e = emb(sparse_x[:, i])    # [B, D]
            emb_list.append(e)
            
        # Stack dense + sparse embeddings: [B, n_f, D] --> n_f : dense + sparse 임베딩 벡터의 개수 (n_sparse + 1)
        # torch.stack(..., dim=1) 은 BxD 텐서 K개를 입력받아 B x K x D 형태의 텐서를 만든다.
        z = torch.stack([z0] + emb_list, dim=1)
        
        # Batched dot products: [B, n_f, n_f]
        # torch.bmm()은 3차원 텐서 2개를 입력받아 배치 차원을 유지하면서, 나머지 두 차원에 대해 행렬 곱셈을 수행한다.
        # 각 batch sample에 대해 n_f x n_f 크기의 행렬이 생성된다. (이 행렬은 모든 벡터 쌍 간의 내적 값을 담고 있다.)
        zz = torch.bmm(z, z.transpose(1, 2))
        
        # Upper-triangular (without diag) per sample -> interactions [B, n_int]
        n_f = zz.size(1)
        # torch.triu_indices() => 정사각 행렬에서 주 대각선 위에있는 인덱스를 생성한다.
        # offset = 0 이면 대각선을 포함하고 offset = 1 이면 대각선을 포함하지 않는다.
        # torch.triu_indices(3,3, offset=1) -> li(행 인덱스) : [0,0,1] , lj(열 인덱스) : [1,2,2]
        # li, lj 는 PyTorch 텐서
        li, lj = torch.triu_indices(n_f, n_f, offset=1)
        # interactions = [B, n_f * (n_f - 1) / 2] , size(interactions[1]) = n_int
        interactions = zz[:, li, lj]
        
        # Concatenate dense embedding with interaction vector
        top_input = torch.cat([z0, interactions], dim=1)    # [B, D + n_int]
        
        # Top MLP to get logits
        logits = self.top_mlp(top_input)    # [B, 1]
        return logits

In [58]:
# data는 criteo kaggle을 기준으로 만든다.
# ---- config ----
num_dense_features = 13
sparse_feature_sizes = [1460, 583, 10131227, 2202608, 305, 24, 12517, 633, 3,
 93145, 5683, 8351593, 3194, 27, 14992, 5461306, 10, 5652,
 2173, 4, 7046547, 18, 15, 286181, 105, 142572] # 26 sparse fields
embedding_dim = 16
import time

# ---- device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Model ----
model = DLRM(
    num_dense_features=num_dense_features,
    sparse_feature_sizes=sparse_feature_sizes,
    embedding_dim=embedding_dim,
    bottom_mlp_sizes=(512,256,64,16),
    top_mlp_sizes=(512,256,1),
).to(device)

# inference-only settings
batch_size = 65536
num_samples = 6_000_000
num_batches = (num_samples + batch_size - 1) // batch_size

model.eval()
preds = []
if device.type == "cuda":
    torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
    for b in range(num_batches):
        current_bs = min(batch_size, num_samples - b * batch_size)
        if current_bs <= 0:
            break
        # random dense and sparse inputs
        dense_x = torch.randn(current_bs, num_dense_features, device=device)
        sparse_columns = [torch.randint(0, n, (current_bs,), device=device) for n in sparse_feature_sizes]
        sparse_x = torch.stack(sparse_columns, dim=1)

        logits = model(dense_x, sparse_x).squeeze(1)
        probs = logits
        preds.append(probs.cpu())

        if b < 3:
            print(f"Batch {b+1}/{num_batches} example probs: {probs[:3].tolist()}")

batches_run = len(preds)
if device.type == "cuda":
    torch.cuda.synchronize()
elapsed = time.perf_counter() - start
print(f"Inference time: {elapsed:.3f}s for {batches_run} batches (batch_size={batch_size}).")

preds = torch.cat(preds)
print(f"Finished inference for {preds.numel()} samples (batch_size={batch_size}).")

Using device: cuda
Batch 1/92 example probs: [0.5394262075424194, 0.45091965794563293, 0.5583493113517761]
Batch 2/92 example probs: [0.5113988518714905, 0.5717505216598511, 0.5469425916671753]
Batch 3/92 example probs: [0.4991457760334015, 0.42554280161857605, 0.44580867886543274]
Inference time: 0.296s for 92 batches (batch_size=65536).
Finished inference for 6000000 samples (batch_size=65536).
