# Graph Neural Networks (GNNs)

Graph Neural Networks are a class of deep learning models designed to perform inference on **graph-structured data**. They generalize traditional neural networks to handle non-Euclidean data where relationships between entities are explicitly defined by edges.

## Key Concepts

### 1. Graph Basics
- **Nodes/Vertices (V)**: Represent entities (e.g., users in social networks, atoms in molecules)
- **Edges (E)**: Represent relationships between nodes
- **Node Features**: Attributes associated with each node (e.g., user profiles, atom properties)
- **Adjacency Matrix (A)**: Matrix representation of edge connections

### 2. Core GNN Operations
#### Message Passing
GNNs operate via **message passing** where nodes:
1. **Aggregate** information from neighbors
2. **Update** their own representation based on aggregated messages

#### Common Variants
- **Graph Convolutional Networks (GCNs)**
- **Graph Attention Networks (GATs)**
- **GraphSAGE**

### A basic GNN (following the Message Passing framework) includes:

1. Message function: How a node collects information from neighbors.
2. Aggregation function: Combines messages from all neighbors (e.g., sum, mean).
3. Update function: Updates node's own features using the aggregated message.


- [x] General GNN (Message Passing Framework)

  $m_v^{(l)} = \text{AGGREGATE}(\{h_u^{(l)} : u \in N(v)\})$

  $h_v^{(l+1)} = \text{UPDATE}(h_v^{(l)}, m_v^{(l)})$

We'll implement three typical aggregators:

- Mean
- Sum
- Max

We’ll:

1. Collect neighbor features explicitly
2. Apply a chosen aggregation function (mean, sum, or max)
3. Combine it with the node’s own features

This simulates message passing and custom aggregation — hallmarks of general GNNs.

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Aggregation Function

In [23]:
def aggregate_neighbors(X, A, mode="mean"):
    N = A.size(0)
    X_neighbors = torch.zeros_like(X)

    for i in range(N):
        neighbors = A[i].nonzero().squeeze()
        if neighbors.ndim == 0:
            neighbors = neighbors.unsqueeze(0)

        if len(neighbors) == 0:
            agg = torch.zeros(X.shape[1])
        else:
            neigh_feats = X[neighbors]
            if mode == "mean":
                agg = neigh_feats.mean(dim=0)
            elif mode == "sum":
                agg = neigh_feats.sum(dim=0)
            elif mode == "max":
                agX_neighborsg = neigh_feats.max(dim=0)[0]
            else:
                raise ValueError("Unknown aggregation mode")

        X_neighbors[i] = agg
    return X_neighbors


#### GNN Layer with Custom Aggregation

In [24]:
class GeneralGNNLayer(nn.Module):
    def __init__(self, in_features, out_features, agg_mode="mean"):
        super(GeneralGNNLayer, self).__init__()
        self.W_self = nn.Parameter(torch.randn(in_features, out_features))
        self.W_neigh = nn.Parameter(torch.randn(in_features, out_features))
        self.agg_mode = agg_mode

    def forward(self, X, A):
        # Aggregate from neighbors using custom aggregator
        neighbor_agg = aggregate_neighbors(X, A, self.agg_mode)
        
        # Linear transform for self and neighbors
        h_self = X @ self.W_self
        h_neigh = neighbor_agg @ self.W_neigh

        # Combine and activate
        return F.relu(h_self + h_neigh)


#### Full 2-Layer GNN

In [25]:
class GeneralGNN(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features, agg_mode="mean"):
        super(GeneralGNN, self).__init__()
        self.layer1 = GeneralGNNLayer(in_features, hidden_dim, agg_mode)
        self.layer2 = GeneralGNNLayer(hidden_dim, out_features, agg_mode)

    def forward(self, X, A):
        x = self.layer1(X, A)
        x = self.layer2(x, A)
        return x


In [26]:
X = torch.tensor([
    [1.0, 0.0],
    [0.0, 1.0],
    [1.0, 1.0],
    [0.0, 0.0]
])

