In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [3]:
import torch
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import bipartite_subgraph

In [4]:
# -----------------------------
# Fake bipartite user-item graph
# -----------------------------

num_users = 5
num_items = 6

# User nodes: 0..4
# Item nodes: 5..10
edge_index = torch.tensor([
    [0, 0, 1, 2, 3, 4, 4],
    [5, 6, 6, 7, 8, 9, 10]
], dtype=torch.long)

edge_index = torch.cat([
    edge_index,
    edge_index.flip(0)  # make graph undirected
], dim=1)

num_nodes = num_users + num_items

data = Data(edge_index=edge_index, num_nodes=num_nodes)

In [5]:
# -----------------------------
# LightGCN-style model
# -----------------------------

class LightGCN(MessagePassing):
    def __init__(self, num_nodes, emb_dim=16):
        super().__init__(aggr='mean')
        self.embedding = nn.Embedding(num_nodes, emb_dim)

    def forward(self, edge_index):
        x = self.embedding.weight
        x = self.propagate(edge_index, x=x)
        return x

    def message(self, x_j):
        return x_j

In [6]:
# -----------------------------
# Training setup
# -----------------------------

model = LightGCN(num_nodes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def link_score(emb, u, v):
    return (emb[u] * emb[v]).sum(dim=-1)

# Fake positive edges
pos_edges = edge_index[:, :7]

# Fake negative samples
neg_u = torch.randint(0, num_users, (7,))
neg_v = torch.randint(num_users, num_nodes, (7,))

In [7]:
# -----------------------------
# Training loop (toy)
# -----------------------------

for epoch in range(200):
    optimizer.zero_grad()
    emb = model(edge_index)

    pos_score = link_score(emb, pos_edges[0], pos_edges[1])
    neg_score = link_score(emb, neg_u, neg_v)

    loss = -(torch.log(torch.sigmoid(pos_score)) +
             torch.log(torch.sigmoid(-neg_score))).mean()

    loss.backward()
    optimizer.step()

In [8]:
# -----------------------------
# Inspect learned embeddings
# -----------------------------

user_emb = emb[:num_users]
item_emb = emb[num_users:]

user_emb, item_emb


(tensor([[-2.5379e-01,  9.1925e-01,  1.9571e-01,  9.5120e-01, -1.4217e+00,
          -1.0167e+00,  6.9940e-01, -3.3470e-01,  5.8302e-01,  1.0590e+00,
           1.3386e-01, -1.0035e+00, -5.9297e-01,  3.3305e-01,  9.7317e-01,
          -7.2101e-01],
         [ 4.0034e-01,  9.4504e-01,  6.4802e-01, -3.4993e-01, -1.6369e+00,
          -8.4376e-01,  1.8318e-01,  1.1974e-03,  6.0001e-01,  1.2351e+00,
           1.1668e-01, -2.4145e+00,  1.4162e-01,  2.9610e-01,  2.3693e+00,
          -1.9178e+00],
         [-1.8660e+00,  8.3005e-01, -7.1758e-01,  7.0469e-01,  1.0254e-01,
          -1.0156e+00,  1.4306e-01, -1.2017e+00,  2.7094e-01, -1.6761e-01,
           4.8004e-01, -1.7187e+00, -1.4504e+00, -5.5880e-01,  1.0909e+00,
          -9.3800e-01],
         [ 1.5826e+00,  2.0897e+00, -3.1421e-01, -5.6868e-01,  1.8863e-01,
          -2.8943e-03, -1.3105e+00, -1.0523e-01,  4.3304e-02,  1.1022e+00,
           2.4489e+00,  2.4218e-02, -6.5123e-01, -2.0600e+00,  9.0842e-01,
          -4.4030e-01],
    