<a href="https://colab.research.google.com/github/sabrinabenb/Graph-UNet/blob/main/Graph_Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['PYTHONWARNINGS'] = "ignore"
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
from typing import Callable, List, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_sparse import spspmm

from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    sort_edge_index,
)
from torch_geometric.utils.repeat import repeat
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()

In [None]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]
g = dataset[0] # Get the first graph object.
print(data)


Dataset: Cora():
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [None]:
from typing import Callable, List, Union

import torch
from torch import Tensor

from torch_geometric.nn import GCNConv, TopKPooling
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.typing import OptTensor, PairTensor
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_torch_csr_tensor,
)
from torch_geometric.utils.repeat import repeat


class GraphUNet(torch.nn.Module):
    r"""The Graph U-Net model from the `"Graph U-Nets"
    <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
    architecture with graph pooling and unpooling operations.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Size of each hidden sample.
        out_channels (int): Size of each output sample.
        depth (int): The depth of the U-Net architecture.
        pool_ratios (float or [float], optional): Graph pooling ratio for each
            depth. (default: :obj:`0.5`)
        sum_res (bool, optional): If set to :obj:`False`, will use
            concatenation for integration of skip connections instead
            summation. (default: :obj:`True`)
        act (torch.nn.functional, optional): The nonlinearity to use.
            (default: :obj:`torch.nn.functional.relu`)
    """
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        depth: int,
        pool_ratios: Union[float, List[float]] = 0.5,
        sum_res: bool = True,
        act: Union[str, Callable] = 'relu',
    ):
        super().__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = activation_resolver(act)
        self.sum_res = sum_res

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(in_channels, channels, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels, improved=True))

        in_channels = channels if sum_res else 2 * channels

        self.up_convs = torch.nn.ModuleList()
        for _ in range(depth - 1):
            self.up_convs.append(GCNConv(in_channels, channels, improved=True))
        self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()


    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        batch: OptTensor = None,
        edge_weight: Tensor = None,
    ) -> Tensor:
        """"""  # noqa: D419
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        if edge_weight is None:
            edge_weight = x.new_ones(edge_index.size(1))
        assert edge_weight.dim() == 1
        assert edge_weight.size(0) == edge_index.size(1)

        x = self.down_convs[0](x, edge_index, edge_weight)
        x = self.act(x)

        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []

        for i in range(1, self.depth + 1):
            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,
                                                       x.size(0))
            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch)

            x = self.down_convs[i](x, edge_index, edge_weight)
            x = self.act(x)

            if i < self.depth:
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
            perms += [perm]

        for i in range(self.depth):
            j = self.depth - 1 - i

            res = xs[j]
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]

            up = torch.zeros_like(res)
            up[perm] = x
            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)

            x = self.up_convs[i](x, edge_index, edge_weight)
            x = self.act(x) if i < self.depth - 1 else x

        return x


    def augment_adj(self, edge_index: Tensor, edge_weight: Tensor,
                    num_nodes: int) -> PairTensor:
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        adj = to_torch_csr_tensor(edge_index, edge_weight,
                                  size=(num_nodes, num_nodes))
        adj = (adj @ adj).to_sparse_coo()
        edge_index, edge_weight = adj.indices(), adj.values()
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.hidden_channels}, {self.out_channels}, '
                f'depth={self.depth}, pool_ratios={self.pool_ratios})')

Depth 4

