In [86]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from model import MatrixFactorizationWithBias
from data_processing import DataProcessing
import yaml
import operator
import numpy as np

In [25]:
wandb_config = yaml.safe_load(open("./wandb/run-20240208_203625-d6efp63d/files/config.yaml"))
data_dir = wandb_config["data_dir"]["value"]
batch_size = wandb_config["batch_size"]["value"]
model_dir = wandb_config["model_dir"]["value"]
embedding_dim = wandb_config["embedding_dim"]["value"]


In [20]:
data_dir

'data/small_data'

In [24]:
dataset = DataProcessing(data_dir, batch_size)
n_users = len(dataset.user2id) 
n_user_attrs = len(dataset.user_attribute2id)
n_items = len(dataset.item2id)
n_item_attrs = len(dataset.item_attribute2id)

In [137]:
n_users, n_user_attrs, n_items, n_item_attrs

(54, 6, 209, 10)

In [26]:
model = MatrixFactorizationWithBias(n_users + n_item_attrs, n_items + n_user_attrs, embedding_dim).to(device)
model.load_state_dict(torch.load('model/small_data_model/20240208203623/model.pth'))

<All keys matched successfully>

In [42]:
model(torch.tensor([54]), torch.tensor([209]))

tensor([-1.1799], grad_fn=<AddBackward0>)

## Pairwise affinity

In [126]:
similarity_dict = {}
for u_attr in range(n_user_attrs):
    for i_attr in range(n_item_attrs):
        u_attr_mod = u_attr + n_items
        i_attr_mod = i_attr + n_users
        u_attr_mod = torch.tensor([u_attr_mod]).to(device)
        i_attr_mod = torch.tensor([i_attr_mod]).to(device)
        user_attr_embedding = model.user_embeddings(i_attr_mod)
        item_attr_embedding = model.item_embeddings(u_attr_mod)
        similarity = torch.nn.functional.cosine_similarity(user_attr_embedding, item_attr_embedding)

        similarity_dict[(u_attr, i_attr)] = similarity.item()

In [127]:
sorted_similarity_dict = dict(sorted(similarity_dict.items(), key=operator.itemgetter(1)))

## Intersection similarity

In [81]:
intersection_similarity_dict = {}
for i in range(n_user_attrs):
    for j in range(i+1, n_user_attrs):
        for k in range(n_item_attrs):
            for l in range(k+1, n_item_attrs):
                if i == j or k == l:
                    continue
                
                i_mod = i + n_items
                j_mod = j + n_items
                k_mod = k + n_users
                l_mod = l + n_users
                i_mod = torch.tensor([i_mod]).to(device)
                j_mod = torch.tensor([j_mod]).to(device)
                k_mod = torch.tensor([k_mod]).to(device)
                l_mod = torch.tensor([l_mod]).to(device)

                intersection_user_attr = (model.item_embeddings(i_mod) + model.item_embeddings(j_mod))/2
                intersection_item_attr = (model.user_embeddings(k_mod) + model.user_embeddings(l_mod))/ 2
                similarity = torch.nn.functional.cosine_similarity(intersection_user_attr, intersection_item_attr)
                intersection_similarity_dict[((i, j), (k, l))] = similarity.item()

In [112]:
sorted_intersection_similarity_dict = dict(sorted(intersection_similarity_dict.items(), key=operator.itemgetter(1), reverse=True))

In [125]:
sorted_list = list(sorted_intersection_similarity_dict.keys())
print(1-sorted_list.index(((3, 4), (1, 6)))/len(sorted_list))
print(1-sorted_list.index(((3, 4), (2, 7)))/len(sorted_list))
print(1-sorted_list.index(((2, 3), (1, 2)))/len(sorted_list))

0.8325925925925926
0.7792592592592593
0.9259259259259259


In [101]:
len(sorted_list)

675

