# GNN for Recommendation: LightGCN vs Feature GNN vs Hybrid Fusion

1) **LightGCN** (graph-only, no features, no nonlinearities)
2) **Feature-aware GNN** (GraphSAGE) using node features + nonlinearities
3) **Hybrid**: LightGCN embeddings fused with node features downstream


## Synthetic graph used throughout

**Users:** 0..4  
**Items:** 5..10  

Edges (user → item):
- U0 → I5, I6
- U1 → I6
- U2 → I7
- U3 → I8
- U4 → I9, I10

We make the graph **undirected** for message passing by adding reverse edges.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing, SAGEConv

torch.manual_seed(7)

num_users = 5
num_items = 6
num_nodes = num_users + num_items

# Directed edges user->item (positives)
edge_ui = torch.tensor([
    [0, 0, 1, 2, 3, 4, 4],
    [5, 6, 6, 7, 8, 9, 10]
], dtype=torch.long)

# Make undirected for message passing
edge_index = torch.cat([edge_ui, edge_ui.flip(0)], dim=1)

data = Data(edge_index=edge_index, num_nodes=num_nodes)

edge_ui, edge_index.shape

## Shared utility: link prediction loss

We score a (u, i) pair by dot product of embeddings:

`score(u, i) = <e_u, e_i>`

Training uses a simple logistic loss:
- positives = observed user→item edges
- negatives = random user–item pairs

In [None]:
def dot_score(emb, src, dst):
    return (emb[src] * emb[dst]).sum(dim=-1)

def sample_negatives(n_samples, num_users, num_nodes):
    u = torch.randint(0, num_users, (n_samples,))
    v = torch.randint(num_users, num_nodes, (n_samples,))
    return u, v

def bce_link_loss(emb, pos_src, pos_dst, neg_src, neg_dst):
    pos = dot_score(emb, pos_src, pos_dst)
    neg = dot_score(emb, neg_src, neg_dst)
    loss = -(torch.log(torch.sigmoid(pos) + 1e-9) +
             torch.log(torch.sigmoid(-neg) + 1e-9)).mean()
    return loss

## 1) LightGCN (graph-only)

**LightGCN design constraints**:
- No node features: nodes start as trainable embeddings only
- No nonlinear transforms: propagation is pure neighbor aggregation

This isolates collaborative filtering signal from graph structure.

In [None]:
class LightGCN(MessagePassing):
    def __init__(self, num_nodes, emb_dim=16):
        super().__init__(aggr='mean')
        self.embedding = nn.Embedding(num_nodes, emb_dim)

    def forward(self, edge_index):
        x = self.embedding.weight  # no features, just embeddings
        x = self.propagate(edge_index, x=x)
        return x

    def message(self, x_j):
        return x_j

In [None]:
def train_lightgcn(edge_index, pos_edges, epochs=200, emb_dim=16, lr=0.01):
    model = LightGCN(num_nodes=num_nodes, emb_dim=emb_dim)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    pos_src, pos_dst = pos_edges[0], pos_edges[1]

    for _ in range(epochs):
        opt.zero_grad()
        emb = model(edge_index)
        neg_src, neg_dst = sample_negatives(pos_src.numel(), num_users, num_nodes)
        loss = bce_link_loss(emb, pos_src, pos_dst, neg_src, neg_dst)
        loss.backward()
        opt.step()

    return model, model(edge_index).detach()

pos_edges = edge_ui  # directed positives (user->item)

light_model, light_emb = train_lightgcn(edge_index, pos_edges, epochs=200, emb_dim=16, lr=0.01)
light_emb[:2], light_emb[num_users:num_users+2]

## 2) Feature-aware GNN (GraphSAGE)

To add features:
- Provide a feature matrix `x` for all nodes
- Use a feature-aware GNN layer (GraphSAGE here)
- Include nonlinearities (ReLU)

Here features are random vectors (mechanics only).