In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=4, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_edge(data.edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    out, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

  adj = torch.sparse_csr_tensor(


Epoch: 001, Train: 0.2000, Val: 0.3200, Test: 0.3270
Epoch: 002, Train: 0.2786, Val: 0.3200, Test: 0.3270
Epoch: 003, Train: 0.1929, Val: 0.3200, Test: 0.3270
Epoch: 004, Train: 0.2214, Val: 0.3200, Test: 0.3270
Epoch: 005, Train: 0.2357, Val: 0.3200, Test: 0.3270
Epoch: 006, Train: 0.3143, Val: 0.3200, Test: 0.3270
Epoch: 007, Train: 0.4357, Val: 0.3200, Test: 0.3270
Epoch: 008, Train: 0.4857, Val: 0.3200, Test: 0.3270
Epoch: 009, Train: 0.4571, Val: 0.3200, Test: 0.3270
Epoch: 010, Train: 0.2357, Val: 0.3200, Test: 0.3270
Epoch: 011, Train: 0.2286, Val: 0.3200, Test: 0.3270
Epoch: 012, Train: 0.2071, Val: 0.3200, Test: 0.3270
Epoch: 013, Train: 0.2500, Val: 0.3200, Test: 0.3270
Epoch: 014, Train: 0.2714, Val: 0.3200, Test: 0.3270
Epoch: 015, Train: 0.2643, Val: 0.3200, Test: 0.3270
Epoch: 016, Train: 0.3143, Val: 0.3200, Test: 0.3270
Epoch: 017, Train: 0.4643, Val: 0.3200, Test: 0.3270
Epoch: 018, Train: 0.6286, Val: 0.4380, Test: 0.4510
Epoch: 019, Train: 0.6429, Val: 0.4500, Test: 

#Depth 5

In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=5, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_edge(data.edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    out, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Train: 0.1786, Val: 0.1240, Test: 0.1280
Epoch: 002, Train: 0.1571, Val: 0.1240, Test: 0.1280
Epoch: 003, Train: 0.1643, Val: 0.1240, Test: 0.1280
Epoch: 004, Train: 0.2786, Val: 0.1500, Test: 0.1770
Epoch: 005, Train: 0.3786, Val: 0.2080, Test: 0.2260
Epoch: 006, Train: 0.4357, Val: 0.2320, Test: 0.2530
Epoch: 007, Train: 0.4643, Val: 0.2320, Test: 0.2530
Epoch: 008, Train: 0.5143, Val: 0.2680, Test: 0.2940
Epoch: 009, Train: 0.5143, Val: 0.2980, Test: 0.3170
Epoch: 010, Train: 0.5571, Val: 0.3440, Test: 0.3490
Epoch: 011, Train: 0.6071, Val: 0.3700, Test: 0.3810
Epoch: 012, Train: 0.5929, Val: 0.3700, Test: 0.3810
Epoch: 013, Train: 0.4143, Val: 0.3700, Test: 0.3810
Epoch: 014, Train: 0.4429, Val: 0.3700, Test: 0.3810
Epoch: 015, Train: 0.5571, Val: 0.3700, Test: 0.3810
Epoch: 016, Train: 0.4929, Val: 0.3700, Test: 0.3810
Epoch: 017, Train: 0.5286, Val: 0.3700, Test: 0.3810
Epoch: 018, Train: 0.5000, Val: 0.3700, Test: 0.3810
Epoch: 019, Train: 0.5643, Val: 0.3700, Test: 

Depth 13

In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=13, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_edge(data.edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    out, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

Depth 50

In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=50, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_edge(data.edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    out, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 002, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 003, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 004, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 005, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 006, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 007, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 008, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 009, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 010, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 011, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 012, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 013, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 014, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 015, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 016, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 017, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 018, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 019, Train: 0.1429, Val: 0.1220, Test: 

Depth 100

In [None]:
import os.path as osp

import torch
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pool_ratios = [2000 / data.num_nodes, 0.5]
        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,
                              depth=100, pool_ratios=pool_ratios)

    def forward(self):
        edge_index, _ = dropout_edge(data.edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(data.x, p=0.92, training=self.training)

        x = self.unet(x, edge_index)
        return F.log_softmax(x, dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)


def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    out, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = out[mask].argmax(1)
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

  adj = torch.sparse_csr_tensor(


Epoch: 001, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 002, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 003, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 004, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 005, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 006, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 007, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 008, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 009, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 010, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 011, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 012, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 013, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 014, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 015, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 016, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 017, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 018, Train: 0.1429, Val: 0.1220, Test: 0.1300
Epoch: 019, Train: 0.1429, Val: 0.1220, Test: 