# Modification-friendly implementation of FAS via jupyter notebook

Jupyter notebook makes more convenient modification and active interaction coding than a monotonous python compiler.

## Environmental Settings

In [None]:
%pip install torch torchvision torchaudio
%pip install torch-geometric
%pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv
%pip install ogb
%pip install PyMetis

In [None]:
import os

PATH = os.getcwd()

## Data Preprocessing

- This section focuses on generating the coarsened data, including coarsened feature, coarsed edge index, and coarsening matrix. The generated data will be stored in the `preprocessed` folder.
- You must run this before initializing the training.
- Two sequential coarsening procedures, with 0.1 coarsening ratio of each, are provided.

In [None]:
## METIS, cpu-based implementation.

import torch
import pymetis

def metis_coarsen_normalized(edge_index: torch.Tensor, r: float):
    """
    Partitions an undirected, unweighted graph using METIS (via pymetis), constructs the
    normalized partition matrix (C) and the coarsened graph's edge index computed as C A C^T.

    The partition matrix is normalized so that for each supernode (cluster), the nonzero
    entry is divided by sqrt(number of nodes in the cluster).

    Args:
        edge_index (torch.Tensor): A 2 x E tensor (CPU) where each column [i, j] represents an
                                   undirected edge between nodes i and j.
        r (float): Coarsening ratio; the number of clusters is computed as k = int(n * r), clamped to [1, n].

    Returns:
        partition_matrix (torch.Tensor): A (n x k) normalized binary matrix where each row is a one-hot
                                         vector indicating the cluster assignment of that node.
        coarse_edge_index (torch.Tensor): A 2 x E_coarse tensor representing the edges of the coarsened graph
                                            (constructed as C A C^T) with no duplicate edges and no self-loops.
    """
    # Work on CPU
    edge_index_cpu = edge_index.cpu()

    # Determine the number of nodes (assuming nodes are 0-indexed)
    n = int(edge_index_cpu.max().item() + 1)
    # Compute number of clusters (partitions)
    k = max(1, int(n * r))
    k = min(k, n)

    # Build the adjacency list for pymetis.
    src = edge_index_cpu[0].tolist()
    dst = edge_index_cpu[1].tolist()
    adjacency = [set() for _ in range(n)]
    for i, j in zip(src, dst):
        adjacency[i].add(j)
        adjacency[j].add(i)  # Ensure undirectedness.
    adjacency = [list(neighbors) for neighbors in adjacency]

    # Partition the graph using PyMetis.
    # pymetis.part_graph returns (edgecut, parts) where parts is a list of cluster assignments.
    _, parts = pymetis.part_graph(k, adjacency)

    # Build the partition matrix (n x k) with one-hot encoding.
    partition_matrix = torch.zeros(n, k, dtype=torch.float32)
    parts_tensor = torch.tensor(parts, dtype=torch.long)
    partition_matrix.scatter_(1, parts_tensor.unsqueeze(1), 1.0)

    # Normalize the partition matrix:
    # For each cluster j, divide the corresponding column by sqrt(number of nodes in cluster j)
    cluster_counts = partition_matrix.sum(dim=0) + 1  # shape: (k,)
    norm_factors = torch.sqrt(cluster_counts)
    partition_matrix = partition_matrix / norm_factors.unsqueeze(0)

    # --- Construct the coarsened graph's edge index ---
    # "Lift" each original edge (i,j) to (parts[i], parts[j])
    coarse_u = parts_tensor[edge_index_cpu[0]]
    coarse_v = parts_tensor[edge_index_cpu[1]]

    # Remove self-loops (intra-cluster edges)
    mask = coarse_u != coarse_v
    coarse_u = coarse_u[mask]
    coarse_v = coarse_v[mask]

    # Sort each edge pair so that (u,v) and (v,u) are treated as the same edge.
    row = torch.min(coarse_u, coarse_v)
    col = torch.max(coarse_u, coarse_v)
    coarse_edges = torch.stack((row, col), dim=0)

    # Remove duplicate edges.
    coarse_edge_index = torch.unique(coarse_edges, dim=1)

    return partition_matrix, coarse_edge_index