A = torch.tensor([
    [0, 1, 0, 0],
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [0, 0, 1, 0]
], dtype=torch.float32)

labels = torch.tensor([0, 1, 0, 1])


In [28]:
model = GeneralGNN(in_features=2, hidden_dim=4, out_features=2, agg_mode="max")  # Try "sum" or "mean"
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(X, A)
    loss = loss_fn(out, labels)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == labels).float().mean().item()
        print(f"Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}")


Epoch 0 | Loss: 1.0915 | Accuracy: 0.0000
Epoch 10 | Loss: 0.8451 | Accuracy: 0.5000
Epoch 20 | Loss: 0.7184 | Accuracy: 0.5000
Epoch 30 | Loss: 0.6838 | Accuracy: 0.5000
Epoch 40 | Loss: 0.6651 | Accuracy: 0.5000
Epoch 50 | Loss: 0.6329 | Accuracy: 0.5000
Epoch 60 | Loss: 0.5772 | Accuracy: 0.5000
Epoch 70 | Loss: 0.5117 | Accuracy: 1.0000
Epoch 80 | Loss: 0.4572 | Accuracy: 1.0000
Epoch 90 | Loss: 0.4174 | Accuracy: 1.0000


##### Matrix multiplication (e.g., $A \cdot X$) is how nodes aggregate information from their neighbors in a GNN.

#### Specifically:
- $A \in \mathbb{R}^{N \times N}$: adjacency matrix
- $A_{ij} = 1$ if node $i$ is connected to node $j$
- $X \in \mathbb{R}^{N \times d}$: node features (each row = node’s feature vector)
- $A \cdot X \rightarrow$ Each node’s new feature is the sum of its neighbors’ features

#### Why it matters:
- It enables message passing: collecting info from neighbors
- When you multiply $A \cdot X$, you’re computing:
  $
  (A X)_i = \sum_{j \in N(i)} X_j
  $
- If $A$ is normalized, it performs mean pooling (e.g., GCN-style)
- Matrix multiplication = fast, vectorized aggregation across all nodes.

### simple GNN

In [None]:
class GNNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        """Implementation of a basic GNN layer from scratch"""
        super().__init__()
        # Weight matrix for neighbor aggregation
        self.neighbor_weights = nn.Linear(input_dim, output_dim)
        # Weight matrix for self node features
        self.self_weights = nn.Linear(input_dim, output_dim)
        # Bias term
        self.bias = nn.Parameter(torch.zeros(output_dim))

    def forward(self, X, adj):
        """
        X: Node features [num_nodes, input_dim]
        adj: Normalized adjacency matrix [num_nodes, num_nodes]
        """
        # aggregate neighbour inforamtion
        neighbour_info = torch.matmul(adj, self.neighbor_weights(X))
        # include self info
        self_info = self.self_weights(X)
        # combine and add bias
        output = neighbour_info + self_info + self.bias
        return F.relu(output)


