In [25]:
# runs in jupyter container
import os
import torch
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

DATA_DIR = os.getenv("MOVIELENS_DATA_DIR", "/mnt/Project-37")
print("Data dir:", DATA_DIR)

Data dir: /mnt/Project-37


In [33]:
# runs in jupyter container
eval_path = os.path.join(DATA_DIR, "evaluation", "movielens_192m_eval.txt")
eval_df = pd.read_csv(eval_path, sep="\t", names=["userId", "itemId"])
print("Before filtering:", eval_df.shape)

eval_df = eval_df[(eval_df["userId"] < 10000) & (eval_df["itemId"] < 10000)]
print("After filtering:", eval_df.shape)

Before filtering: (33600781, 2)
After filtering: (172949, 2)


In [34]:
# runs in jupyter container
class EvalDataset(Dataset):
    def __init__(self, dataframe):
        self.users = torch.tensor(dataframe['userId'].values, dtype=torch.long)
        self.items = torch.tensor(dataframe['itemId'].values, dtype=torch.long)

    def __len__(self):
        return len(self.users)

    def __getitem__(self, idx):
        return self.users[idx], self.items[idx]

eval_dataset = EvalDataset(eval_df)
eval_loader = DataLoader(eval_dataset, batch_size=128, shuffle=False)


In [35]:
# runs in jupyter container
class PointWiseFeedForward(nn.Module):
    def __init__(self, conv_dims, dropout_rate):
        super().__init__()
        self.conv1 = nn.Conv1d(conv_dims, conv_dims, kernel_size=1)
        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(conv_dims, conv_dims, kernel_size=1)
        self.dropout2 = nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2)
        outputs += inputs
        return outputs

In [36]:
# runs in jupyter container
class TransformerBlock(nn.Module):
    def __init__(self, seq_max_len, embedding_dim, num_heads, dropout_rate):
        super().__init__()
        self.seq_max_len = seq_max_len
        self.embedding_dim = embedding_dim
        self.mha = nn.MultiheadAttention(embedding_dim, num_heads, dropout_rate, batch_first=True)
        self.ffn = PointWiseFeedForward(embedding_dim, dropout_rate)
        self.layernorm = nn.LayerNorm(embedding_dim)

    def forward(self, input):
        mask = ~torch.tril(torch.ones((input.shape[1], input.shape[1]), dtype=torch.bool, device=input.device))
        query = self.layernorm(input)
        output, _ = self.mha(query, input, input, attn_mask=mask)
        output += query
        output = self.ffn(self.layernorm(output))
        return output

In [37]:
# runs in jupyter container
class SSEPT(nn.Module):
    def __init__(self, user_num, item_num, **kwargs):
        super().__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.seq_max_len = kwargs.get("seq_max_len", 50)
        self.num_blocks = kwargs.get("num_blocks", 2)
        self.embedding_dim = kwargs.get("embedding_dim", 64)
        self.attention_num_heads = kwargs.get("attention_num_heads", 2)
        self.dropout_rate = kwargs.get("dropout_rate", 0.2)

        self.user_embedding_layer = nn.Embedding(self.user_num + 1, self.embedding_dim, padding_idx=0)
        self.item_embedding_layer = nn.Embedding(self.item_num + 1, self.embedding_dim, padding_idx=0)
        self.positional_embedding_layer = nn.Embedding(self.seq_max_len, self.embedding_dim * 2)

        self.encoderlayers = nn.ModuleList([
            TransformerBlock(self.seq_max_len, self.embedding_dim * 2, self.attention_num_heads, self.dropout_rate)
            for _ in range(self.num_blocks)
        ])

        self.lastlayernorm = nn.LayerNorm(self.embedding_dim * 2)

    def embedding_all(self, user, input_seq):
        item_emb = self.item_embedding_layer(input_seq) * (self.embedding_dim ** 0.5)
        user_emb = self.user_embedding_layer(user) * (self.embedding_dim ** 0.5)
        user_emb_exp = torch.tile(user_emb.unsqueeze(1), [1, input_seq.size(1), 1])
        seq_emb = torch.cat([item_emb, user_emb_exp], dim=2)

        positions = torch.tile(torch.arange(input_seq.size(1)), [input_seq.size(0), 1]).to(input_seq.device)
        pos_emb = self.positional_embedding_layer(positions)
        seq_emb += pos_emb

        timeline_mask = (input_seq == 0)
        seq_emb *= (~timeline_mask).unsqueeze(-1)
        return seq_emb, user_emb, timeline_mask

    def Encoder(self, seq_emb, mask):
        for block in self.encoderlayers:
            seq_emb = block(seq_emb)
            seq_emb *= (~mask).unsqueeze(-1)
        return seq_emb

    def predict(self, user, input_seq, item_indices):
        seq_emb, user_emb, mask = self.embedding_all(user, input_seq)
        seq_output = self.Encoder(seq_emb, mask)
        log_feats = self.lastlayernorm(seq_output)
        final_feat = log_feats[:, -1, :]
        item_embs = self.item_embedding_layer(item_indices)
        item_embs = item_embs * (self.embedding_dim ** 0.5)
        user_embs = self.user_embedding_layer(user) * (self.embedding_dim ** 0.5)
        pair_emb = torch.cat([item_embs, user_embs], dim=1)
        logits = pair_emb.matmul(final_feat.unsqueeze(-1)).squeeze(-1)
        return logits

In [38]:
# runs in jupyter container
model = SSEPT(
    user_num=10_000,
    item_num=10_000,
    seq_max_len=200,
    embedding_dim=50,
    attention_num_heads=1,
    num_blocks=6,
    dropout_rate=0.2
)
model.load_state_dict(torch.load("models/SSE_PT.pth", map_location="cpu"))
model.eval()
print("Model loaded.")

Model loaded.


In [None]:
# runs in jupyter container
dummy_seq = torch.ones(128, 50, dtype=torch.long)
results = []

with torch.no_grad():
    for i, (users, items) in enumerate(eval_loader):
        if users.shape[0] != dummy_seq.shape[0]:
            seq = dummy_seq[:users.shape[0], :]
        else:
            seq = dummy_seq
        logits = model.predict(users, seq, items)
        results.extend(logits.cpu().numpy())
        if i % 10 == 0:
            print(f"Processed batch {i}, output shape: {logits.shape}")

print("Done. Avg predicted score:", np.mean(results))

Processed batch 0, output shape: torch.Size([128, 128])
Processed batch 10, output shape: torch.Size([128, 128])
Processed batch 20, output shape: torch.Size([128, 128])
Processed batch 30, output shape: torch.Size([128, 128])
Processed batch 40, output shape: torch.Size([128, 128])
Processed batch 50, output shape: torch.Size([128, 128])
Processed batch 60, output shape: torch.Size([128, 128])
Processed batch 70, output shape: torch.Size([128, 128])
Processed batch 80, output shape: torch.Size([128, 128])
Processed batch 90, output shape: torch.Size([128, 128])
Processed batch 100, output shape: torch.Size([128, 128])
Processed batch 110, output shape: torch.Size([128, 128])
Processed batch 120, output shape: torch.Size([128, 128])
Processed batch 130, output shape: torch.Size([128, 128])
Processed batch 140, output shape: torch.Size([128, 128])
Processed batch 150, output shape: torch.Size([128, 128])
Processed batch 160, output shape: torch.Size([128, 128])
Processed batch 170, outp