In [None]:
from dataset import load_nc_dataset

dataset = load_nc_dataset('ogbn-arxiv')
node_feat = dataset.graph['node_feat']
edge_index = dataset.graph['edge_index']

### First step coarsening: G -> G1, r = 0.1

In [None]:
c1, g1 = metis_coarsen_normalized(edge_index, 0.1)

torch.save(c1, PATH+'/preprocessed/C1.pt')
torch.save(g1, PATH+'/preprocessed/G1.pt')

node_feat = c1.T @ node_feat
torch.save(node_feat, PATH+'/preprocessed/X1.pt')

### Second step coarsening: G1 -> G2, r = 0.1

In [None]:
c2, g2 = metis_coarsen_normalized(g1, 0.1)

torch.save(c1 @ c2, PATH+'/preprocessed/C2C1.pt')
torch.save(g2, PATH+'/preprocessed/G2.pt')

node_feat = c2.T @ node_feat
torch.save(node_feat, PATH+'/preprocessed/X2.pt')

## FAS Training

- Main training code for FAS. Please make sure you have generate the data.
- If you want to do the full graph training, please only run the **Level 1 training** with setting the `epoch` to be 500, then run the testing code.

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.nn as nn


class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=256, num_layers=3,
                 dropout=0.5):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=False))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=False))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=False))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)

In [None]:
import random
import numpy as np
import gc

def one_hot(x, class_count):
    return torch.eye(class_count, device=x.device)[x, :]

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

def create_mask(indices, n):
    # Create a tensor of zeros with length n
    mask = torch.zeros(n, dtype=torch.float, device=indices.device)

    # Set the positions of the training indices to 1
    mask[indices] = 1.0

    return mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

## loss functions
loss_nll = nn.NLLLoss()
loss_kl = nn.KLDivLoss(reduction='batchmean')

from dataset import load_nc_dataset
dataset = load_nc_dataset('ogbn-arxiv')

In [None]:
set_seed(2703)

model = GCN(dataset.num_features, dataset.num_classes).to(device)
model.reset_parameters()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

splits = dataset.get_idx_split()
train_idx = splits['train'].to(device)
test_idx = splits['test'].to(device)

x0 = dataset.graph['node_feat'].to(device)
idx0 = dataset.graph['edge_index'].to(device)
y0 = dataset.label.to(device)

complem_idx = torch.arange(x0.shape[0]).to(device)
complem_idx = complem_idx[~torch.isin(complem_idx, train_idx)]

### Level 1 training on the original graph

In [None]:
x = x0.clone()
idx = idx0.clone()
y = y0.clone()

for _ in range(50):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)[train_idx]
  loss = loss_nll(out, y[train_idx])
  loss.backward()
  optimizer.step()

### Level 2 training on G1 (r=0.1)

In [None]:
c = torch.load(PATH+'/preprocessed/C1.pt').to(device)

In [None]:
## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X1.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G1.pt').to(device)

  residual = model(x, idx) - residual

  ## update the label on coarsened graph
  y = y0.clone()
  y[complem_idx] = model(x0, idx0).argmax(dim=1).long()[complem_idx] ## incorporate the inferred label
  y = one_hot(y, dataset.num_classes)
  y = c.T @ y
  y = y + residual ## incorporate τ-correction
  y = y.log_softmax(dim=-1)
  y = y.detach()

## training
for _ in range(100):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out, y.argmax(dim=1).long())
  # loss = loss_nll(out[mask],y[mask].argmax(dim=1).long()) 
  loss.backward()
  optimizer.step()


r'''
# If you want to discard the inferred label and dismiss all non-training nodes...

## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X1.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G1.pt').to(device)

  residual = model(x, idx) - residual

  y = y0.clone()
  y = one_hot(y, dataset.num_classes)
  y[complem_idx, :] = torch.zeros_like(y[complem_idx, :])
  y = c.T @ y
  y = y + residual
  y = y.log_softmax(dim=-1)
  y = y.detach()

  mask = torch.nonzero(y.any(dim=1)).squeeze()

## training
for _ in range(100):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out[mask],y[mask].argmax(dim=1).long()) 
  loss.backward()
  optimizer.step()
'''


