Official implementation of "Towards Graph Foundation Models: Learning Generalities Across Graphs via Task-Trees" accepted at ICML 2025.
Authors: Zehong Wang, Zheyuan Zhang, Tianyi Ma, Nitesh V Chawla, Chuxu Zhang, Yanfang Ye
Graph-structured data is everywhere---from social networks to molecular structures---but building general-purpose models for graphs has been difficult due to the wide variety of graph types and tasks. Inspired by the success of foundation models in text and vision, this work introduces a new approach to generalize across different graph tasks using a concept called "task-trees."
- Task-Trees: A unified structure that captures the essential parts of a graph relevant to a specific task and unifies different types of graph tasks (node, edge, graph-level) into a common format.
- Theoretical Foundation: We provide theoretical analysis on the stability, transferability, and generalization properties of task-trees.
- GIT Model: A graph foundation model pretrained on task-trees from diverse graphs, demonstrating strong performance across 30+ datasets in five domains.
- Multiple Learning Paradigms: Support for fine-tuning, few-shot learning, in-context learning, and zero-shot generalization.
GIT employs a three-stage training paradigm:
- Pretraining: Learning generalizable patterns from diverse graphs via task-tree reconstruction
- Supervised Fine-Tuning (SFT): Domain-specific adaptation with supervised data
- Task Fine-Tuning: Final adaptation to downstream tasks with minimal data
- Python 3.8+
- PyTorch 2.1.0+
- PyTorch Geometric 2.5.3
- CUDA 11.8 (recommended)
# Clone the repository
git clone https://github.com/YourUsername/GIT.git
cd GIT
# Create a virtual environment (optional but recommended)
conda env create -f environment.yml
conda activate GITTODO: Add the link to the data.
Download the data from Google Drive and put it in the cache_data/ directory.
Train a foundation model from scratch on diverse graphs:
# Basic pretraining command
python pretrain.py \
--pretrain_dataset default \
--lr 0.0000001 \
--epochs 10 \
--feat_p 0.2 \
--edge_p 0.2 \
--align_reg_lambda 10.0 \
--hidden_dim 512 \
--num_layers 4 \
--backbone gcn \
--seed 42Key Hyperparameters:
--pretrain_dataset: Pretraining data source--lr: Learning rate for pretraining--epochs: Number of pretraining epochs--feat_p: Feature masking probability--edge_p: Edge dropout probability--align_reg_lambda: Weight for alignment regularization--hidden_dim: Hidden dimension of GNN encoder--num_layers: Number of GNN layers--backbone: GNN backbone architecture (gcn, gat, gin, sage)
Perform domain-specific adaptation:
# SFT on citation domain
python sft.py \
--dataset cora \
--pretrain_dataset default \
--use_params# Fine-tune on node classification
python finetune.py \
--dataset cora \
--setting base \
--pt_data default \
--use_params# Few-shot learning (1-shot for Cora, 5-shot for Arxiv)
python finetune.py \
--dataset cora \
--setting few_shot \
--pt_data default \
--n_train 1 \
--n_way 5 \
--n_shot 1 \
--n_query 15 \
--n_task 500 \
--use_params# Zero-shot transfer without any training
python finetune.py \
--dataset cora \
--setting zero_shot \
--pt_data default \
--n_way 5 \
--n_shot 5 \
--n_query 15 \
--n_task 500 \
--use_params# In-context learning with examples in the context
python finetune.py \
--dataset cora \
--setting in_context \
--pt_data default \
--n_way 5 \
--n_shot 5 \
--n_query 15 \
--n_task 500 \
--use_paramsGIT supports over 30 datasets across 5 domains:
- Node Classification: Cora, CiteSeer, PubMed, DBLP, Arxiv, Arxiv23
- Link Prediction: Cora, CiteSeer, PubMed, DBLP, Arxiv, Arxiv23
- Node Classification: BookHis, BookChild, EleComp, ElePhoto, SportsFit, AmazonRatings, Products
- Link Prediction: BookHis, BookChild, EleComp, ElePhoto, SportsFit, AmazonRatings
- Edge Classification: WN18RR, FB15K237, NELL995, CODEX-S, CODEX-M, CODEX-L, GDELT, ICEWS18-19, Enron, Googlemap_CT
- Graph Classification: BBBP, BACE, ToxCast, Tox21, CYP450, ChemHIV, MUV, ChemPCBA
We use Weights & Biases for experiment tracking. To enable logging:
- Install wandb:
pip install wandb - Login:
wandb login - Run experiments with
--debugflag removed (or set toFalse)
To disable wandb logging, add --debug flag to your command.
If you find this work useful, please cite our paper:
@inproceedings{wang2025towards,
title={Towards Graph Foundation Models: Learning Generalities Across Graphs via Task-Trees},
author={Zehong Wang and Zheyuan Zhang and Tianyi Ma and Nitesh V Chawla and Chuxu Zhang and Yanfang Ye},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=BSqf2k01ag}
}For questions or feedback, please contact:
- Zehong Wang: zwang43@nd.edu
- Open an issue on GitHub
This repository is built upon the following excellent codebases:
- PyTorch Geometric: Graph neural network library
- OGB: Open Graph Benchmark
- OFA: One for All graph pretraining framework
We thank the authors for their great work and open-sourcing their code!