In [1]:
# Load NextItemPredTransformer

from NextItemPredTransformer import NextItemPredTransformer
from NextItemPredTransformer import ModelDimensions
import torch
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Vocab size includes the SOS and EOS tokens
vocab_size = 100
# Initialize a random item embedding matrix
item_embedding_matrix = torch.rand(vocab_size, 64)
# Initialize a random user embedding matrix
user_embedding_matrix = torch.rand(vocab_size, 64)

In [3]:
# Init NextItemPredTransformer
dims = ModelDimensions(
    model_input_length=102,
    model_hidden_dim=40,
    n_attention_heads=8,
    n_decoder_layers=3,
    vocab_size=vocab_size,
    pre_trained_item_embeddings=item_embedding_matrix,
    pre_trained_user_embeddings=user_embedding_matrix,
    use_concat_user_embedding=True,
)

model = NextItemPredTransformer(dims)

In [4]:
# Define training hyper parameters
batch_size = 32
lr = 1e-3
n_epochs = 10
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()


In [5]:

# list from 0 to 10
user_ids = list(range(10))
# list of 10 random arrays of shape (10) where random numbers are between 0 and 100
user_items_vectors = [torch.randint(0, 100, (10,)) for _ in range(10)]
# list of 10 random arrays of shape (10)
user_rating_times_vectors = [torch.rand(10) for _ in range(10)]
max_seq_len = 100

In [6]:
from NextItemPredDataset import NextItemPredDataset
# Create a dataset
dataset = NextItemPredDataset(
    user_ids, user_items_vectors, user_rating_times_vectors, max_seq_len
)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=False, num_workers=0
)


  items[i, 1 : len(user_items) + 1] = torch.tensor(user_items)
  times[i, 1 : len(user_rating_times_vectors[i]) + 1] = torch.tensor(


In [8]:
# Create a training loop
for epoch in range(n_epochs):
    for batch in dataloader:
        # Get the inputs; data is a list of [inputs, labels]
        user_ids, items, times, pred_index, true_item_id = batch
        # These are the actual times that we want to predict the next item for(simulates user recommendation time)
        pred_times = times[:, pred_index]
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(items, user_ids, times, pred_times)
        # Create empty tensor of shape (batch_size, 1, dim)
        # for i, pred in enumerate(pred_index):
        #     # For each batch take the relevant index from the output based on the pred_index for that batch
        #     relevant_outputs[i] = outputs[i][pred]
        # Convert this loop to a tensor operation
        relevant_outputs = torch.gather(outputs, 1, pred_index.unsqueeze(1).unsqueeze(2).expand(-1, -1, vocab_size))
        
        # Squeeze dim 1 for relevant_outputs
        relevant_outputs = torch.squeeze(relevant_outputs, dim=1)
        loss = criterion(relevant_outputs, true_item_id)
        # Backward and optimize
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}")

KeyboardInterrupt: 