In [1]:
from nrms import NRMS

model = NRMS(
    vocab_size=50000
)

In [13]:
import torch

# Settings
B = 2       # batch size (users)
N = 5       # max clicked articles per user
D = 7       # number of total candidate/news documents
L = 10      # max token length per article
vocab_size = 10000  # vocabulary size

PAD_TOKEN_ID = 0

def generate_padded_articles(batch_size, num_articles, max_len, vocab_size, pad_token_id=0):
    """Generate a batch of padded articles and corresponding masks."""
    token_ids = torch.full((batch_size, num_articles, max_len), pad_token_id, dtype=torch.long)
    token_mask = torch.ones((batch_size, num_articles, max_len), dtype=torch.bool)

    for b in range(batch_size):
        for n in range(num_articles):
            real_len = torch.randint(3, max_len + 1, (1,)).item()
            tokens = torch.randint(1, vocab_size, (real_len,))
            token_ids[b, n, :real_len] = tokens
            token_mask[b, n, :real_len] = False  # False means "not masked" (valid token)

    return token_ids, token_mask

def generate_padded_docs(num_docs, max_len, vocab_size, pad_token_id=0):
    """Generate a list of padded news documents and masks."""
    token_ids = torch.full((num_docs, max_len), pad_token_id, dtype=torch.long)
    token_mask = torch.ones((num_docs, max_len), dtype=torch.bool)

    for d in range(num_docs):
        real_len = torch.randint(5, max_len + 1, (1,)).item()
        tokens = torch.randint(1, vocab_size, (real_len,))
        token_ids[d, :real_len] = tokens
        token_mask[d, :real_len] = False

    return token_ids, token_mask


# Generate clicked history (B, N, L) and mask (B, N, L)
clicked_token_ids, clicked_token_mask = generate_padded_articles(B, N, L, vocab_size, PAD_TOKEN_ID)

# Generate all documents (D, L) and mask (D, L)
all_doc_token_ids, all_doc_token_mask = generate_padded_docs(D, L, vocab_size, PAD_TOKEN_ID)

# Model init & forward
model = NRMS(vocab_size=vocab_size)
scores = model(
    clicked_token_ids=clicked_token_ids,
    clicked_token_mask=clicked_token_mask,
    all_doc_token_ids=all_doc_token_ids,
    all_doc_token_mask=all_doc_token_mask,
)

print(f'{clicked_token_ids=}')
print(f'{clicked_token_ids.shape=}')
print(f'{clicked_token_mask=}')
print(f'{clicked_token_mask.shape=}')
print(f'{all_doc_token_ids=}')
print(f'{all_doc_token_ids.shape=}')
print(f'{all_doc_token_mask=}')
print(f'{all_doc_token_mask.shape=}')
print(f'{scores.shape=}')
scores

clicked_token_ids=tensor([[[5959, 3717, 7365, 3245,    0,    0,    0,    0,    0,    0],
         [3218, 8146, 3787, 3019, 7701,    0,    0,    0,    0,    0],
         [3905,  347, 9549, 1592, 2368, 8427, 8634,    0,    0,    0],
         [1660,  814, 3688,  990, 4592, 5597, 9580, 8579,    0,    0],
         [2561, 5567, 1911, 1961,    0,    0,    0,    0,    0,    0]],

        [[4245, 4673, 6280, 7812, 1732, 4769, 8557, 2882,    0,    0],
         [2436,  120, 9091, 4894,    0,    0,    0,    0,    0,    0],
         [7259, 4478, 7088, 1982, 8416, 4651, 7126, 3171, 5298,    0],
         [4641, 9458,   79, 7506,  644, 3165, 3375, 2877, 4051,  402],
         [4155, 9618, 9704,    0,    0,    0,    0,    0,    0,    0]]])
clicked_token_ids.shape=torch.Size([2, 5, 10])
clicked_token_mask=tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True],
         [False, False, False, False, False,  True,  True,  True,  True,  True],
         [False, False, False, False, Fa

tensor([[0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0.]], grad_fn=<SoftmaxBackward0>)