In [None]:
class FeatureGNN(nn.Module):
    def __init__(self, in_dim=8, hidden_dim=16):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        return h

In [None]:
def train_feature_gnn(edge_index, pos_edges, epochs=200, in_dim=8, hidden_dim=16, lr=0.01):
    model = FeatureGNN(in_dim=in_dim, hidden_dim=hidden_dim)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # Fake node features (stand-in for text/image/metadata embeddings)
    x = torch.randn(num_nodes, in_dim)

    pos_src, pos_dst = pos_edges[0], pos_edges[1]

    for _ in range(epochs):
        opt.zero_grad()
        emb = model(x, edge_index)
        neg_src, neg_dst = sample_negatives(pos_src.numel(), num_users, num_nodes)
        loss = bce_link_loss(emb, pos_src, pos_dst, neg_src, neg_dst)
        loss.backward()
        opt.step()

    return model, x.detach(), model(x, edge_index).detach()

sage_model, node_x, sage_emb = train_feature_gnn(edge_index, pos_edges, epochs=200, in_dim=8, hidden_dim=16, lr=0.01)
sage_emb[:2], sage_emb[num_users:num_users+2]

## 3) Hybrid fusion (graph embeddings + features downstream)

Common large-scale pattern:
1) Learn graph embeddings from interactions (LightGCN)
2) Compute side features separately (content/metadata)
3) Fuse downstream (concat + MLP here)

This keeps graph training simple while still using rich features.

In [None]:
class HybridRanker(nn.Module):
    def __init__(self, graph_dim=16, feat_dim=8, hidden=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(graph_dim + feat_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, graph_dim)  # project back to embedding space
        )

    def forward(self, graph_emb, node_x):
        z = torch.cat([graph_emb, node_x], dim=-1)
        return self.mlp(z)

In [None]:
def train_hybrid(edge_index, pos_edges, epochs=200, graph_dim=16, feat_dim=8, lr=0.01):
    # Pretrain LightGCN embeddings (fixed graph encoder)
    _, graph_emb = train_lightgcn(edge_index, pos_edges, epochs=200, emb_dim=graph_dim, lr=0.01)

    # Fake features (stand-in for content/metadata)
    node_x = torch.randn(num_nodes, feat_dim)

    ranker = HybridRanker(graph_dim=graph_dim, feat_dim=feat_dim, hidden=32)
    opt = torch.optim.Adam(ranker.parameters(), lr=lr)

    pos_src, pos_dst = pos_edges[0], pos_edges[1]

    for _ in range(epochs):
        opt.zero_grad()
        fused_emb = ranker(graph_emb, node_x)
        neg_src, neg_dst = sample_negatives(pos_src.numel(), num_users, num_nodes)
        loss = bce_link_loss(fused_emb, pos_src, pos_dst, neg_src, neg_dst)
        loss.backward()
        opt.step()

    return graph_emb.detach(), node_x.detach(), ranker, ranker(graph_emb, node_x).detach()

hy_graph_emb, hy_x, hy_ranker, hy_emb = train_hybrid(edge_index, pos_edges, epochs=200, graph_dim=16, feat_dim=8, lr=0.01)
hy_emb[:2], hy_emb[num_users:num_users+2]

## Edge weights (mechanics only)

Edges can carry weights (e.g., interaction strength, time decay).
This notebook attaches an `edge_weight` tensor as a placeholder; a full weighted
propagation implementation is intentionally omitted here.

In [None]:
edge_weight = torch.ones(edge_index.size(1))
edge_weight[:10] *= 2.0  # pretend first few edges are "stronger"
edge_weight[:12]

## Takeaways

- **LightGCN**: graph-only collaborative filtering; scalable; no features by design
- **Feature GNN (GraphSAGE/GAT)**: integrates features; more expressive; heavier compute
- **Hybrid**: graph embeddings + content/metadata fused downstream; common at scale