class GNN(nn.Module):
    """Two-layer GNN implementation"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer1 = GNNLayer(input_dim, hidden_dim)
        self.layer2 = GNNLayer(hidden_dim, output_dim)

    def forward(self, x, adj):
        x = self.layer1(x)
        x = self.layer2(x)

        return F.log_softmax(x, dim=1)

# GCN


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F


#### Normalize Adjacency (with self-loop)

The equation represents a common operation in **Graph Neural Networks (GNNs)**, specifically for **normalizing the adjacency matrix** with self-loops:

$$
\widehat{A} = A + I, \quad \widehat{D}^{-1/2} \widehat{A} \widehat{D}^{-1/2}
$$

### Explanation:
1. **$\widehat{A} = A + I$**:  
   - Adds self-loops to the adjacency matrix $A$ (where $I$ is the identity matrix).  
   - Ensures each node includes its own features during aggregation.  

2. **$\widehat{D}^{-1/2} \widehat{A} \widehat{D}^{-1/2}$**:  
   - Normalizes $\widehat{A}$ using the degree matrix $\widehat{D}$ (diagonal matrix with $\widehat{D}_{ii} = \sum_j \widehat{A}_{ij}$).  
   - Symmetric normalization prevents scaling issues in deep GNNs. 

## Mathematical Comparison

###  GCN (Kipf & Welling)

$ H^{(l+1)} = \sigma \left( \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} H^{(l)} W^{(l)} \right) $

- $\hat{A} = A + I$: adds self-loops
- Normalized symmetric Laplacian: smooths over neighbors
- No custom aggregation/message/update — all built into the matrix product

---

###  General GNN (Message Passing Framework)

$ m_v^{(l)} = \text{AGGREGATE} \left( \{ h_u^{(l)} : u \in N(v) \} \right) $

$ h_v^{(l+1)} = \text{UPDATE} \left( h_v^{(l)}, m_v^{(l)} \right) $

- Modular and expressive
- Can support edge features, attention, direction
- Used in GAT, GraphSAGE, MPNN, etc.

In [29]:
def normalize_adjacency(A):
    A_hat = A + torch.eye(A.size(0)) # self loop
    D_hat = torch.diag(torch.pow(A_hat.sum(1), -0.5))
    return D_hat@A_hat@D_hat
    

#### Define a Custom GCN Layer

In [30]:
class CustomGNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Manually define weight matrix
        self.W = nn.Parameter(torch.randn(in_features, out_features))
        # we use nn.Linear
        # self.linear = nn.Linear(in_features, out_features)

    def forward(self, X, A_hat):
        out = A_hat @ X               # Aggregate neighbor features
        out = out @ self.W            # Linear transformation
        # out = self.linear(out)      # Linear transformation
        return F.relu(out)           # Non-linearity

#### GCN Model

In [31]:
class CustomGNN(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features):
        super(CustomGNN, self).__init__()
        self.layer1 = CustomGNNLayer(in_features, hidden_dim)
        self.layer2 = CustomGNNLayer(hidden_dim, out_features)

    def forward(self, X, A):
        A_hat = normalize_adjacency(A)
        x = self.layer1(X, A_hat)
        x = self.layer2(x, A_hat)
        return x


#### Dummy Data
- Four Nodes
- Each node has two attributes

In [15]:
X = torch.tensor([
    [1.0, 0.0],
    [0.0, 1.0],
    [1.0, 1.0],
    [0.0, 0.0]
])

A = torch.tensor([
    [0, 1, 0, 0],
    [1, 0, 1, 0],
    [0, 1, 0, 1],
    [0, 0, 1, 0]
], dtype=torch.float32)

labels = torch.tensor([0, 1, 0, 1])


#### Training loop

In [20]:
model = CustomGNN(in_features=2, hidden_dim=4, out_features=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fun = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    out = model(X, A)
    loss = loss_fun(out, labels)
    loss.backward()
    optimizer.step()

    if epoch % 25 == 0:
        print(f"Output Shape is: {out.shape} \n output values are \n {out}")
        pred = out.argmax(dim=1)
        print(f"Predictions: {pred}")
        acc = (pred == labels).float().mean().item()
        print(f"Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}")


Output Shape is: torch.Size([4, 2]) 
 output values are 
 tensor([[0.5163, 0.9618],
        [0.6368, 1.1569],
        [0.6303, 1.1506],
        [0.4792, 0.8618]], grad_fn=<ReluBackward0>)
Predictions: tensor([1, 1, 1, 1])
Epoch 0 | Loss: 0.7285 | Accuracy: 0.5000
Output Shape is: torch.Size([4, 2]) 
 output values are 
 tensor([[0.6887, 0.5920],
        [0.8574, 0.7606],
        [0.8507, 0.7609],
        [0.6507, 0.5935]], grad_fn=<ReluBackward0>)
Predictions: tensor([0, 0, 0, 0])
Epoch 25 | Loss: 0.6900 | Accuracy: 0.5000
Output Shape is: torch.Size([4, 2]) 
 output values are 
 tensor([[0.6651, 0.6086],
        [0.8028, 0.7836],
        [0.7899, 0.7843],
        [0.5918, 0.6125]], grad_fn=<ReluBackward0>)
Predictions: tensor([0, 0, 0, 1])
Epoch 50 | Loss: 0.6853 | Accuracy: 0.7500
Output Shape is: torch.Size([4, 2]) 
 output values are 
 tensor([[0.7297, 0.5841],
        [0.8397, 0.7725],
        [0.8150, 0.7784],
        [0.5898, 0.6175]], grad_fn=<ReluBackward0>)
Predictions: tenso

#### **GCN is a special case of a more general GNN framework.**


It's efficient and works well for many tasks, but has limitations:
- **Can’t handle edge features**
- **Assumes all neighbors are equally important**
- **Can over-smooth if too deep**

### GNN vs. GCN — Key Differences

| Feature            | General GNN                          | GCN (Kipf & Welling)                       |
|---------------------|--------------------------------------|--------------------------------------------|
| Aggregation         | User-defined (sum, mean, max, attention, etc.) | Fixed: normalized adjacency $\hat{A} = D^{-1/2} (A + I) D^{-1/2}$ |
| Message Function    | Customizable per edge or node        | Simplified: linear propagation via $\hat{A} \cdot X \cdot W$ |
| Weight Sharing      | May differ per edge or node type     | Single global weight matrix per layer      |
| Edge Features       | Often supported                      | Not in vanilla GCN                         |
| Edge Direction      | Can be directional (e.g., via attention) | Treats edges as undirected                 |
| Learned Adjacency   | Optional (e.g., in attention-based GNNs) | Fixed — no learnable adjacency             |
| Expressive Power    | Flexible, modular                    | Simpler but efficient and robust           |
| Computational Cost  | Higher for complex GNNs              | Low — sparse matrix ops                    |





### Summary Table

| Model       | Best For                | Key Idea                       |
|-------------|-------------------------|--------------------------------|
| GCN         | Node/graph classification | Convolution with normalized adjacency |
| GraphSAGE   | Large, inductive graphs  | Neighborhood sampling          |
| GAT         | Attention on neighbors   | Learnable neighbor importance  |
| GIN         | Graph classification    | Strong expressive power        |
| MPNN        | Custom models           | Message-passing framework      |
| ST-GCN      | Action recognition      | Spatio-temporal structure      |
| R-GCN/HAN   | Heterogeneous graphs    | Node/edge types                |


### Concept Table

| Concept                   | Meaning                          | Why It Matters                  |
|---------------------------|----------------------------------|---------------------------------|
| $ A @ X $          | Aggregate neighbor features      | Enables message passing         |
| $ A + I $               | Add self-loop                    | Include node's own features     |
| $ D^{-1/2} A D^{-1/2} $ | Normalize                        | Balance influence of neighbors  |
| GCN layer                 | Convolution on graphs            | Simple, effective baseline      |
| Custom aggregation (sum/mean/max) | General GNN                | Adds flexibility (like GraphSAGE) |
| Stacking layers           | Increases receptive field        | Enables multi-hop propagation   |

# GAT (Graph Attention Network)
Introduced by Veličković et al., 2018, GAT uses attention mechanisms to weigh neighbor contributions dynamically.
GAT enhances message passing by **assigning different weights** to different neighbors, using **learned attention scores**.



### Core Equation:

$h_i = \sigma \left( \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j \right)$

#### Where:

- $W$: learnable weight matrix
- $\alpha_{ij}$: attention coefficient from node $i$ to node $j$
- $\sigma$: non-linearity (e.g., ReLU)

### GAT (Node-level attention)

| Aspect              | Description                                      |
|---------------------|--------------------------------------------------|
| Introduced by       | "Who among my neighbors should I pay more attention to?" |
| Formula             | For a node $v$, GAT computes: <br> $h'_v = \sigma \left( \sum_{u \in N(v)} \alpha_{vu} W h_u \right)$ |
| Details             | - $\alpha_{vu}$: attention weight between node $v$ and neighbor $u$ <br> - Attention is learned per edge using the node features. |

---

| Aspect              | GAT (Graph Attention Network)               |
|---------------------|---------------------------------------------|
| Introduced by       | Veličković et al., 2018                     |
| Goal                | Learn which neighbors are most important (attend over neighbors) |
| Input Graph         | Homogeneous (same node & edge types)        |
| Attention           | Over neighbor nodes (per node)              |
| Aggregation         | Weighted sum of neighbor features using learned attention |
| Adjacency Matrix    | Static (edges fixed, weights learned)       |
| Use Case            | Social networks, citation networks, molecular graphs |

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F


## GAT Layer (Single-head, Single-layer)

In [38]:
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.W = nn.Parameter(torch.randn(in_features, out_features))
        self.a = nn.Parameter(torch.randn(2 * out_features, 1))
        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, X, A):
        N = X.size(0)

        # Linear transform of input features
        H = X @ self.W  # [N, out_features]

        # Prepare attention inputs (concat h_i || h_j for each edge)
        H_repeat_i = H.repeat(1, N).view(N * N, -1)
        H_repeat_j = H.repeat(N, 1)
        H_cat = torch.cat([H_repeat_i, H_repeat_j], dim=1)  # [N*N, 2*out_features]

        # Compute attention scores
        e = self.leakyrelu(H_cat @ self.a).view(N, N)  # [N, N]

        # Mask non-existent edges
        e = e.masked_fill(A == 0, float("-inf"))

        # Normalize with softmax
        alpha = F.softmax(e, dim=1)  # [N, N]

        # Compute new node representations
        H_prime = alpha @ H  # [N, out_features]
        return F.elu(H_prime)

In [39]:
class GAT(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features):
        super(GAT, self).__init__()
        self.gat1 = GATLayer(in_features, hidden_dim)
        self.gat2 = GATLayer(hidden_dim, out_features)

    def forward(self, X, A):
        x = self.gat1(X, A)
        x = self.gat2(x, A)
        return x


In [40]:
X = torch.tensor([
    [1.0, 0.0],
    [0.0, 1.0],
    [1.0, 1.0],
    [0.0, 0.0]
])

A = torch.tensor([
    [1, 1, 0, 0],  # self-loop included
    [1, 1, 1, 0],
    [0, 1, 1, 1],
    [0, 0, 1, 1]
], dtype=torch.float32)

labels = torch.tensor([0, 1, 0, 1])


In [42]:
model = GAT(in_features=2, hidden_dim=4, out_features=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(X, A)
    loss = loss_fn(out, labels)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == labels).float().mean().item()
        print(f"Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}")


Epoch 0 | Loss: 1.3686 | Accuracy: 0.5000
Epoch 10 | Loss: 1.1047 | Accuracy: 0.5000
Epoch 20 | Loss: 0.9026 | Accuracy: 0.5000
Epoch 30 | Loss: 0.7715 | Accuracy: 0.5000
Epoch 40 | Loss: 0.7103 | Accuracy: 0.5000
Epoch 50 | Loss: 0.6909 | Accuracy: 0.5000
Epoch 60 | Loss: 0.6906 | Accuracy: 0.5000
Epoch 70 | Loss: 0.6892 | Accuracy: 0.5000
Epoch 80 | Loss: 0.6881 | Accuracy: 0.5000
Epoch 90 | Loss: 0.6874 | Accuracy: 1.0000


### Transformer Attention: Quick Recap

In the Transformer architecture, attention is computed as:

$Attention(Q, K, V) = softmax \left( \frac{QK^T}{\sqrt{d}} \right) V$

### Where:

- $Q$ = query vector
- $K$ = key vector
- $V$ = value vector

Each token generates its own $Q$, $K$, and $V$ using learnable linear projections.

### Why GAT Doesn't Explicitly Use Q/K/V

In GAT (Graph Attention Network), the attention is simplified for computational and architectural efficiency:

- Each node uses the same transformed feature vector $h_i = W x_i$ as the basis for both query and key.
- The attention score is computed using concatenated pairs of node features:

  $e_{ij} = LeakyReLU \left( a^\top [W h_i || W h_j] \right)$

Instead of:

  $e_{ij} = Q_i \cdot K_j^\top$

So:

- GAT collapses $Q$, $K$, $V$ into one space: $h_i = W x_i$
- The attention weights $\alpha_{ij}$ are learned per edge through this concatenation mechanism

### GAT vs Transformer Attention

| Component      | Transformer Attention         | GAT Attention                |
|----------------|--------------------------------|------------------------------|
| Query          | $Q = W_Q x$                   | Implicit via $W x_i$         |
| Key            | $K = W_K x$                   | Implicit via $W x_j$         |
| Value          | $V = W_V x$                   | Also $W x_j$ (same as key)   |
| Scoring        | Dot product: $Q K^T$          | Concatenation: $a^\top [W h_i || W h_j]$ |
| Weighting      | Softmax of dot product        | Softmax of $e_{ij}$          |
| Aggregation    | Weighted sum of $V$           | Weighted sum of $W x_j$      |

So even though GAT doesn't explicitly call them $Q$/ $K$/ $V$, conceptually:
- Query = node $i$'s own transformed features
- Key = neighbors' transformed features
- Value = same as Key (neighbors' features)
- The attention weights are computed by concatenating the two and passing through a learned vector $a$, instead of dot product

# Multi-Head Attention in GAT

**Just like in Transformers:**
- Multiple attention heads capture different patterns or relationships.
- It improves stability and representation power.

**Two common ways to combine multi-head outputs:**
- Concatenation (used in hidden layers)
- Averaging (used in the output layer)

### GAT Layer (reuse from before)

In [45]:
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GATLayer, self).__init__()
        self.W = nn.Parameter(torch.randn(in_features, out_features))
        self.a = nn.Parameter(torch.randn(2 * out_features, 1))
        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, X, A):
        N = X.size(0)
        H = X @ self.W  # Linear transformation

        H_repeat_i = H.repeat(1, N).view(N * N, -1)
        H_repeat_j = H.repeat(N, 1)
        H_cat = torch.cat([H_repeat_i, H_repeat_j], dim=1)

        e = self.leakyrelu(H_cat @ self.a).view(N, N)
        e = e.masked_fill(A == 0, float("-inf"))
        alpha = F.softmax(e, dim=1)
        H_prime = alpha @ H
        return F.elu(H_prime)


### Multi-head GAT layer

In [44]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, merge='concat'):
        super(MultiHeadGATLayer, self).__init__()
        self.num_heads = num_heads
        self.merge = merge
        self.attn_heads = nn.ModuleList([
            GATLayer(in_features, out_features) for _ in range(num_heads)
        ])

    def forward(self, X, A):
        head_outputs = [attn(X, A) for attn in self.attn_heads]  # [head1, head2, ...]
        
        if self.merge == 'concat':
            return torch.cat(head_outputs, dim=1)  # [N, out_features * num_heads]
        elif self.merge == 'mean':
            return torch.mean(torch.stack(head_outputs), dim=0)  # [N, out_features]
        else:
            raise ValueError("Merge method must be 'concat' or 'mean'")


### Full Multi-Head GAT Model

In [47]:
class GAT(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features, num_heads=4):
        super(GAT, self).__init__()
        self.gat1 = MultiHeadGATLayer(in_features, hidden_dim, num_heads=num_heads, merge='concat')
        self.gat2 = MultiHeadGATLayer(hidden_dim * num_heads, out_features, num_heads=1, merge='mean')

    def forward(self, X, A):
        x = self.gat1(X, A)
        x = self.gat2(x, A)
        return x


In [48]:
X = torch.tensor([
    [1.0, 0.0],
    [0.0, 1.0],
    [1.0, 1.0],
    [0.0, 0.0]
])

A = torch.tensor([
    [1, 1, 0, 0],
    [1, 1, 1, 0],
    [0, 1, 1, 1],
    [0, 0, 1, 1]
], dtype=torch.float32)

labels = torch.tensor([0, 1, 0, 1])


In [49]:
model = GAT(in_features=2, hidden_dim=4, out_features=2, num_heads=4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(50):
    model.train()
    optimizer.zero_grad()
    out = model(X, A)
    loss = loss_fn(out, labels)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        pred = out.argmax(dim=1)
        acc = (pred == labels).float().mean().item()
        print(f"Epoch {epoch} | Loss: {loss.item():.4f} | Accuracy: {acc:.4f}")


Epoch 0 | Loss: 0.8488 | Accuracy: 0.5000
Epoch 10 | Loss: 0.6743 | Accuracy: 0.5000
Epoch 20 | Loss: 0.6393 | Accuracy: 0.7500
Epoch 30 | Loss: 0.6212 | Accuracy: 0.5000
Epoch 40 | Loss: 0.6089 | Accuracy: 0.7500


# GTN

GTN (Graph Transformer Network) — a model for **heterogeneous graphs**, introduced by Yao et al. (2020). It learns **meta-paths** (combinations of relations) automatically using soft selection of edge types.

**Learns meaningful meta-paths** between nodes by **softly selecting and composing multiple edge types.**

This makes it especially powerful for **heterogeneous node classification**, where the importance of edge types (relations) is **not fixed** and must be learned.



## How GTN works

### Step 1: Soft Edge-Type Selection

Given multiple adjacency matrices $ A_r \in \mathbb{R}^{N \times N} $, GTN learns to softly combine them into a new adjacency matrix $ A^{(l)} $ by:

$ A^{(l)} = \sum_{r=1}^R \alpha_r^{(l)} A_r $

- $ \alpha_r^{(l)} $: learnable weights (normalized using softmax)
- $ A_r $: processed via 1x1 convolution over edge types

This selects the first edge in the meta-path.

---

### Step 2: Learn Longer Meta-Paths

To build longer paths like $ A \rightarrow B \rightarrow C $, multiply adjacency matrices:

$ A^{(l)} = \alpha^{(l-1)}_1 A^{(l-1)} $

This constructs paths of length $ l $. GTN stacks $ l $ such layers to build meta-paths like:

$ A^{(l)} = A^{(1)} \cdot A^{(2)} \cdot \ldots \cdot A^{(l)} $

Each $ ( A^{(l)} ) $ is a soft combination of input edge types, so the full path is learned compositionally.

---

### Step 3: Graph Convolution

After constructing the learned adjacency matrix $ A^{(l)} $, perform a standard GCN layer:

$ Z = \sigma(A^{(l)} X W) $

- \( X \): input node features
- \( W \): learnable weight matrix
- $ A^{(l)} $: normalized learned adjacency matrix
- $ \sigma $: activation function (e.g., ReLU)

This gives the final node embeddings, used for tasks like classification.

---

## Overall Equation

Let $A_r \in \mathbb{R}^{N \times N}$ for $r = 1, \ldots, R$

1. **Learned soft edge matrix:**

$ A^{(l)} = \sum_{r=1}^R \text{softmax}(w_r^{(l)}) \cdot A_r $

2. **Meta-path composition (length-2 simplicity):**

$ A^{(2)} = A^{(1)} \cdot A^{(1)} $

3. **Feature propagation:**

$ Z = \text{ReLU}(A^{(l)} X W) $

You can extend this to multiple channels (paths) and multiple layers for deeper compositions.

## Visualization of GTN Architecture

```sql
Input:
   X  ---> Feature matrix
   A1, A2, ..., AR  ---> Multiple adjacency matrices (edge types)

1×1 Conv:
   Softly selects edge types (learns α1, α2, ..., αR)

Adj Multiplication:
   Builds meta-paths via matrix multiplication

GCN Layer:
   Learns node representations from meta-path graph

Classifier:
   Predicts node classes


### Why GTN is Powerful
- Learns meta-paths without human intuition
- Works on heterogeneous graphs with minimal assumptions
- Handles different semantics of edge types using attention-like soft selection
- Can model longer dependencies via adjacency matrix multiplication