In [37]:
import numpy as np
import pickle
import random
from typing import List, Dict, Tuple
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

In [2]:
def load_ctx(dataset_name) -> Dict[Tuple[int, int], np.ndarray]:
    with open(f"../data/{dataset_name}_ctx.dat", "rb") as rf:
        ctx = pickle.load(rf)
    
    # Cast context values from list to numpy array
    for k in ctx.keys():
        ctx[k] = np.array(ctx[k])
    
    return ctx

In [3]:
def load_attrs(dataset_name) -> np.ndarray:
    with open(f"../data/{dataset_name}_attrs.dat", "rb") as rf:
        attrs = pickle.load(rf)
    
    # Add zero row for <pad> item
    attrs = np.concatenate((np.zeros((1, attrs.shape[1])), attrs), axis=0)
    return attrs

In [4]:
def load_profiles(dataset_name):
    user_ids, item_ids = set(), set()
    profiles = defaultdict(list)

    with open(f"../data/{dataset_name}.txt", "r") as df:
        for line in df:
            user_id, item_id = list(map(int, line.strip().split(" ")))
            user_ids.add(user_id)
            item_ids.add(item_id)
            profiles[user_id].append(item_id)
    
    return list(user_ids), list(item_ids), profiles

In [6]:
def one_out_idx(profile: List[int], mode: str) -> int:
    if mode not in ["train", "val", "test"]:
        raise ValueError(f"Invalid mode: {mode}")
    
    if mode == "train" and len(profile) > 1:
        return max(1, len(profile) - 3)
    
    if mode == "val" and len(profile) > 2:
        return max(2, len(profile) - 2)
    
    if mode == "test" and len(profile) > 3:
        return len(profile) - 1
    
    return -1

In [7]:
def pad_profile(profile: List[int], max_len: int, mode: str) -> List[int]:
    if mode not in ["train", "val", "test"]:
        raise ValueError(f"Invalid mode: {mode}")

    start, end = 0, 0

    if mode == "train" and len(profile) > 1:
        n_excluded = 3
        start = max(0, len(profile) - n_excluded - max_len)
        end = max(1, len(profile) - n_excluded)

    if mode == "val" and len(profile) > 2:
        n_excluded = 1 if len(profile) == 3 else 2
        start = max(0, len(profile) - n_excluded - max_len)
        end = max(1, len(profile) - n_excluded)

    if mode == "test" and len(profile) > 3:
        n_excluded = 2
        start = max(0, len(profile) - n_excluded - max_len)
        end = max(1, len(profile) - n_excluded)

    return list(range(start, end))

In [104]:
def sample_negatives(profile: List[int], n_items: int, n: int) -> List[int]:
    sample = []

    while len(sample) < n:
        item_id = random.randint(1, n_items - 1)

        if item_id not in sample and item_id not in profile:
            sample.append(item_id)

    return sample

In [148]:
def get_train_sequences(
        user_id: int,
        profile: List[int],
        seq_len: int,
        attrs: np.ndarray,
        ctx: Dict[Tuple[int, int], np.ndarray]
    ) -> Tuple[np.ndarray, ...]:
    q_len = attrs.shape[1] + next(iter(ctx.values())).shape[0]

    p_x = np.zeros(seq_len, dtype=np.int32)
    o_x = np.zeros(seq_len * 2, dtype=np.int32)
    p_q = np.zeros((seq_len, q_len))
    o_q = np.zeros((seq_len * 2, q_len))

    padded_idxs = pad_profile(profile, seq_len, "train")
    neg_sample = sample_negatives(profile, attrs.shape[0], len(padded_idxs))

    for i, pi in enumerate(padded_idxs):
        shift = seq_len - len(padded_idxs)

        p_x[shift + i] = profile[pi]
        o_x[shift + i] = profile[pi + 1]
        o_x[seq_len + shift + i] = neg_sample[i]

        a = attrs[profile[pi]]
        c = ctx[(user_id, profile[pi])]
        p_q[shift + i] = np.concatenate((a, c))

        a = attrs[profile[pi + 1]]
        c = ctx[(user_id, profile[pi + 1])]
        o_q[shift + i] = np.concatenate((a, c))

        a = attrs[neg_sample[i]]
        c = ctx[((user_id, profile[pi + 1]))]  # Assign same context to negative sample as to positive sample
        o_q[seq_len + shift + i] = np.concatenate((a, c))
    
    y_true = np.zeros(seq_len * 2, dtype=np.int32)
    y_true[np.where(p_x > 0)] = 1
    mask = np.zeros(seq_len * 2, dtype=np.int32)
    mask[np.where(o_x > 0)] = 1
    
    return p_x, o_x, p_q, o_q, y_true, mask

