# Create the data based on the following description



In [1]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import pairwise_distances
from scipy.sparse import csr_matrix
import random


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from loss import nll

In [30]:

def build_item_attribute_matrix(item_attribute_dict, attribute_list=None):
    items = list(item_attribute_dict.keys())
    attributes = attribute_list if attribute_list else sorted({a for attrs in item_attribute_dict.values() for a in attrs})
    item_idx = {item: i for i, item in enumerate(items)}
    attr_idx = {attr: i for i, attr in enumerate(attributes)}
    matrix = np.zeros((len(items), len(attributes)))
    
    for item, attrs in item_attribute_dict.items():
        for attr in attrs:
            matrix[item_idx[item], attr_idx[attr]] = 1
    return matrix, items, attributes


def jaccard_similarity(matrix):
    A = matrix.astype(bool).astype(int)  # Ensure binary matrix
    n = A.shape[1]
    sim = np.zeros((n, n))

    for i in range(n):
        for j in range(n):
            if i == j:
                sim[i, j] = 1.0
            else:
                inter = np.logical_and(A[:, i], A[:, j]).sum()
                union = np.logical_or(A[:, i], A[:, j]).sum()
                sim[i, j] = inter / union if union != 0 else 0.0
    return sim


def generate_similarity_based_user_attributes(attr_sim, num_users, top_k=5, diversity_rounds=2):
    num_attributes = attr_sim.shape[0]
    user_attr_matrix = np.zeros((num_users, num_attributes))
    
    for u in range(num_users):
        selected = set()
        for _ in range(diversity_rounds):
            remaining_attrs = [a for a in range(num_attributes) if a not in selected]
            if not remaining_attrs:
                break
            seed = random.choice(remaining_attrs)
            sims = attr_sim[seed]
            probs = sims / sims.sum()
            non_zero_indices = np.where(probs > 0)[0]
            sample_size = min(top_k, len(non_zero_indices))
            chosen = np.random.choice(non_zero_indices, size=sample_size, replace=False)
            selected.update(chosen)
        user_attr_matrix[u, list(selected)] = 1
    return user_attr_matrix

def compute_attribute_overlap(matrix):
    I = matrix
    overlaps = {}
    for a in range(I.shape[1]):
        for b in range(a+1, I.shape[1]):
            Ia, Ib = I[:, a], I[:, b]
            inter = np.logical_and(Ia, Ib).sum()
            if Ia.sum() == 0 or Ib.sum() == 0: continue
            rmax = max(inter / Ia.sum(), inter / Ib.sum())
            rmin = min(inter / Ia.sum(), inter / Ib.sum())
            overlaps[(a, b)] = (rmax, rmin)
    return overlaps

def generate_set_theoretic_rules(overlaps, num_users, k=2, l=1, tau1=0.8, tau2=0.2, tau3=0.85):
    user_rules = []
    conjunctions = [pair for pair, (rmax, rmin) in overlaps.items() if rmax < tau1 and rmin > tau2]
    negations = [pair for pair, (rmax, rmin) in overlaps.items() if rmin > tau3]

    for _ in range(num_users):
        num_and_clauses = random.randint(2, 4)
        num_not_clauses = random.randint(0, 1)
        
        P = random.sample(conjunctions, min(num_and_clauses, len(conjunctions)))
        N = random.sample(negations, min(num_not_clauses, len(negations)))
        clause = {"and": P, "not": N}
        user_rules.append(clause)
    return user_rules, conjunctions, negations

def evaluate_dnf_rule(rule, item_attrs):
    # example rule: {'and': [(0, 1), (2, 3)], 'not': [(4, 5)]}
    # need to validate each and and each not rule
    # for this example, we check of 0 and 1, 2 and 3 are present, and 4 is present but 5 is not present.
    for and_clause in rule['and']:
        if all(item_attrs[a] for a in and_clause):
            return 1
    
    for present, not_present in rule['not']:
        if item_attrs[present] and not item_attrs[not_present]:
            return 1

    return 0


def generate_user_item_matrix_dot(Mua, Mia, tau_user=None):
    scores = Mua @ Mia.T
    Mui_dot = np.zeros_like(scores)
    for u in range(scores.shape[0]):
        threshold = tau_user[u] if tau_user is not None else np.percentile(scores[u], 75)
        Mui_dot[u] = (scores[u] > threshold).astype(int)
    return Mui_dot

def generate_user_item_matrix_set(user_rules, Mia):
    num_users = len(user_rules)
    num_items = Mia.shape[0]
    Mui_set = np.zeros((num_users, num_items))
    for u in range(num_users):
        for i in range(num_items):
            Mui_set[u, i] = evaluate_dnf_rule(user_rules[u], Mia[i])
    return Mui_set

def combine_preferences(Mui_dot, Mui_set, p=0.5):
    U = Mui_dot.shape[0]
    z = np.random.binomial(1, p, U)
    return np.array([Mui_set[u] if z[u] else Mui_dot[u] for u in range(U)])

def apply_noise(matrix, noise_fraction=0.1):
    noisy_matrix = matrix.copy()
    indices = np.argwhere(matrix == 1)
    num_flips = int(noise_fraction * len(indices))
    to_flip = indices[np.random.choice(len(indices), num_flips, replace=False)]
    for i, j in to_flip:
        noisy_matrix[i, j] = 0
    return noisy_matrix


Step 1: Load Attribute Item matrix.
Step 2: Get similarity (Jaccard) between attributes
Step 3: 

In [31]:
# Step 1: Example item–attribute dict
# Load item-attribute dict from MovieLens data
item_attr_dict = {}
with open('data/item_tag_dict.txt', 'r') as f:
    for line in f:
        movie_id, tags = line.strip().split(': ')
        item_attr_dict[movie_id] = tags.split(', ')

print(f"Loaded {len(item_attr_dict)} movies with their genres")
print("\nExample items:")
for movie_id, tags in list(item_attr_dict.items())[:5]:
    print(f"Movie {movie_id}: {tags}")

Mia, items, attrs = build_item_attribute_matrix(item_attr_dict)
attr_sim = jaccard_similarity(Mia)

Loaded 3883 movies with their genres

Example items:
Movie 1: ['animation', "children's", 'comedy']
Movie 2: ['adventure', "children's", 'fantasy']
Movie 3: ['comedy', 'romance']
Movie 4: ['comedy', 'drama']
Movie 5: ['comedy']


In [32]:
# Step 2: User–attribute preferences
Mua = generate_similarity_based_user_attributes(attr_sim, num_users=200, diversity_rounds=5)


In [33]:
# Step 3: Set-theoretic rules
overlaps = compute_attribute_overlap(Mia)
user_rules, conjunctions, negations = generate_set_theoretic_rules(overlaps, num_users=200, k=2, l=0, tau1=0.3, tau2=0.01, tau3=0.05)

In [67]:
# print a few user rules in attribute names parse the rules to attribute names

for rule in user_rules[:5]:
    and_clauses = [f"({attrs[a]} & {attrs[b]})" for (a, b) in rule['and']]
    not_clauses = [f"({attrs[present]} & not {attrs[not_present]})" for (present, not_present) in rule['not']]
    all_clauses = and_clauses + not_clauses
    print(" or ".join(all_clauses))

