In [3]:
import pickle
import sys
import torch
from torch_geometric.data import Data
from torch_geometric.nn import LightGCN
from torch_geometric.utils import from_networkx

sys.path.append("../complexity_hunters/")  # to make utils importable
sys.path.append(".")  # to make utils importable
sys.path.append("..")  # to make utils importable

import utils.data_worker
import utils.consts

from graph.graph import build_graph
import igraph

from complexity_hunters.extra_metrics import sets_iou

# build_graph()

graph = pickle.load(open("../data/graph.pkl", "rb"))

print("Converting graph to PyTorch Geometric format")
data = from_networkx(graph)
data.edge_index = data.edge_index.long()

# Define LightGCN model
class RecommendationModel(torch.nn.Module):
    def __init__(self, num_users, num_questions, embedding_dim=64, num_layers=3):
        super().__init__()
        self.model = LightGCN(num_nodes=num_users + num_questions, num_layers=num_layers, embedding_dim=embedding_dim)
        self.user_embeddings = torch.nn.Embedding(num_users, embedding_dim)
        self.question_embeddings = torch.nn.Embedding(num_questions, embedding_dim)

    def forward(self, edge_index):
        x = torch.cat([self.user_embeddings.weight, self.question_embeddings.weight])
        return self.model(x, edge_index)

user_nodes = [node for node in graph.nodes if graph.nodes[node]["type"] == "user"]
question_nodes = [node for node in graph.nodes if graph.nodes[node]["type"] == "question"]

user_mapping = {node: idx for idx, node in enumerate(user_nodes)}
question_mapping = {node: idx + len(user_nodes) for idx, node in enumerate(question_nodes)}

edge_index = data.edge_index.clone()
for idx in range(edge_index.shape[1]):
    src, dst = edge_index[:, idx]
    if src in user_mapping and dst in question_mapping:
        edge_index[0, idx] = user_mapping[src]
        edge_index[1, idx] = question_mapping[dst]

num_users = len(user_nodes)
num_questions = len(question_nodes)
embedding_dim = 64
model = RecommendationModel(num_users, num_questions, embedding_dim)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

labels = torch.ones(edge_index.shape[1])

print("Training LightGCN")
for epoch in range(10):
    model.train()
    optimizer.zero_grad()
    outputs = model(edge_index)
    loss = criterion(outputs.squeeze(), labels)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


print("Making recommendations")
brand_new_question = posts[posts.PostTypeId == 1].sample()
brand_new_question_id = brand_new_question["Id"].values[0]
brand_new_question_idx = question_mapping[brand_new_question_id]

with torch.no_grad():
    model.eval()
    question_embedding = model.question_embeddings(torch.tensor([brand_new_question_idx]))
    user_embeddings = model.user_embeddings.weight
    scores = torch.matmul(user_embeddings, question_embedding.T).squeeze()

top_k = 5
recommended_users = scores.topk(top_k).indices
recommended_user_ids = [user_nodes[idx] for idx in recommended_users]
print(f"Recommended users for question {brand_new_question_id}: {recommended_user_ids}")


INFO: Loading dataset ../data/Posts.xml...
INFO: Loading dataset ../data/Badges.xml...


extracting user pairs based on tags: 100%|██████████████████| 2013/2013 [00:01<00:00, 1605.45it/s]


INFO: Dumped graph into ./data/graph.pkl
Converting graph to PyTorch Geometric format...
Training LightGCN...


RuntimeError: scatter(): Expected dtype int64 for index