In [149]:
def get_test_sequences(
        user_id: int,
        profile: List[int],
        profile_seq_len: int,
        target_seq_len: int,
        attrs: np.ndarray,
        ctx: Dict[Tuple[int, int], np.ndarray],
        mode: str
    ) -> Tuple[np.ndarray, ...]:
    q_len = attrs.shape[1] + next(iter(ctx.values())).shape[0]

    p_x = np.zeros(profile_seq_len, dtype=np.int32)
    o_x = np.zeros(target_seq_len + 1, dtype=np.int32)
    p_q = np.zeros((profile_seq_len, q_len))
    o_q = np.zeros((target_seq_len + 1, q_len))

    one_out = one_out_idx(profile, mode)
    a = attrs[profile[one_out]]
    c = ctx[(user_id, profile[one_out])]
    o_x[0] = profile[one_out]
    o_q[0] = np.concatenate((a, c))

    padded_idxs = pad_profile(profile, profile_seq_len, mode)
    neg_samples = sample_negatives(profile, attrs.shape[0], target_seq_len)

    for i, pi in enumerate(padded_idxs):
        shift = profile_seq_len - len(padded_idxs)

        a = attrs[profile[pi]]
        c = ctx[(user_id, profile[pi])]
        p_x[shift + i] = profile[pi]
        p_q[shift + i] = np.concatenate((a, c))
    
    for i, oi in enumerate(neg_samples, start=1):
        a = attrs[oi]
        c = ctx[(user_id, profile[one_out])]  # Assign same context to negatives as to one-out positive
        o_x[i] = oi
        o_q[i] = np.concatenate((a, c))
    
    y_true = np.zeros(target_seq_len + 1, dtype=np.int32)
    y_true[0] = 1
    mask = np.ones(target_seq_len + 1, dtype=np.int32)
    
    return p_x, o_x, p_q, o_q, y_true, mask

In [150]:
def get_sequences(
    user_id: int,
    profile: List[int],
    profile_seq_len: int,
    target_seq_len: int,
    attrs: np.ndarray,
    ctx: Dict[Tuple[int, int], np.ndarray],
    mode: str
) -> Tuple[np.ndarray, ...]:
    if mode == "train":
        return get_train_sequences(user_id, profile, profile_seq_len, attrs, ctx)
    else:
        return get_test_sequences(user_id, profile, profile_seq_len, target_seq_len, attrs, ctx, mode)

In [151]:
class CARCADataset(Dataset):
    def __init__(
        self,
        user_ids: List[int],
        item_ids: List[int],
        profiles: Dict[int, List[int]],
        attrs: np.ndarray,
        ctx: Dict[Tuple[int, int], np.ndarray],
        profile_seq_len: int,
        target_seq_len: int,
        mode: str
    ):
        super().__init__()

        self.user_ids = self.valid_user_ids(profiles, mode)
        self.item_ids = item_ids
        self.profiles = profiles
        self.attrs = attrs
        self.ctx = ctx
        self.profile_seq_len = profile_seq_len
        self.target_seq_len = target_seq_len
        self.mode = mode
    
    def __len__(self) -> int:
        return len(self.user_ids)
    
    def __getitem__(self, idx) -> Tuple[np.ndarray, ...]:
        user_id = self.user_ids[idx]
        profile = self.profiles[user_id]

        return get_sequences(
            user_id,
            profile,
            self.profile_seq_len,
            self.target_seq_len,
            self.attrs,
            self.ctx,
            self.mode
        )
    
    def valid_user_ids(self, profiles: Dict[int, List[int]], mode: str) -> List[int]:
        return [uid for uid, profile in profiles.items() if one_out_idx(profile, mode) != -1]