(adventure & comedy) or (animation & comedy) or (adventure & horror) or (comedy & sci-fi)
(crime & horror) or (thriller & war) or (action & western)
(crime & horror) or (adventure & horror) or (adventure & comedy) or (romance & not thriller)
(comedy & mystery) or (adventure & mystery)
(crime & mystery) or (comedy & fantasy) or (children's & sci-fi) or (action & thriller)


In [34]:
user_rules[12]

{'and': [(14, 15), (1, 5), (4, 5), (1, 12)], 'not': []}

In [35]:
# Step 4: User–item matrices
Mui_dot = generate_user_item_matrix_dot(Mua, Mia)
Mui_set = generate_user_item_matrix_set(user_rules, Mia)

# rank of Mui_dot
rank_Mui_dot = np.linalg.matrix_rank(Mui_dot)
print(f"Rank of Mui_dot: {rank_Mui_dot}")

# rank of Mui_set
rank_Mui_set = np.linalg.matrix_rank(Mui_set)
print(f"Rank of Mui_set: {rank_Mui_set}")


Rank of Mui_dot: 130
Rank of Mui_set: 188


In [36]:

def calculate_sparsity(matrix):
    """
    Calculate the sparsity of a matrix (percentage of zero elements).
    """
    total_elements = matrix.size
    zero_elements = np.count_nonzero(matrix == 0)
    sparsity = (zero_elements / total_elements) * 100
    return sparsity

# Calculate sparsity for both matrices
sparsity_dot = calculate_sparsity(Mui_dot)
sparsity_set = calculate_sparsity(Mui_set)

print(f"Sparsity of Mui_dot: {sparsity_dot:.2f}%")
print(f"Sparsity of Mui_set: {sparsity_set:.2f}%")
print(f"Number of non-zero elements in Mui_dot: {np.count_nonzero(Mui_dot)}")
print(f"Number of non-zero elements in Mui_set: {np.count_nonzero(Mui_set)}")
print(f"Total elements in matrix: {Mui_dot.size}")

Sparsity of Mui_dot: 90.26%
Sparsity of Mui_set: 93.79%
Number of non-zero elements in Mui_dot: 75663
Number of non-zero elements in Mui_set: 48243
Total elements in matrix: 776600


In [37]:
# Step 5: Combine preferences
Mui = combine_preferences(Mui_dot, Mui_set, p=0.6)

# Step 6: Inject noise
Mui_noisy = apply_noise(Mui, noise_fraction=0.1)

In [63]:
# sparsity of Mui, Mui_dot, Mui_set, Mui_noisy
print(f"Sparsity of Mui: {calculate_sparsity(Mui):.2f}%")
print(f"Sparsity of Mui_dot: {calculate_sparsity(Mui_dot):.2f}%")
print(f"Sparsity of Mui_set: {calculate_sparsity(Mui_set):.2f}%")
print(f"Sparsity of Mui_noisy: {calculate_sparsity(Mui_noisy):.2f}%")


Sparsity of Mui: 92.36%
Sparsity of Mui_dot: 90.26%
Sparsity of Mui_set: 93.79%
Sparsity of Mui_noisy: 93.12%


In [69]:
# check the rank of Mui, Mui_dot, Mui_set, Mui_noisy
print(f"Rank of Mui: {np.linalg.matrix_rank(Mui)}")
print(f"Rank of Mui_dot: {np.linalg.matrix_rank(Mui_dot)}")
print(f"Rank of Mui_set: {np.linalg.matrix_rank(Mui_set)}")
print(f"Rank of Mui_noisy: {np.linalg.matrix_rank(Mui_noisy)}")



Rank of Mui: 194
Rank of Mui_dot: 130
Rank of Mui_set: 188
Rank of Mui_noisy: 200


In [70]:
Mui_dot.shape

(200, 3883)

In [38]:

def train_test_split(matrix, test_ratio=0.2, seed=42):
    np.random.seed(seed)
    train = matrix.copy()
    test = np.zeros_like(matrix)

    for u in range(matrix.shape[0]):
        pos_items = np.where(matrix[u] == 1)[0]
        if len(pos_items) == 0:
            continue
        test_size = max(1, int(test_ratio * len(pos_items)))
        test_items = np.random.choice(pos_items, size=test_size, replace=False)
        train[u, test_items] = 0
        test[u, test_items] = 1

    return train, test

class TrainingDataset(Dataset):
    def __init__(self, user_item_matrix, num_negatives=1):
        self.num_users, self.num_items = user_item_matrix.shape
        self.positives = [(u, i) for u in range(self.num_users) for i in range(self.num_items) if user_item_matrix[u, i] == 1]
        self.negatives = []
        self.num_negatives = num_negatives
        self.user_item_matrix = user_item_matrix

    def __len__(self):
        return len(self.positives)

    def __getitem__(self, idx):
        u, i_pos = self.positives[idx]
        # positive example
        i_neg = np.random.randint(0, self.num_items)
        while self.user_item_matrix[u, i_neg] == 1:
            i_neg = np.random.randint(0, self.num_items)

        # randomly choose to return positive or negative sample
        if np.random.rand() < 0.5:
            return (
                torch.tensor(u, dtype=torch.long),
                torch.tensor(i_pos, dtype=torch.long),
                torch.tensor(1.0, dtype=torch.float)
            )
        else:
            return (
            torch.tensor(u, dtype=torch.long),
            torch.tensor(i_neg, dtype=torch.long),
            torch.tensor(0.0, dtype=torch.float)
            )


    
def evaluate_precision_hit(model, test_matrix, train_matrix, k=10):
    model.eval()
    num_users, num_items = test_matrix.shape
    precision_scores, hit_scores = [], []

    with torch.no_grad():
        for u in range(num_users):
            train_items = set(np.where(train_matrix[u] == 1)[0])
            test_items = set(np.where(test_matrix[u] == 1)[0])
            if not test_items:
                continue

            candidates = [i for i in range(num_items) if i not in train_items]
            user_tensor = torch.tensor([u] * len(candidates))
            item_tensor = torch.tensor(candidates)
            scores = model(user_tensor, item_tensor).numpy()

            top_k_indices = np.argsort(scores)[-k:][::-1]
            top_k_items = [candidates[i] for i in top_k_indices]

            hits = sum([1 for item in top_k_items if item in test_items])
            precision = hits / k
            hit = 1.0 if hits > 0 else 0.0

            precision_scores.append(precision)
            hit_scores.append(hit)

    avg_precision = np.mean(precision_scores)
    avg_hit = np.mean(hit_scores)
    print(f"Precision@{k}: {avg_precision:.4f}, Hit@{k}: {avg_hit:.4f}")


In [39]:
class MFModel(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=32):
        super(MFModel, self).__init__()
        self.user_emb = nn.Embedding(num_users, embedding_dim)
        self.item_emb = nn.Embedding(num_items, embedding_dim)
    
    def forward(self, user_idx, item_idx):
        u = self.user_emb(user_idx)
        i = self.item_emb(item_idx)
        dot = (u * i).sum(dim=1)
        return torch.sigmoid(dot)  # since preference is binary

In [40]:
def train(model, dataloader, epochs=5, loss_type='BCE', lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    if loss_type == 'BCE':
        loss_fn = nn.BCELoss()
    elif loss_type == 'MSE':
        loss_fn = nn.MSELoss()
    elif loss_type == 'NLL':
        loss_fn = nll
    else:
        raise ValueError(f"Invalid loss type: {loss_type}")
    
    for epoch in tqdm(range(epochs)):
        total_loss = 0.0
        model.train()
        for users, items, labels in dataloader:
            preds = model(users, items)
            if loss_type == 'BCE' or loss_type == 'MSE':
                loss = loss_fn(preds, labels)
            elif loss_type == 'NLL':
                # check if labels are 1 or 0
                pred_pos = preds[labels == 1]
                pred_neg = preds[labels == 0]
                loss = nll(pos=pred_pos, neg=pred_neg)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")
        evaluate_precision_hit(model, test_matrix, train_matrix, k=10)


In [56]:
train_matrix, test_matrix = train_test_split(Mui_set, test_ratio=0.2)

train_dataset = TrainingDataset(train_matrix, num_negatives=4)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

num_users, num_items = train_matrix.shape
model = MFModel(num_users, num_items, embedding_dim=64)


train(model, train_loader, loss_type='BCE', epochs=60, lr=0.005)
evaluate_precision_hit(model, test_matrix, train_matrix, k=10)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 1, Loss: 1116.2058


  2%|▏         | 1/60 [00:01<01:41,  1.73s/it]

Precision@10: 0.0140, Hit@10: 0.1050
Epoch 2, Loss: 832.7717


  3%|▎         | 2/60 [00:03<01:36,  1.66s/it]

Precision@10: 0.0145, Hit@10: 0.1100
Epoch 3, Loss: 652.9417


  5%|▌         | 3/60 [00:04<01:31,  1.60s/it]

Precision@10: 0.0160, Hit@10: 0.1100
Epoch 4, Loss: 497.7637


  7%|▋         | 4/60 [00:06<01:27,  1.55s/it]

Precision@10: 0.0185, Hit@10: 0.1200
Epoch 5, Loss: 410.3877


  8%|▊         | 5/60 [00:07<01:24,  1.54s/it]

Precision@10: 0.0250, Hit@10: 0.1350
Epoch 6, Loss: 337.5972


 10%|█         | 6/60 [00:09<01:23,  1.54s/it]

Precision@10: 0.0400, Hit@10: 0.1950
Epoch 7, Loss: 287.9047


 12%|█▏        | 7/60 [00:11<01:23,  1.58s/it]

Precision@10: 0.0485, Hit@10: 0.2050
Epoch 8, Loss: 240.0403


 13%|█▎        | 8/60 [00:12<01:23,  1.60s/it]

Precision@10: 0.0615, Hit@10: 0.2250
Epoch 9, Loss: 207.9246


 15%|█▌        | 9/60 [00:14<01:24,  1.65s/it]

Precision@10: 0.0775, Hit@10: 0.2500
Epoch 10, Loss: 182.3389


 17%|█▋        | 10/60 [00:16<01:25,  1.70s/it]

Precision@10: 0.0970, Hit@10: 0.3050
Epoch 11, Loss: 159.4759


 18%|█▊        | 11/60 [00:18<01:27,  1.78s/it]

Precision@10: 0.1060, Hit@10: 0.3250
Epoch 12, Loss: 145.3022


 20%|██        | 12/60 [00:20<01:28,  1.85s/it]

Precision@10: 0.1265, Hit@10: 0.3800
Epoch 13, Loss: 122.1113


 22%|██▏       | 13/60 [00:22<01:28,  1.88s/it]

Precision@10: 0.1410, Hit@10: 0.4300
Epoch 14, Loss: 110.8111


 23%|██▎       | 14/60 [00:24<01:27,  1.91s/it]

Precision@10: 0.1580, Hit@10: 0.4750
Epoch 15, Loss: 95.5508


 25%|██▌       | 15/60 [00:26<01:27,  1.95s/it]

Precision@10: 0.1700, Hit@10: 0.5100
Epoch 16, Loss: 84.7659


 27%|██▋       | 16/60 [00:28<01:25,  1.95s/it]

Precision@10: 0.1910, Hit@10: 0.5650
Epoch 17, Loss: 76.8398


 28%|██▊       | 17/60 [00:30<01:25,  1.98s/it]

Precision@10: 0.2090, Hit@10: 0.5600
Epoch 18, Loss: 69.0342


 30%|███       | 18/60 [00:32<01:23,  1.98s/it]

Precision@10: 0.2265, Hit@10: 0.5750
Epoch 19, Loss: 63.5414


 32%|███▏      | 19/60 [00:34<01:25,  2.08s/it]

Precision@10: 0.2495, Hit@10: 0.5950
Epoch 20, Loss: 57.5298


 33%|███▎      | 20/60 [00:36<01:23,  2.09s/it]

Precision@10: 0.2635, Hit@10: 0.6100
Epoch 21, Loss: 53.2403


 35%|███▌      | 21/60 [00:38<01:24,  2.16s/it]

Precision@10: 0.2785, Hit@10: 0.6250
Epoch 22, Loss: 49.6508


 37%|███▋      | 22/60 [00:41<01:25,  2.24s/it]

Precision@10: 0.3015, Hit@10: 0.6700
Epoch 23, Loss: 46.3506


 38%|███▊      | 23/60 [00:43<01:23,  2.26s/it]

Precision@10: 0.3230, Hit@10: 0.6750
Epoch 24, Loss: 42.9560


 40%|████      | 24/60 [00:46<01:22,  2.29s/it]

Precision@10: 0.3295, Hit@10: 0.6750
Epoch 25, Loss: 40.6196


 42%|████▏     | 25/60 [00:48<01:22,  2.35s/it]

Precision@10: 0.3530, Hit@10: 0.6850
Epoch 26, Loss: 39.0005


 43%|████▎     | 26/60 [00:50<01:20,  2.36s/it]

Precision@10: 0.3735, Hit@10: 0.7300
Epoch 27, Loss: 34.6355


 45%|████▌     | 27/60 [00:53<01:19,  2.40s/it]

Precision@10: 0.3875, Hit@10: 0.7400
Epoch 28, Loss: 33.4016


 47%|████▋     | 28/60 [00:55<01:17,  2.41s/it]

Precision@10: 0.3980, Hit@10: 0.7650
Epoch 29, Loss: 31.0260


 48%|████▊     | 29/60 [00:58<01:15,  2.42s/it]

Precision@10: 0.4095, Hit@10: 0.7750
Epoch 30, Loss: 28.8266


 50%|█████     | 30/60 [01:00<01:12,  2.43s/it]

Precision@10: 0.4225, Hit@10: 0.7650
Epoch 31, Loss: 26.2909


 52%|█████▏    | 31/60 [01:03<01:11,  2.45s/it]

Precision@10: 0.4350, Hit@10: 0.7650
Epoch 32, Loss: 24.0544


 53%|█████▎    | 32/60 [01:05<01:09,  2.47s/it]

Precision@10: 0.4455, Hit@10: 0.7700
Epoch 33, Loss: 23.9475


 55%|█████▌    | 33/60 [01:07<01:04,  2.39s/it]

Precision@10: 0.4515, Hit@10: 0.7900
Epoch 34, Loss: 22.6617


 57%|█████▋    | 34/60 [01:10<01:01,  2.35s/it]

Precision@10: 0.4645, Hit@10: 0.8100
Epoch 35, Loss: 21.3647


 58%|█████▊    | 35/60 [01:12<00:57,  2.32s/it]

Precision@10: 0.4740, Hit@10: 0.8000
Epoch 36, Loss: 21.0515


 60%|██████    | 36/60 [01:14<00:54,  2.25s/it]

Precision@10: 0.4875, Hit@10: 0.8150
Epoch 37, Loss: 19.3491


 62%|██████▏   | 37/60 [01:16<00:50,  2.21s/it]

Precision@10: 0.4940, Hit@10: 0.8300
Epoch 38, Loss: 17.1660


 63%|██████▎   | 38/60 [01:18<00:47,  2.17s/it]

Precision@10: 0.5070, Hit@10: 0.8350
Epoch 39, Loss: 16.5165


 65%|██████▌   | 39/60 [01:20<00:44,  2.14s/it]

Precision@10: 0.5080, Hit@10: 0.8500
Epoch 40, Loss: 14.9572


 67%|██████▋   | 40/60 [01:22<00:41,  2.09s/it]

Precision@10: 0.5100, Hit@10: 0.8450
Epoch 41, Loss: 15.7382


 68%|██████▊   | 41/60 [01:24<00:38,  2.04s/it]

Precision@10: 0.5110, Hit@10: 0.8400
Epoch 42, Loss: 14.4908


 70%|███████   | 42/60 [01:26<00:35,  2.00s/it]

Precision@10: 0.5105, Hit@10: 0.8350
Epoch 43, Loss: 13.8201


 72%|███████▏  | 43/60 [01:28<00:33,  1.97s/it]

Precision@10: 0.5275, Hit@10: 0.8500
Epoch 44, Loss: 12.9276


 73%|███████▎  | 44/60 [01:30<00:31,  1.95s/it]

Precision@10: 0.5395, Hit@10: 0.8550
Epoch 45, Loss: 12.9565


 75%|███████▌  | 45/60 [01:32<00:28,  1.92s/it]

Precision@10: 0.5345, Hit@10: 0.8550
Epoch 46, Loss: 12.3030


 77%|███████▋  | 46/60 [01:34<00:26,  1.91s/it]

Precision@10: 0.5490, Hit@10: 0.8700
Epoch 47, Loss: 12.0391


 78%|███████▊  | 47/60 [01:36<00:24,  1.90s/it]

Precision@10: 0.5480, Hit@10: 0.8750
Epoch 48, Loss: 11.2744


 80%|████████  | 48/60 [01:38<00:23,  1.92s/it]

Precision@10: 0.5480, Hit@10: 0.8850
Epoch 49, Loss: 10.3618


 82%|████████▏ | 49/60 [01:39<00:21,  1.93s/it]

Precision@10: 0.5540, Hit@10: 0.8700
Epoch 50, Loss: 11.5526


 83%|████████▎ | 50/60 [01:41<00:19,  1.96s/it]

Precision@10: 0.5605, Hit@10: 0.8600
Epoch 51, Loss: 10.7261


 85%|████████▌ | 51/60 [01:44<00:17,  1.98s/it]

Precision@10: 0.5630, Hit@10: 0.8700
Epoch 52, Loss: 9.5171


 87%|████████▋ | 52/60 [01:46<00:15,  1.99s/it]

Precision@10: 0.5630, Hit@10: 0.8750
Epoch 53, Loss: 9.0884


 88%|████████▊ | 53/60 [01:48<00:13,  1.99s/it]

Precision@10: 0.5625, Hit@10: 0.8700
Epoch 54, Loss: 9.5650


 90%|█████████ | 54/60 [01:50<00:12,  2.04s/it]

Precision@10: 0.5660, Hit@10: 0.8900
Epoch 55, Loss: 8.9387


 92%|█████████▏| 55/60 [01:52<00:10,  2.02s/it]

Precision@10: 0.5695, Hit@10: 0.8800
Epoch 56, Loss: 7.9246


 93%|█████████▎| 56/60 [01:54<00:08,  2.01s/it]

Precision@10: 0.5665, Hit@10: 0.8800
Epoch 57, Loss: 7.6194


 95%|█████████▌| 57/60 [01:56<00:06,  2.01s/it]

Precision@10: 0.5715, Hit@10: 0.8750
Epoch 58, Loss: 7.9499


 97%|█████████▋| 58/60 [01:58<00:03,  2.00s/it]

Precision@10: 0.5700, Hit@10: 0.8850
Epoch 59, Loss: 8.0889


 98%|█████████▊| 59/60 [02:00<00:02,  2.02s/it]

Precision@10: 0.5815, Hit@10: 0.8950
Epoch 60, Loss: 6.7164


100%|██████████| 60/60 [02:02<00:00,  2.04s/it]

Precision@10: 0.5915, Hit@10: 0.8900





Precision@10: 0.5915, Hit@10: 0.8900


In [57]:
for k in [1, 5, 10]:
    evaluate_precision_hit(model, test_matrix, train_matrix, k=k)

Precision@1: 0.7250, Hit@1: 0.7250
Precision@5: 0.6370, Hit@5: 0.8550
Precision@10: 0.5915, Hit@10: 0.8900


In [58]:
train_matrix, test_matrix = train_test_split(Mui_dot, test_ratio=0.2)

train_dataset = TrainingDataset(train_matrix, num_negatives=4)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

num_users, num_items = train_matrix.shape
model = MFModel(num_users, num_items, embedding_dim=64)


train(model, train_loader, loss_type='BCE', epochs=60, lr=0.005)
evaluate_precision_hit(model, test_matrix, train_matrix, k=10)

  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 1, Loss: 1702.1675


  2%|▏         | 1/60 [00:02<02:17,  2.33s/it]

Precision@10: 0.0205, Hit@10: 0.1800
Epoch 2, Loss: 1205.4785


  3%|▎         | 2/60 [00:04<02:09,  2.23s/it]

Precision@10: 0.0210, Hit@10: 0.1650
Epoch 3, Loss: 813.1495


  5%|▌         | 3/60 [00:06<02:03,  2.17s/it]

Precision@10: 0.0660, Hit@10: 0.4450
Epoch 4, Loss: 472.1925


  7%|▋         | 4/60 [00:08<02:00,  2.16s/it]

Precision@10: 0.2145, Hit@10: 0.7950
Epoch 5, Loss: 269.2223


  8%|▊         | 5/60 [00:11<02:07,  2.31s/it]

Precision@10: 0.3710, Hit@10: 0.9350
Epoch 6, Loss: 172.7213


 10%|█         | 6/60 [00:13<02:00,  2.24s/it]

Precision@10: 0.4325, Hit@10: 0.9400
Epoch 7, Loss: 125.6383


 12%|█▏        | 7/60 [00:15<02:02,  2.31s/it]

Precision@10: 0.4745, Hit@10: 0.9700
Epoch 8, Loss: 98.9466


 13%|█▎        | 8/60 [00:17<01:57,  2.25s/it]

Precision@10: 0.4885, Hit@10: 0.9800
Epoch 9, Loss: 88.8875


 15%|█▌        | 9/60 [00:20<01:54,  2.25s/it]

Precision@10: 0.5100, Hit@10: 0.9750
Epoch 10, Loss: 81.4239


 17%|█▋        | 10/60 [00:22<01:57,  2.34s/it]

Precision@10: 0.5145, Hit@10: 0.9750
Epoch 11, Loss: 75.9306


 18%|█▊        | 11/60 [00:25<01:53,  2.31s/it]

Precision@10: 0.5260, Hit@10: 0.9750
Epoch 12, Loss: 71.1281


 20%|██        | 12/60 [00:27<01:52,  2.33s/it]

Precision@10: 0.5485, Hit@10: 0.9800
Epoch 13, Loss: 64.7656


 22%|██▏       | 13/60 [00:29<01:50,  2.36s/it]

Precision@10: 0.5965, Hit@10: 0.9900
Epoch 14, Loss: 59.5694


 23%|██▎       | 14/60 [00:32<01:49,  2.38s/it]

Precision@10: 0.6195, Hit@10: 0.9850
Epoch 15, Loss: 58.0995


 25%|██▌       | 15/60 [00:34<01:42,  2.28s/it]

Precision@10: 0.6585, Hit@10: 1.0000
Epoch 16, Loss: 57.2237


 27%|██▋       | 16/60 [00:36<01:43,  2.34s/it]

Precision@10: 0.6920, Hit@10: 0.9950
Epoch 17, Loss: 50.2826


 28%|██▊       | 17/60 [00:39<01:41,  2.37s/it]

Precision@10: 0.7235, Hit@10: 1.0000
Epoch 18, Loss: 48.0875


 30%|███       | 18/60 [00:41<01:40,  2.39s/it]

Precision@10: 0.7430, Hit@10: 1.0000
Epoch 19, Loss: 46.6355


 32%|███▏      | 19/60 [00:43<01:36,  2.36s/it]

Precision@10: 0.7650, Hit@10: 0.9950
Epoch 20, Loss: 45.3006


 33%|███▎      | 20/60 [00:46<01:36,  2.41s/it]

Precision@10: 0.8000, Hit@10: 1.0000
Epoch 21, Loss: 42.2354


 35%|███▌      | 21/60 [00:48<01:35,  2.44s/it]

Precision@10: 0.8175, Hit@10: 1.0000
Epoch 22, Loss: 41.8227


 37%|███▋      | 22/60 [00:51<01:29,  2.36s/it]

Precision@10: 0.8440, Hit@10: 1.0000
Epoch 23, Loss: 38.6956


 38%|███▊      | 23/60 [00:53<01:25,  2.30s/it]

Precision@10: 0.8475, Hit@10: 1.0000
Epoch 24, Loss: 37.8123


 40%|████      | 24/60 [00:55<01:24,  2.34s/it]

Precision@10: 0.8695, Hit@10: 1.0000
Epoch 25, Loss: 35.8258


 42%|████▏     | 25/60 [00:58<01:21,  2.33s/it]

Precision@10: 0.8785, Hit@10: 1.0000
Epoch 26, Loss: 34.0460


 43%|████▎     | 26/60 [01:00<01:18,  2.30s/it]

Precision@10: 0.8975, Hit@10: 1.0000
Epoch 27, Loss: 34.0362


 45%|████▌     | 27/60 [01:02<01:14,  2.26s/it]

Precision@10: 0.9025, Hit@10: 1.0000
Epoch 28, Loss: 32.5738


 47%|████▋     | 28/60 [01:04<01:14,  2.34s/it]

Precision@10: 0.9100, Hit@10: 1.0000
Epoch 29, Loss: 32.5317


 48%|████▊     | 29/60 [01:07<01:10,  2.28s/it]

Precision@10: 0.9185, Hit@10: 1.0000
Epoch 30, Loss: 32.5863


 50%|█████     | 30/60 [01:09<01:11,  2.37s/it]

Precision@10: 0.9215, Hit@10: 1.0000
Epoch 31, Loss: 29.9073


 52%|█████▏    | 31/60 [01:11<01:06,  2.30s/it]

Precision@10: 0.9305, Hit@10: 1.0000
Epoch 32, Loss: 28.8893


 53%|█████▎    | 32/60 [01:13<01:03,  2.26s/it]

Precision@10: 0.9395, Hit@10: 1.0000
Epoch 33, Loss: 28.1028


 55%|█████▌    | 33/60 [01:16<01:02,  2.32s/it]

Precision@10: 0.9320, Hit@10: 1.0000
Epoch 34, Loss: 28.0944


 57%|█████▋    | 34/60 [01:19<01:02,  2.42s/it]

Precision@10: 0.9420, Hit@10: 1.0000
Epoch 35, Loss: 27.7336


 58%|█████▊    | 35/60 [01:21<00:59,  2.37s/it]

Precision@10: 0.9465, Hit@10: 1.0000
Epoch 36, Loss: 25.9016


 60%|██████    | 36/60 [01:23<00:58,  2.44s/it]

Precision@10: 0.9435, Hit@10: 1.0000
Epoch 37, Loss: 26.1441


 62%|██████▏   | 37/60 [01:26<00:54,  2.37s/it]

Precision@10: 0.9485, Hit@10: 1.0000
Epoch 38, Loss: 26.3417


 63%|██████▎   | 38/60 [01:28<00:51,  2.36s/it]

Precision@10: 0.9495, Hit@10: 1.0000
Epoch 39, Loss: 23.5848


 65%|██████▌   | 39/60 [01:30<00:48,  2.32s/it]

Precision@10: 0.9500, Hit@10: 1.0000
Epoch 40, Loss: 26.7916


 67%|██████▋   | 40/60 [01:33<00:46,  2.33s/it]

Precision@10: 0.9500, Hit@10: 1.0000
Epoch 41, Loss: 24.3986


 68%|██████▊   | 41/60 [01:35<00:45,  2.39s/it]

Precision@10: 0.9595, Hit@10: 1.0000
Epoch 42, Loss: 24.8844


 70%|███████   | 42/60 [01:38<00:43,  2.43s/it]

Precision@10: 0.9530, Hit@10: 1.0000
Epoch 43, Loss: 24.9878


 72%|███████▏  | 43/60 [01:40<00:40,  2.38s/it]

Precision@10: 0.9565, Hit@10: 1.0000
Epoch 44, Loss: 25.4858


 73%|███████▎  | 44/60 [01:43<00:39,  2.45s/it]

Precision@10: 0.9575, Hit@10: 1.0000
Epoch 45, Loss: 24.9379


 75%|███████▌  | 45/60 [01:45<00:37,  2.52s/it]

Precision@10: 0.9550, Hit@10: 1.0000
Epoch 46, Loss: 24.2371


 77%|███████▋  | 46/60 [01:48<00:34,  2.48s/it]

Precision@10: 0.9535, Hit@10: 1.0000
Epoch 47, Loss: 24.3298


 78%|███████▊  | 47/60 [01:50<00:32,  2.50s/it]

Precision@10: 0.9585, Hit@10: 1.0000
Epoch 48, Loss: 22.3502


 80%|████████  | 48/60 [01:53<00:30,  2.51s/it]

Precision@10: 0.9645, Hit@10: 1.0000
Epoch 49, Loss: 22.7647


 82%|████████▏ | 49/60 [01:55<00:27,  2.48s/it]

Precision@10: 0.9685, Hit@10: 1.0000
Epoch 50, Loss: 23.4127


 83%|████████▎ | 50/60 [01:57<00:24,  2.43s/it]

Precision@10: 0.9705, Hit@10: 1.0000
Epoch 51, Loss: 22.9395


 85%|████████▌ | 51/60 [02:00<00:22,  2.46s/it]

Precision@10: 0.9635, Hit@10: 1.0000
Epoch 52, Loss: 21.2622


 87%|████████▋ | 52/60 [02:02<00:19,  2.38s/it]

Precision@10: 0.9670, Hit@10: 1.0000
Epoch 53, Loss: 22.1987


 88%|████████▊ | 53/60 [02:04<00:16,  2.32s/it]

Precision@10: 0.9630, Hit@10: 1.0000
Epoch 54, Loss: 22.3381


 90%|█████████ | 54/60 [02:06<00:13,  2.28s/it]

Precision@10: 0.9630, Hit@10: 1.0000
Epoch 55, Loss: 21.9852


 92%|█████████▏| 55/60 [02:09<00:11,  2.38s/it]

Precision@10: 0.9595, Hit@10: 1.0000
Epoch 56, Loss: 22.1762


 93%|█████████▎| 56/60 [02:11<00:09,  2.32s/it]

Precision@10: 0.9580, Hit@10: 1.0000
Epoch 57, Loss: 20.9357


 95%|█████████▌| 57/60 [02:14<00:06,  2.30s/it]

Precision@10: 0.9595, Hit@10: 1.0000
Epoch 58, Loss: 20.6624


 97%|█████████▋| 58/60 [02:16<00:04,  2.29s/it]

Precision@10: 0.9635, Hit@10: 1.0000
Epoch 59, Loss: 22.2179


 98%|█████████▊| 59/60 [02:19<00:02,  2.46s/it]

Precision@10: 0.9670, Hit@10: 1.0000
Epoch 60, Loss: 19.0214


100%|██████████| 60/60 [02:21<00:00,  2.36s/it]

Precision@10: 0.9650, Hit@10: 1.0000





Precision@10: 0.9650, Hit@10: 1.0000


In [59]:
for k in [1, 5, 10]:
    evaluate_precision_hit(model, test_matrix, train_matrix, k=k)

Precision@1: 0.9900, Hit@1: 0.9900
Precision@5: 0.9780, Hit@5: 1.0000
Precision@10: 0.9650, Hit@10: 1.0000


In [53]:
# let's create the box intersection model
import torch
import torch.nn as nn
from box.box_wrapper import BoxTensor

class BoxIntersectionModel(nn.Module):
    def __init__(self,
                 num_users,
                 num_items,
                 embedding_dim=32,
                 intersection_temp=0.01,
                 volume_temp= 0.1):
        
        super().__init__()
        self.user_emb = nn.Embedding(num_users, 2 * embedding_dim)  # z and Z for each user
        self.item_emb = nn.Embedding(num_items, 2 * embedding_dim)  # z and Z for each item
        self.embedding_dim = embedding_dim
        self.intersection_temp = intersection_temp
        self.volume_temp = volume_temp

    def forward(self, user_idx, item_idx):
        # Get user and item box parameters
        user_box_params = self.user_emb(user_idx).view(-1, 2, self.embedding_dim)
        item_box_params = self.item_emb(item_idx).view(-1, 2, self.embedding_dim)
        user_box = BoxTensor(user_box_params)
        item_box = BoxTensor(item_box_params)
        # Compute intersection volume (as a score)
        intersection_vol = user_box.gumbel_intersection_log_volume(item_box, 
                                                                 volume_temp=self.volume_temp,
                                                                 intersection_temp=self.intersection_temp)
        item_vol = item_box.log_soft_volume_adjusted(volume_temp=self.volume_temp,
                                                      intersection_temp=self.intersection_temp)
        conditional_prob = intersection_vol - item_vol
        assert (conditional_prob <= 0).all(), "Log probability can not be positive"
        return conditional_prob

In [54]:
train_matrix, test_matrix = train_test_split(Mui_set, test_ratio=0.2)

train_dataset = TrainingDataset(train_matrix, num_negatives=5)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

num_users, num_items = train_matrix.shape
dim = 64
model = BoxIntersectionModel(num_users, num_items, embedding_dim=dim)
print(f"Dimension: {dim}")
train(model, train_loader, loss_type='NLL',epochs= 10 * int(np.log2(dim)) , lr=0.01)
evaluate_precision_hit(model, test_matrix, train_matrix, k=10)

Dimension: 64


  0%|          | 0/60 [00:00<?, ?it/s]

Epoch 1, Loss: 54814.9089


  2%|▏         | 1/60 [00:04<04:04,  4.14s/it]

Precision@10: 0.0960, Hit@10: 0.5000
Epoch 2, Loss: 24915.1379


  3%|▎         | 2/60 [00:08<04:07,  4.27s/it]

Precision@10: 0.0945, Hit@10: 0.5200
Epoch 3, Loss: 21316.0685


  5%|▌         | 3/60 [00:12<03:55,  4.12s/it]

Precision@10: 0.1065, Hit@10: 0.5500
Epoch 4, Loss: 19114.1754


  7%|▋         | 4/60 [00:16<03:39,  3.93s/it]

Precision@10: 0.1125, Hit@10: 0.5450
Epoch 5, Loss: 17398.0447


  8%|▊         | 5/60 [00:19<03:32,  3.86s/it]

Precision@10: 0.1465, Hit@10: 0.6100
Epoch 6, Loss: 15775.1014


 10%|█         | 6/60 [00:23<03:31,  3.91s/it]

Precision@10: 0.1575, Hit@10: 0.6150
Epoch 7, Loss: 14223.9639


 12%|█▏        | 7/60 [00:27<03:23,  3.84s/it]

Precision@10: 0.1910, Hit@10: 0.6600
Epoch 8, Loss: 12975.5691


 13%|█▎        | 8/60 [00:31<03:16,  3.78s/it]

Precision@10: 0.2275, Hit@10: 0.6950
Epoch 9, Loss: 11757.5804


 15%|█▌        | 9/60 [00:35<03:13,  3.80s/it]

Precision@10: 0.2705, Hit@10: 0.7250
Epoch 10, Loss: 10918.2717


 17%|█▋        | 10/60 [00:38<03:11,  3.83s/it]

Precision@10: 0.2940, Hit@10: 0.7350
Epoch 11, Loss: 10094.5109


 18%|█▊        | 11/60 [00:42<03:06,  3.81s/it]

Precision@10: 0.3405, Hit@10: 0.7750
Epoch 12, Loss: 9512.2625


 20%|██        | 12/60 [00:47<03:18,  4.13s/it]

Precision@10: 0.3730, Hit@10: 0.7900
Epoch 13, Loss: 8805.0569


 22%|██▏       | 13/60 [00:53<03:34,  4.56s/it]

Precision@10: 0.4040, Hit@10: 0.8200
Epoch 14, Loss: 8175.1623


 23%|██▎       | 14/60 [00:59<03:51,  5.03s/it]

Precision@10: 0.4255, Hit@10: 0.8300
Epoch 15, Loss: 7694.1161


 25%|██▌       | 15/60 [01:04<03:56,  5.26s/it]

Precision@10: 0.4545, Hit@10: 0.8450
Epoch 16, Loss: 7431.8689


 27%|██▋       | 16/60 [01:10<03:59,  5.44s/it]

Precision@10: 0.4770, Hit@10: 0.8500
Epoch 17, Loss: 6918.0623


 28%|██▊       | 17/60 [01:16<03:57,  5.53s/it]

Precision@10: 0.4930, Hit@10: 0.8400
Epoch 18, Loss: 6453.3119


 30%|███       | 18/60 [01:22<03:57,  5.65s/it]

Precision@10: 0.5090, Hit@10: 0.8450
Epoch 19, Loss: 6539.6733


 32%|███▏      | 19/60 [01:27<03:48,  5.57s/it]

Precision@10: 0.5180, Hit@10: 0.8650
Epoch 20, Loss: 6114.9664


 33%|███▎      | 20/60 [01:33<03:38,  5.46s/it]

Precision@10: 0.5395, Hit@10: 0.8650
Epoch 21, Loss: 5850.4320


 35%|███▌      | 21/60 [01:38<03:33,  5.46s/it]

Precision@10: 0.5520, Hit@10: 0.9000
Epoch 22, Loss: 5532.3292


 37%|███▋      | 22/60 [01:43<03:19,  5.25s/it]

Precision@10: 0.5765, Hit@10: 0.9100
Epoch 23, Loss: 5402.4907


 38%|███▊      | 23/60 [01:48<03:08,  5.11s/it]

Precision@10: 0.5935, Hit@10: 0.9100
Epoch 24, Loss: 5313.2006


 40%|████      | 24/60 [01:52<02:57,  4.94s/it]

Precision@10: 0.5970, Hit@10: 0.9150
Epoch 25, Loss: 4985.6069


 42%|████▏     | 25/60 [01:57<02:47,  4.78s/it]

Precision@10: 0.6075, Hit@10: 0.9250
Epoch 26, Loss: 4873.8665


 43%|████▎     | 26/60 [02:01<02:37,  4.64s/it]

Precision@10: 0.6195, Hit@10: 0.9300
Epoch 27, Loss: 4903.1063


 45%|████▌     | 27/60 [02:05<02:30,  4.57s/it]

Precision@10: 0.6230, Hit@10: 0.9250
Epoch 28, Loss: 4466.3322


 47%|████▋     | 28/60 [02:10<02:24,  4.52s/it]

Precision@10: 0.6255, Hit@10: 0.9250
Epoch 29, Loss: 4477.8775


 48%|████▊     | 29/60 [02:14<02:18,  4.46s/it]

Precision@10: 0.6390, Hit@10: 0.9400
Epoch 30, Loss: 4328.6484


 50%|█████     | 30/60 [02:18<02:13,  4.44s/it]

Precision@10: 0.6400, Hit@10: 0.9400
Epoch 31, Loss: 4339.9168


 52%|█████▏    | 31/60 [02:23<02:07,  4.40s/it]

Precision@10: 0.6395, Hit@10: 0.9300
Epoch 32, Loss: 4214.1159


 53%|█████▎    | 32/60 [02:27<02:04,  4.43s/it]

Precision@10: 0.6600, Hit@10: 0.9400
Epoch 33, Loss: 4163.5623


 55%|█████▌    | 33/60 [02:32<02:05,  4.66s/it]

Precision@10: 0.6660, Hit@10: 0.9350
Epoch 34, Loss: 3784.4100


 57%|█████▋    | 34/60 [02:37<02:03,  4.75s/it]

Precision@10: 0.6795, Hit@10: 0.9500
Epoch 35, Loss: 3653.8826


 58%|█████▊    | 35/60 [02:42<02:00,  4.80s/it]

Precision@10: 0.6780, Hit@10: 0.9400
Epoch 36, Loss: 3654.1532


 60%|██████    | 36/60 [02:47<01:54,  4.76s/it]

Precision@10: 0.6850, Hit@10: 0.9550
Epoch 37, Loss: 3429.3321


 62%|██████▏   | 37/60 [02:52<01:51,  4.85s/it]

Precision@10: 0.6885, Hit@10: 0.9550
Epoch 38, Loss: 3740.4648


 63%|██████▎   | 38/60 [02:57<01:46,  4.82s/it]

Precision@10: 0.6975, Hit@10: 0.9550
Epoch 39, Loss: 3669.5952


 65%|██████▌   | 39/60 [03:01<01:40,  4.78s/it]

Precision@10: 0.6970, Hit@10: 0.9550
Epoch 40, Loss: 3237.2054


 67%|██████▋   | 40/60 [03:06<01:35,  4.76s/it]

Precision@10: 0.7025, Hit@10: 0.9600
Epoch 41, Loss: 3105.9227


 68%|██████▊   | 41/60 [03:11<01:28,  4.68s/it]

Precision@10: 0.7050, Hit@10: 0.9500
Epoch 42, Loss: 3087.6830


 70%|███████   | 42/60 [03:15<01:22,  4.60s/it]

Precision@10: 0.7000, Hit@10: 0.9650
Epoch 43, Loss: 2973.5181


 72%|███████▏  | 43/60 [03:19<01:16,  4.50s/it]

Precision@10: 0.7085, Hit@10: 0.9700
Epoch 44, Loss: 2919.1186


 73%|███████▎  | 44/60 [03:24<01:10,  4.42s/it]

Precision@10: 0.7070, Hit@10: 0.9650
Epoch 45, Loss: 2842.9050


 75%|███████▌  | 45/60 [03:28<01:05,  4.35s/it]

Precision@10: 0.7190, Hit@10: 0.9700
Epoch 46, Loss: 2841.5039


 77%|███████▋  | 46/60 [03:32<01:00,  4.33s/it]

Precision@10: 0.7155, Hit@10: 0.9750
Epoch 47, Loss: 2826.8619


 78%|███████▊  | 47/60 [03:36<00:55,  4.30s/it]

Precision@10: 0.7195, Hit@10: 0.9750
Epoch 48, Loss: 2822.2274


 80%|████████  | 48/60 [03:40<00:51,  4.28s/it]

Precision@10: 0.7245, Hit@10: 0.9750
Epoch 49, Loss: 2611.1425


 82%|████████▏ | 49/60 [03:45<00:47,  4.31s/it]

Precision@10: 0.7255, Hit@10: 0.9700
Epoch 50, Loss: 2561.8815


 83%|████████▎ | 50/60 [03:49<00:43,  4.32s/it]

Precision@10: 0.7210, Hit@10: 0.9700
Epoch 51, Loss: 2643.4655


 85%|████████▌ | 51/60 [03:54<00:38,  4.32s/it]

Precision@10: 0.7260, Hit@10: 0.9650
Epoch 52, Loss: 2607.6245


 87%|████████▋ | 52/60 [03:58<00:34,  4.32s/it]

Precision@10: 0.7230, Hit@10: 0.9700
Epoch 53, Loss: 2336.7957


 88%|████████▊ | 53/60 [04:02<00:30,  4.41s/it]

Precision@10: 0.7265, Hit@10: 0.9750
Epoch 54, Loss: 2487.1995


 90%|█████████ | 54/60 [04:07<00:26,  4.45s/it]

Precision@10: 0.7330, Hit@10: 0.9750
Epoch 55, Loss: 2574.1385


 92%|█████████▏| 55/60 [04:12<00:22,  4.51s/it]

Precision@10: 0.7285, Hit@10: 0.9700
Epoch 56, Loss: 2616.3247


 93%|█████████▎| 56/60 [04:16<00:18,  4.52s/it]

Precision@10: 0.7265, Hit@10: 0.9700
Epoch 57, Loss: 2244.5125


 95%|█████████▌| 57/60 [04:21<00:13,  4.50s/it]

Precision@10: 0.7235, Hit@10: 0.9700
Epoch 58, Loss: 2258.4345


 97%|█████████▋| 58/60 [04:25<00:09,  4.52s/it]

Precision@10: 0.7290, Hit@10: 0.9750
Epoch 59, Loss: 2248.6849


 98%|█████████▊| 59/60 [04:30<00:04,  4.51s/it]

Precision@10: 0.7280, Hit@10: 0.9750
Epoch 60, Loss: 2116.4752


100%|██████████| 60/60 [04:34<00:00,  4.58s/it]

Precision@10: 0.7285, Hit@10: 0.9800





Precision@10: 0.7285, Hit@10: 0.9800


In [55]:
for k in [1, 5, 10]:
    evaluate_precision_hit(model, test_matrix, train_matrix, k=k)

Precision@1: 0.8050, Hit@1: 0.8050
Precision@5: 0.7820, Hit@5: 0.9550
Precision@10: 0.7285, Hit@10: 0.9800


In [60]:
train_matrix, test_matrix = train_test_split(Mui_dot, test_ratio=0.2)

train_dataset = TrainingDataset(train_matrix, num_negatives=4)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

num_users, num_items = train_matrix.shape
dim = 32
model = BoxIntersectionModel(num_users, num_items, embedding_dim=dim)
print(f"Dimension: {dim}")
train(model, train_loader, loss_type='NLL',epochs= 10 * int(np.log2(dim)) , lr=0.01)
evaluate_precision_hit(model, test_matrix, train_matrix, k=10)

Dimension: 32


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Loss: 36016.9535


  2%|▏         | 1/50 [00:03<03:09,  3.87s/it]

Precision@10: 0.5210, Hit@10: 0.9950
Epoch 2, Loss: 13519.7088


  4%|▍         | 2/50 [00:07<03:00,  3.77s/it]

Precision@10: 0.5865, Hit@10: 0.9850
Epoch 3, Loss: 12526.4601


  6%|▌         | 3/50 [00:11<02:52,  3.67s/it]

Precision@10: 0.6215, Hit@10: 0.9850
Epoch 4, Loss: 12011.4959


  8%|▊         | 4/50 [00:14<02:52,  3.74s/it]

Precision@10: 0.6430, Hit@10: 0.9900
Epoch 5, Loss: 11945.2149


 10%|█         | 5/50 [00:18<02:49,  3.77s/it]

Precision@10: 0.6565, Hit@10: 0.9850
Epoch 6, Loss: 11447.4005


 12%|█▏        | 6/50 [00:22<02:48,  3.84s/it]

Precision@10: 0.6495, Hit@10: 0.9900
Epoch 7, Loss: 10655.6615


 14%|█▍        | 7/50 [00:26<02:42,  3.79s/it]

Precision@10: 0.6900, Hit@10: 0.9850
Epoch 8, Loss: 10569.9004


 16%|█▌        | 8/50 [00:30<02:40,  3.83s/it]

Precision@10: 0.7060, Hit@10: 1.0000
Epoch 9, Loss: 11274.7553


 18%|█▊        | 9/50 [00:34<02:38,  3.87s/it]

Precision@10: 0.7105, Hit@10: 0.9950
Epoch 10, Loss: 9982.0330


 20%|██        | 10/50 [00:38<02:41,  4.05s/it]

Precision@10: 0.7300, Hit@10: 0.9950
Epoch 11, Loss: 9888.0536


 22%|██▏       | 11/50 [00:42<02:38,  4.05s/it]

Precision@10: 0.7570, Hit@10: 1.0000
Epoch 12, Loss: 11263.4409


 24%|██▍       | 12/50 [00:46<02:34,  4.06s/it]

Precision@10: 0.7575, Hit@10: 1.0000
Epoch 13, Loss: 10361.2657


 26%|██▌       | 13/50 [00:51<02:31,  4.09s/it]

Precision@10: 0.7740, Hit@10: 0.9950
Epoch 14, Loss: 9149.7611


 28%|██▊       | 14/50 [00:55<02:27,  4.08s/it]

Precision@10: 0.7785, Hit@10: 1.0000
Epoch 15, Loss: 9082.1118


 30%|███       | 15/50 [00:59<02:23,  4.10s/it]

Precision@10: 0.7885, Hit@10: 1.0000
Epoch 16, Loss: 9796.3705


 32%|███▏      | 16/50 [01:03<02:22,  4.20s/it]

Precision@10: 0.7895, Hit@10: 1.0000
Epoch 17, Loss: 9971.4696


 34%|███▍      | 17/50 [01:07<02:19,  4.22s/it]

Precision@10: 0.8090, Hit@10: 1.0000
Epoch 18, Loss: 11194.9810


 36%|███▌      | 18/50 [01:12<02:18,  4.33s/it]

Precision@10: 0.8240, Hit@10: 1.0000
Epoch 19, Loss: 10577.1107


 38%|███▊      | 19/50 [01:16<02:11,  4.23s/it]

Precision@10: 0.8185, Hit@10: 1.0000
Epoch 20, Loss: 10159.3020


 40%|████      | 20/50 [01:21<02:09,  4.32s/it]

Precision@10: 0.8325, Hit@10: 1.0000
Epoch 21, Loss: 9152.4944


 42%|████▏     | 21/50 [01:25<02:09,  4.45s/it]

Precision@10: 0.8465, Hit@10: 1.0000
Epoch 22, Loss: 10929.6186


 44%|████▍     | 22/50 [01:31<02:11,  4.70s/it]

Precision@10: 0.8445, Hit@10: 1.0000
Epoch 23, Loss: 11155.8868


 46%|████▌     | 23/50 [01:35<02:00,  4.47s/it]

Precision@10: 0.8550, Hit@10: 1.0000
Epoch 24, Loss: 10597.2289


 48%|████▊     | 24/50 [01:38<01:51,  4.30s/it]

Precision@10: 0.8595, Hit@10: 1.0000
Epoch 25, Loss: 10715.9061


 50%|█████     | 25/50 [01:44<01:55,  4.62s/it]

Precision@10: 0.8640, Hit@10: 1.0000
Epoch 26, Loss: 9668.5044


 52%|█████▏    | 26/50 [01:48<01:44,  4.35s/it]

Precision@10: 0.8695, Hit@10: 1.0000
Epoch 27, Loss: 11578.3771


 54%|█████▍    | 27/50 [01:52<01:38,  4.27s/it]

Precision@10: 0.8760, Hit@10: 1.0000
Epoch 28, Loss: 10172.4328


 56%|█████▌    | 28/50 [01:56<01:32,  4.19s/it]

Precision@10: 0.8775, Hit@10: 1.0000
Epoch 29, Loss: 11990.8292


 58%|█████▊    | 29/50 [02:00<01:26,  4.11s/it]

Precision@10: 0.8850, Hit@10: 1.0000
Epoch 30, Loss: 10722.9280


 60%|██████    | 30/50 [02:03<01:19,  3.98s/it]

Precision@10: 0.8830, Hit@10: 1.0000
Epoch 31, Loss: 11366.1552


 62%|██████▏   | 31/50 [02:07<01:13,  3.84s/it]

Precision@10: 0.8845, Hit@10: 1.0000
Epoch 32, Loss: 13026.6001


 64%|██████▍   | 32/50 [02:11<01:09,  3.84s/it]

Precision@10: 0.8920, Hit@10: 1.0000
Epoch 33, Loss: 13818.3231


 66%|██████▌   | 33/50 [02:15<01:06,  3.88s/it]

Precision@10: 0.8985, Hit@10: 0.9950
Epoch 34, Loss: 11057.3901


 68%|██████▊   | 34/50 [02:18<01:02,  3.88s/it]

Precision@10: 0.9040, Hit@10: 1.0000
Epoch 35, Loss: 13639.8666


 70%|███████   | 35/50 [02:22<00:58,  3.92s/it]

Precision@10: 0.9035, Hit@10: 1.0000
Epoch 36, Loss: 12387.3855


 72%|███████▏  | 36/50 [02:27<00:55,  3.96s/it]

Precision@10: 0.9010, Hit@10: 1.0000
Epoch 37, Loss: 11432.3166


 74%|███████▍  | 37/50 [02:31<00:52,  4.01s/it]

Precision@10: 0.9080, Hit@10: 1.0000
Epoch 38, Loss: 13447.7545


 76%|███████▌  | 38/50 [02:35<00:48,  4.05s/it]

Precision@10: 0.9045, Hit@10: 1.0000
Epoch 39, Loss: 11063.1122


 78%|███████▊  | 39/50 [02:39<00:44,  4.06s/it]

Precision@10: 0.9095, Hit@10: 1.0000
Epoch 40, Loss: 11894.9658


 80%|████████  | 40/50 [02:43<00:40,  4.07s/it]

Precision@10: 0.9060, Hit@10: 1.0000
Epoch 41, Loss: 12636.3654


 82%|████████▏ | 41/50 [02:47<00:36,  4.03s/it]

Precision@10: 0.9080, Hit@10: 1.0000
Epoch 42, Loss: 13939.9683


 84%|████████▍ | 42/50 [02:51<00:32,  4.00s/it]

Precision@10: 0.9085, Hit@10: 1.0000
Epoch 43, Loss: 13742.8106


 86%|████████▌ | 43/50 [02:55<00:28,  4.00s/it]

Precision@10: 0.9085, Hit@10: 1.0000
Epoch 44, Loss: 11226.9004


 88%|████████▊ | 44/50 [02:59<00:23,  3.96s/it]

Precision@10: 0.9060, Hit@10: 1.0000
Epoch 45, Loss: 14420.8570


 90%|█████████ | 45/50 [03:02<00:19,  3.90s/it]

Precision@10: 0.9115, Hit@10: 1.0000
Epoch 46, Loss: 13873.8010


 92%|█████████▏| 46/50 [03:06<00:15,  3.92s/it]

Precision@10: 0.9075, Hit@10: 1.0000
Epoch 47, Loss: 14953.8138


 94%|█████████▍| 47/50 [03:10<00:11,  3.89s/it]

Precision@10: 0.9055, Hit@10: 1.0000
Epoch 48, Loss: 14616.8826


 96%|█████████▌| 48/50 [03:14<00:07,  3.90s/it]

Precision@10: 0.9020, Hit@10: 1.0000
Epoch 49, Loss: 13419.4977


 98%|█████████▊| 49/50 [03:18<00:04,  4.01s/it]

Precision@10: 0.9090, Hit@10: 1.0000
Epoch 50, Loss: 14402.4844


100%|██████████| 50/50 [03:22<00:00,  4.06s/it]

Precision@10: 0.9150, Hit@10: 1.0000





Precision@10: 0.9150, Hit@10: 1.0000


In [61]:
for k in [1, 5, 10]:
    evaluate_precision_hit(model, test_matrix, train_matrix, k=k)

Precision@1: 0.9250, Hit@1: 0.9250
Precision@5: 0.9130, Hit@5: 0.9950
Precision@10: 0.9150, Hit@10: 1.0000


## Results

- Num Items 3883
- Num Users 200
- Num attributes 18
#### M_dot
- Mua @ Mia.T > tau (Mua: Generated, Mia: Real data)
- Mia is from real data.
- Mua is created by the common understanding 
    - "If you like this genre you might also like similar genre"
#### User rules

<style>
  .genre { 
    display: inline-block; 
    padding: 2px 6px; 
    margin: 1px; 
    border-radius: 6px; 
    font-size: 90%;
    font-weight: bold;
    color: white;
  }
  .adventure { background-color: #1f77b4; }
  .comedy { background-color: #ff7f0e; }
  .animation { background-color: #2ca02c; }
  .horror { background-color: #d62728; }
  .sci-fi { background-color: #9467bd; }
  .crime { background-color: #8c564b; }
  .thriller { background-color: #e377c2; }
  .war { background-color: #7f7f7f; }
  .action { background-color: #bcbd22; }
  .western { background-color: #17becf; }
  .romance { background-color: #ff9896; }
  .mystery { background-color: #c5b0d5; }
  .fantasy { background-color: #aec7e8; }
  .children { background-color: #98df8a; }
</style>

<p><strong>User 1:</strong><br>
(<span class="genre adventure">Adventure</span> ∧ <span class="genre comedy">Comedy</span>) ∨ 
(<span class="genre animation">Animation</span> ∧ <span class="genre comedy">Comedy</span>) ∨ 
(<span class="genre adventure">Adventure</span> ∧ <span class="genre horror">Horror</span>) ∨ 
(<span class="genre comedy">Comedy</span> ∧ <span class="genre sci-fi">Sci-Fi</span>)
</p>

<p><strong>User 2:</strong><br>
(<span class="genre crime">Crime</span> ∧ <span class="genre horror">Horror</span>) ∨ 
(<span class="genre thriller">Thriller</span> ∧ <span class="genre war">War</span>) ∨ 
(<span class="genre action">Action</span> ∧ <span class="genre western">Western</span>)
</p>

<p><strong>User 3:</strong><br>
(<span class="genre crime">Crime</span> ∧ <span class="genre horror">Horror</span>) ∨ 
(<span class="genre adventure">Adventure</span> ∧ <span class="genre horror">Horror</span>) ∨ 
(<span class="genre adventure">Adventure</span> ∧ <span class="genre comedy">Comedy</span>) ∨ 
(<span class="genre romance">Romance</span> ∧ ¬<span class="genre thriller">Thriller</span>)
</p>

<p><strong>User 4:</strong><br>
(<span class="genre comedy">Comedy</span> ∧ <span class="genre mystery">Mystery</span>) ∨ 
(<span class="genre adventure">Adventure</span> ∧ <span class="genre mystery">Mystery</span>)
</p>

<p><strong>User 5:</strong><br>
(<span class="genre crime">Crime</span> ∧ <span class="genre mystery">Mystery</span>) ∨ 
(<span class="genre comedy">Comedy</span> ∧ <span class="genre fantasy">Fantasy</span>) ∨ 
(<span class="genre children">Children's</span> ∧ <span class="genre sci-fi">Sci-Fi</span>) ∨ 
(<span class="genre action">Action</span> ∧ <span class="genre thriller">Thriller</span>)
</p>



Hypothesis
- M_set is not created by multiplication of matrices, so must be higher M_dot
- For M_set boxes would do better than vectors
- For M_dot vectors would do better than boxes (This we do not know for certain)
    - We do not know which angle we should be hitting for understanding the limitation with boxes.
    - Boxes might not be capturing low rank matrices, even if they do, it will be an overfitting to data.

<table border="1" cellpadding="5" cellspacing="0">
  <thead>
    <tr>
      <th rowspan="2">Matrix type</th>
      <th rowspan="2">Matrix rank</th>
      <th rowspan="2">Matrix sparsity</th>
      <th colspan="2">Vector</th>
      <th colspan="2">Box</th>
    </tr>
    <tr>
      <th>Precision@10</th>
      <th>Hits@10</th>
      <th>Precision@10</th>
      <th>Hits@10</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>M_dot</td>
      <td>130</td>
      <td>90.3</td>
      <td>96.5</td>
      <td>100.0</td>
      <td>91.5</td>
      <td>100.0</td>
    </tr>
    <tr>
      <td>M_set</td>
      <td>188</td>
      <td>93.8</td>
      <td>59.15</td>
      <td>89.0</td>
      <td>72.9</td>
      <td>98.0</td>
    </tr>
    <tr>
      <td>M_combined</td>
      <td>92.4</td>
      <td>194</td>
      <td>--</td>
      <td>--</td>
      <td>--</td>
      <td>--</td>
    </tr>
  </tbody>
</table>





Todo

[ ] M_combined: Right now just a convex sum. Evidently rank increases which is bad? Also need to run these experiments

[ ] Dataset size is small -- 1. Run k-fold cross validation 2. need to scale up.

[ ] Hyperparam search (Some manual tuning is done to make the training work)

[ ] Effect of dimension (current dimension size 64)

[ ] Disjunction = 0

[ ] how to combine -- disjucntion? or convex combination ?

[ ] more on the probing combination.. 

[ ] Size of the dataset might affect generalization... also the dimension of boxes? does it over fit.

[ ] training data size, rank, other meta parameters like the dimensions. 

[ ] Similar for the M_set.

[ ] 


[ ] Todo list


[ ] Check all the functionalities.
- Does the data generation agree with what I have? \\
- What is the rank / sparsity of these metrices?
- Is set is of higher rank?
- Also, study the spectrum of the PCA / SVD to come to the conclusion for ranks.
        Because the matrices might all be high rank because of the noisy or incompelteness. But they might be very close to a low rank, and we are after that approximation. Spectrum analysis would catch that


[ ] What is the best Eval in such scenario?
- Rec evals are more discrete so any poor model might also get hits@10 correct, or might just not get it at all. 
- Maybe RMSE is the way to go. Need to understand more on this.

[ ] check the dot product, does normalization help? Exhaustive search space.

[ ] How to densify the set-aspects.

[ ] Users are defined by latent factors.
 - Should those be boxes?
 - if so then what are the latent factors that are more favourable to boxes


[ ] 
- 

Possibilites
- Test bed for the latent box
- But if the higher dimensions are getting it, then honestly need to understand what is wrong.