### Level 3 training on G2 (r=0.01)

In [None]:
c = torch.load(PATH+'/preprocessed/C2C1.pt').to(device)

In [None]:
## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X2.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G2.pt').to(device)

  residual = model(x, idx) - residual

  ## update the label on coarsened graph
  y = y0.clone()
  y[complem_idx] = model(x0, idx0).argmax(dim=1).long()[complem_idx] ## incorporate the inferred label
  y = one_hot(y, dataset.num_classes)
  y = c.T @ y
  y = y + residual ## incorporate τ-correction
  y = y.log_softmax(dim=-1)
  y = y.detach()

## training
for _ in range(200):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out, y.argmax(dim=1).long())
  # loss = loss_nll(out[mask],y[mask].argmax(dim=1).long()) 
  loss.backward()
  optimizer.step()


r'''
# If you want to discard the inferred label and dismiss all non-training nodes...

## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X2.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G2.pt').to(device)

  residual = model(x, idx) - residual

  y = y0.clone()
  y = one_hot(y, dataset.num_classes)
  y[complem_idx, :] = torch.zeros_like(y[complem_idx, :])
  y = c.T @ y
  y = y + residual
  y = y.log_softmax(dim=-1)
  y = y.detach()

  mask = torch.nonzero(y.any(dim=1)).squeeze()

## training
for _ in range(100):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out[mask],y[mask].argmax(dim=1).long()) 
  loss.backward()
  optimizer.step()
'''

### Level 2 training on G1 (r=0.1)

In [None]:
c = torch.load(PATH+'/preprocessed/C1.pt').to(device)

In [None]:
# Inferred labels are included to construct the training labels

## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X1.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G1.pt').to(device)

  residual = model(x, idx) - residual

  ## update the label on coarsened graph
  y = y0.clone()
  y[complem_idx] = model(x0, idx0).argmax(dim=1).long()[complem_idx] ## incorporate the inferred label
  y = one_hot(y, dataset.num_classes)
  y = c.T @ y
  y = y + residual ## incorporate τ-correction
  y = y.log_softmax(dim=-1)
  y = y.detach()

## training
for _ in range(100):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out, y.argmax(dim=1).long())
  # loss = loss_nll(out[mask],y[mask].argmax(dim=1).long()) 
  loss.backward()
  optimizer.step()


r'''
# If you want to discard the inferred label and dismiss all non-training nodes...

## data update
with torch.no_grad():
  model.eval()

  ## τ-correction
  residual = c.T @ model(x0, idx0)

  x = torch.load(PATH+'/preprocessed/X1.pt').to(device)
  idx = torch.load(PATH+'/preprocessed/G1.pt').to(device)

  residual = model(x, idx) - residual

  y = y0.clone()
  y = one_hot(y, dataset.num_classes)
  y[complem_idx, :] = torch.zeros_like(y[complem_idx, :]) ## enforce zero in non-training nodes
  y = c.T @ y
  y = y + residual
  y = y.log_softmax(dim=-1)
  y = y.detach()

  mask = torch.nonzero(y.any(dim=1)).squeeze() ## dismiss the zero vectors, only non-zero vectors are considered as training set

## training
for _ in range(100):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)
  loss = loss_nll(out[mask],y[mask].argmax(dim=1).long())
  loss.backward()
  optimizer.step()
'''

### Level 1 training on the original graphs

In [None]:
x = x0.clone()
idx = idx0.clone()
y = y0.clone()

## training

for _ in range(50):
  model.train()
  optimizer.zero_grad()
  out = model(x, idx)[train_idx]
  loss = loss_nll(out, y[train_idx])
  loss.backward()
  optimizer.step()

### Testing

In [None]:
with torch.no_grad():
  model.eval()

  pred = model(x0.clone(), idx0.clone()).max(1)[1]
  test_acc = int(pred[test_idx].eq(y0[test_idx]).sum().item()) / int(test_idx.shape[0])
  print(f'The testing accuracy is {test_acc}')