In [17]:
attrs = load_attrs("video_games")
ctx = load_ctx("video_games")
user_ids, item_ids, profiles = load_profiles("video_games")

In [34]:
ds_name = "video_games"
records = []

with open(f"../data/{ds_name}_ts.txt", "r") as df:
    for line in df:
        vals = tuple(map(int, line.strip().split(" ")))
        records.append(vals)

In [35]:
records_sorted = sorted(records, key=lambda tup: tup[2])

In [36]:
with open(f"../data/{ds_name}_sorted.txt", "w") as f:
    for i, record in enumerate(records_sorted):
        line = " ".join(str(val) for val in record[:2])
        line = line + "\n" if i < len(records_sorted) -1 else line
        f.write(line)

In [38]:
import torch

In [49]:
torch.set_printoptions(linewidth=200)

In [47]:
y = torch.rand((4, 10))

In [57]:
y_true = torch.zeros((4, 10), dtype=torch.int32)

In [58]:
y_true[:, 0] = 1

In [59]:
y_true

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)

In [50]:
y

tensor([[0.4116, 0.5234, 0.5795, 0.5253, 0.3428, 0.1094, 0.3882, 0.4722, 0.2408, 0.7886],
        [0.1430, 0.1733, 0.9766, 0.0278, 0.3238, 0.8045, 0.5246, 0.5584, 0.6109, 0.4927],
        [0.7522, 0.5516, 0.9690, 0.3029, 0.2106, 0.6066, 0.9199, 0.3489, 0.9856, 0.2810],
        [0.8397, 0.9874, 0.0470, 0.1061, 0.7803, 0.7747, 0.8167, 0.2351, 0.7567, 0.9092]])

In [51]:
sorted, idxs = torch.sort(y)

In [52]:
sorted

tensor([[0.1094, 0.2408, 0.3428, 0.3882, 0.4116, 0.4722, 0.5234, 0.5253, 0.5795, 0.7886],
        [0.0278, 0.1430, 0.1733, 0.3238, 0.4927, 0.5246, 0.5584, 0.6109, 0.8045, 0.9766],
        [0.2106, 0.2810, 0.3029, 0.3489, 0.5516, 0.6066, 0.7522, 0.9199, 0.9690, 0.9856],
        [0.0470, 0.1061, 0.2351, 0.7567, 0.7747, 0.7803, 0.8167, 0.8397, 0.9092, 0.9874]])

In [85]:
y_true_sorted = torch.zeros((4, 10), dtype=torch.int32)

In [86]:
y_true_sorted[:, 0] = 1

In [87]:
y_true_sorted

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)

In [84]:
idxs

tensor([[5, 8, 4, 6, 0, 7, 1, 3, 2, 9],
        [3, 0, 1, 4, 9, 6, 7, 8, 5, 2],
        [4, 9, 3, 7, 1, 5, 0, 6, 2, 8],
        [2, 3, 7, 8, 5, 4, 6, 0, 9, 1]])

In [94]:
y_true_sort = torch.gather(y_true, -1, idxs)

In [95]:
k = 5

In [97]:
top_k = y_true_sort[:, :k]

In [102]:
top_k

tensor([[0, 0, 0, 0, 1],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]], dtype=torch.int32)

In [144]:
# ranks = torch.nonzero(top_k)[:, 1]
ranks = torch.nonzero(torch.zeros((4, 5), dtype=torch.int32))[:, 1]

In [146]:
ranks

tensor([], dtype=torch.int64)

In [147]:
scores = 1 / torch.log2(ranks + 2)

In [149]:
(torch.sum(scores) / top_k.shape[0]).item()

0.0