In [132]:
triple_intersection_similarity_dict = {}
for i in range(n_user_attrs):
    for j in range(i+1, n_user_attrs):
        for k in range(j +1, n_user_attrs):
            for x in range(n_item_attrs):
                for y in range(x+1, n_item_attrs):
                    for z in range(y+1, n_item_attrs):
                        if i == j or i == k or j == k or x == y or x == z or y == z:
                            continue
                        i_mod = i + n_items
                        j_mod = j + n_items
                        k_mod = k + n_items
                        x_mod = x + n_users
                        y_mod = y + n_users
                        z_mod = z + n_users
                        i_mod = torch.tensor([i_mod]).to(device)
                        j_mod = torch.tensor([j_mod]).to(device)
                        k_mod = torch.tensor([k_mod]).to(device)
                        x_mod = torch.tensor([x_mod]).to(device)
                        y_mod = torch.tensor([y_mod]).to(device)
                        z_mod = torch.tensor([z_mod]).to(device)

                        intersection_user_attr = (model.item_embeddings(i_mod) + model.item_embeddings(j_mod) + model.item_embeddings(k_mod))/3
                        intersection_item_attr = (model.user_embeddings(x_mod) + model.user_embeddings(y_mod) + model.user_embeddings(z_mod))/3
                        similarity = torch.nn.functional.cosine_similarity(intersection_user_attr, intersection_item_attr)
                        triple_intersection_similarity_dict[((i, j, k), (x, y, z))] = similarity.item()

In [134]:
sorted_intersection_similarity_dict = dict(sorted(triple_intersection_similarity_dict.items(), key=operator.itemgetter(1), reverse=True))

In [136]:
sorted_list = list(sorted_intersection_similarity_dict.keys())
print(1-sorted_list.index(((0, 1, 4), (4, 6, 8)))/len(sorted_list))

0.7666666666666666


In [62]:
sorted_intersection_similarity_dict = dict(sorted(intersection_similarity_dict.items(), key=operator.itemgetter(1)))

In [63]:
list(sorted_intersection_similarity_dict.keys())

[((1, 5, 0), (0, 7, 1)),
 ((1, 5, 0), (7, 0, 1)),
 ((5, 1, 0), (0, 7, 1)),
 ((5, 1, 0), (7, 0, 1)),
 ((0, 5, 1), (0, 7, 1)),
 ((0, 5, 1), (1, 7, 0)),
 ((0, 5, 1), (7, 0, 1)),
 ((0, 5, 1), (7, 1, 0)),
 ((1, 5, 0), (0, 1, 7)),
 ((1, 5, 0), (1, 0, 7)),
 ((1, 5, 0), (1, 7, 0)),
 ((1, 5, 0), (7, 1, 0)),
 ((5, 0, 1), (0, 7, 1)),
 ((5, 0, 1), (1, 7, 0)),
 ((5, 0, 1), (7, 0, 1)),
 ((5, 0, 1), (7, 1, 0)),
 ((5, 1, 0), (0, 1, 7)),
 ((5, 1, 0), (1, 0, 7)),
 ((5, 1, 0), (1, 7, 0)),
 ((5, 1, 0), (7, 1, 0)),
 ((0, 1, 5), (1, 7, 0)),
 ((0, 1, 5), (7, 1, 0)),
 ((0, 5, 1), (0, 1, 7)),
 ((0, 5, 1), (1, 0, 7)),
 ((1, 0, 5), (1, 7, 0)),
 ((1, 0, 5), (7, 1, 0)),
 ((5, 0, 1), (0, 1, 7)),
 ((5, 0, 1), (1, 0, 7)),
 ((0, 1, 5), (0, 1, 7)),
 ((0, 1, 5), (0, 7, 1)),
 ((0, 1, 5), (1, 0, 7)),
 ((0, 1, 5), (7, 0, 1)),
 ((1, 0, 5), (0, 1, 7)),
 ((1, 0, 5), (0, 7, 1)),
 ((1, 0, 5), (1, 0, 7)),
 ((1, 0, 5), (7, 0, 1)),
 ((0, 1, 4), (0, 7, 9)),
 ((0, 1, 4), (7, 0, 9)),
 ((0, 1, 4), (7, 9, 0)),
 ((0, 1, 4), (9, 7, 0)),


In [35]:
n_items + n_user_attrs

215