In [1]:
# !pip install torch_geometric

In [2]:
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import bipartite_subgraph

In [3]:
# -----------------------------
# Fake bipartite user-item graph
# -----------------------------

num_users = 5
num_items = 6

# User nodes: 0..4
# Item nodes: 5..10
edge_index = torch.tensor([
    [0, 0, 1, 2, 3, 4, 4],
    [5, 6, 6, 7, 8, 9, 10]
], dtype=torch.long)

edge_index = torch.cat([
    edge_index,
    edge_index.flip(0)  # make graph undirected
], dim=1)

num_nodes = num_users + num_items

data = Data(edge_index=edge_index, num_nodes=num_nodes)

In [4]:
# -----------------------------
# LightGCN-style model
# -----------------------------

class LightGCN(MessagePassing):
    def __init__(self, num_nodes, emb_dim=16):
        super().__init__(aggr='mean')
        self.embedding = nn.Embedding(num_nodes, emb_dim)

    def forward(self, edge_index):
        x = self.embedding.weight
        x = self.propagate(edge_index, x=x)
        return x

    def message(self, x_j):
        return x_j

In [5]:
# -----------------------------
# Training setup
# -----------------------------

model = LightGCN(num_nodes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def link_score(emb, u, v):
    return (emb[u] * emb[v]).sum(dim=-1)

# Fake positive edges
pos_edges = edge_index[:, :7]

# Fake negative samples
neg_u = torch.randint(0, num_users, (7,))
neg_v = torch.randint(num_users, num_nodes, (7,))

In [6]:
# -----------------------------
# Training loop (toy)
# -----------------------------

for epoch in range(200):
    optimizer.zero_grad()
    emb = model(edge_index)

    pos_score = link_score(emb, pos_edges[0], pos_edges[1])
    neg_score = link_score(emb, neg_u, neg_v)

    loss = -(torch.log(torch.sigmoid(pos_score)) +
             torch.log(torch.sigmoid(-neg_score))).mean()

    loss.backward()
    optimizer.step()

In [7]:
# -----------------------------
# Inspect learned embeddings
# -----------------------------

user_emb = emb[:num_users]
item_emb = emb[num_users:]

user_emb, item_emb


(tensor([[-1.9455e-01,  3.1994e-02,  6.6894e-01, -7.0011e-01, -6.2982e-01,
           5.1913e-01, -1.6569e+00, -6.2002e-02, -6.7799e-01, -5.8978e-02,
           1.2717e+00, -6.2968e-02,  7.0329e-01,  8.6945e-01,  7.6488e-01,
          -2.7270e-01],
         [ 8.8332e-01,  1.1308e+00,  3.1472e+00, -6.4049e-01, -3.2743e-01,
           1.8817e-01, -1.9712e+00, -4.2489e-01, -1.6407e-01, -1.1748e+00,
           5.9092e-01, -4.0980e-01,  3.5567e-01,  1.3121e+00,  1.2401e+00,
          -3.8750e-01],
         [ 7.9444e-01,  8.5395e-01,  1.4312e-01, -3.0750e-01,  6.1006e-01,
           1.0399e-01,  5.2376e-01, -4.6007e-01, -1.5739e+00, -1.2141e+00,
           1.2020e+00,  1.7073e+00,  1.5505e-01,  9.0475e-01, -3.1685e+00,
           1.3775e-01],
         [-2.7334e-02, -8.4439e-01, -6.4368e-02,  3.4073e-01, -1.5428e-01,
           5.8621e-01,  1.8120e-01,  1.0208e+00,  2.0293e-01,  1.3506e+00,
          -1.6467e+00, -4.4674e-01,  7.3005e-01,  1.0035e-01, -1.5520e-01,
           3.3186e-01],
    