by Patrick Pho (Phuong Pho) and Alexander V. Mantzaris
This repo is an official implementation of Regularized Simple Graph Convolution (SGC) in our paper - "Regularized Simple Graph Convolution (SGC) for improved interpretability of large datasets".
We incorporate a flexible regularization scheme into SGC rendering shrinkage upon three different aspects of the model's weight vectors. The
If you find this repo useful, please cite:
@article{pho2020regularized,
title={Regularized Simple Graph Convolution (SGC) for improved interpretability of large datasets},
author={Pho, Phuong and Mantzaris, Alexander V},
journal={Journal of Big Data},
volume={7},
number={1},
pages={1--17},
year={2020},
publisher={Springer}
}
The dependencies can be install via:
pip install -r requirement.txt
For GPU machine, please refer to official instruction to install suitable version of pytorch
and dgl
:
Two synthetic datasets discussed in our paper can be found in data/
.
An example of incorporating
python main.py --dataset cora --L1 0.5 --L2 1 --L3 2
Use --ortho
to impose orthogonality constraint between the weight vectors with
python main.py --dataset cora --L1 0.5 --L2 1 --L3 2 --L3-ortho
Use --save-trained
to save trained model for inference. The trained model is save in ./checkpoints
python main.py --dataset cora --L1 0.5 --L2 1 --L3 2 --L3-ortho --save-trained
Other useful options for training:
--early-stop
: turn on early stopping to reduce overfitting. Default metric is loss--hist-print
: print training history at every t epoch--plot
: plot option to use with synthetic datasets
We provide predict.py
for users to make prediction on custom dataset. Before running it, you will need:
- Import your dataset as
.dgl
in./data
. Note that you need to includepred_mask
masking boolean tensor in order to make prediction for unlabeled nodes. Withoutpred_mask
, it will make prediction ontest_mask
nodes. - Train model on your custom-dataset and save it. The model's name is pre-defined as `{dataset}+_sgc+k_{k_value}+L1_{L1_value}+L2_{L2_value}+L3_{L3_value}.pt'
Then, user can obtain the prediction for unlabeled nodes by running:
python predict.py --modelpath ./checkpoints/model_name.pt