<a href="https://colab.research.google.com/github/vent0906/ww/blob/main/GAT_HAN_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Attention Networks (GAT) and Heterogeneous Attention Networks (HAN) Tutorial

This notebook demonstrates:
1. Single-head and Multi-head GAT layer implementation
2. Full GAT model on Cora dataset
3. HAN model with node- and semantic-level attention on ACM dataset
4. Training and visualization steps

---

**Please ensure the following packages are installed in your environment**:
```bash
!pip install dgl -f https://data.dgl.ai/wheels/repo.html
!pip install torchdata
```


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GATConv

# Single-head GAT Layer
class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.reset_parameters()

    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        z = self.fc(h)
        self.g.ndata['z'] = z
        self.g.apply_edges(self.edge_attention)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

# Multi-head wrapper
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList([GATLayer(g, in_dim, out_dim) for _ in range(num_heads)])
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        return torch.cat(head_outs, dim=1) if self.merge == 'cat' else torch.mean(torch.stack(head_outs), dim=0)

# Full GAT model
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h


In [None]:
# Semantic-level attention
class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()
        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False)
        )

    def forward(self, z):
        w = self.project(z).mean(0)
        beta = torch.softmax(w, dim=0)
        beta = beta.expand((z.shape[0],) + beta.shape)
        return (beta * z).sum(1)

# HAN Layer
class HANLayer(nn.Module):
    def __init__(self, num_meta_paths, in_size, out_size, num_heads, dropout):
        super(HANLayer, self).__init__()
        self.gat_layers = nn.ModuleList([
            GATConv(in_size, out_size, num_heads, dropout, dropout, activation=F.elu)
            for _ in range(num_meta_paths)
        ])
        self.semantic_attention = SemanticAttention(out_size * num_heads)

    def forward(self, gs, h):
        semantic_embeddings = [gat(gs[i], h).flatten(1) for i, gat in enumerate(self.gat_layers)]
        semantic_embeddings = torch.stack(semantic_embeddings, dim=1)
        return self.semantic_attention(semantic_embeddings)

# Full HAN model
class HAN(nn.Module):
    def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
        super(HAN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout))
        for l in range(1, len(num_heads)):
            self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1], hidden_size, num_heads[l], dropout))
        self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)

    def forward(self, g, h, is_training=True):
        for layer in self.layers:
            h = layer(g, h)
        return self.predict(h) if is_training else h
