diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 001d983bc..5b46c2f0f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -115,7 +115,7 @@ steps: <<: *timeout key: "test-notebooks" depends_on: "runner-3_6" - parallelism: 51 + parallelism: 52 command: ".buildkite/steps/test-demo-notebooks.sh" plugins: <<: *plugins diff --git a/demos/embeddings/deep-graph-infomax-embeddings.ipynb b/demos/embeddings/deep-graph-infomax-embeddings.ipynb index 3395a0044..352d874ca 100644 --- a/demos/embeddings/deep-graph-infomax-embeddings.ipynb +++ b/demos/embeddings/deep-graph-infomax-embeddings.ipynb @@ -25,7 +25,9 @@ "source": [ "This demo demonstrates how to perform unsupervised training of a GCN, GAT, APPNP, or GraphSAGE model using the Deep Graph Infomax algorithm (https://arxiv.org/pdf/1809.10341.pdf) on the CORA dataset. \n", "\n", - "As with all StellarGraph workflows: first we load the dataset, next we create our data generators, and then we train our model. We then take the embeddings created through unsupervised training and predict the node classes using logistic regression." + "As with all StellarGraph workflows: first we load the dataset, next we create our data generators, and then we train our model. We then take the embeddings created through unsupervised training and predict the node classes using logistic regression.\n", + "\n", + "> See [the GCN + Deep Graph Infomax fine-tuning demo](../node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb) for semi-supervised training using Deep Graph Infomax, by fine-tuning the base model for node classification using labelled data." ] }, { @@ -603,6 +605,17 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "This notebook demonstrated how to use the Deep Graph Infomax algorithm to train other algorithms to yield useful embedding vectors for nodes, without supervision. To validate the quality of these vectors, it used logistic regression to perform a supervised node classification task.\n", + "\n", + "See [the GCN + Deep Graph Infomax fine-tuning demo](../node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb) for semi-supervised training using Deep Graph Infomax, by fine-tuning the base model for node classification using labelled data." + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/demos/node-classification/README.md b/demos/node-classification/README.md index 3d5826d7f..00d954341 100644 --- a/demos/node-classification/README.md +++ b/demos/node-classification/README.md @@ -11,6 +11,7 @@ These demos are displayed with detailed descriptions in the documentation: https | [Node classification with Cluster-GCN](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/cluster-gcn-node-classification.html) | [source](cluster-gcn-node-classification.ipynb) | | [Node classification with directed GraphSAGE](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/directed-graphsage-node-classification.html) | [source](directed-graphsage-node-classification.ipynb) | | [Node classification with Graph ATtention Network (GAT)](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gat-node-classification.html) | [source](gat-node-classification.ipynb) | +| [Semi-supervised node classification via GCN, Deep Graph Infomax and fine-tuning](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.html) | [source](gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb) | | [Node classification with Graph Convolutional Network (GCN)](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn-node-classification.html) | [source](gcn-node-classification.ipynb) | | [(Moved) Node classification with Graph Convolutional Network (GCN)](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn/gcn-cora-node-classification-example.html) | [source](gcn/gcn-cora-node-classification-example.ipynb) | | [Inductive node classification and representation learning using GraphSAGE](https://stellargraph.readthedocs.io/en/stable/demos/node-classification/graphsage-inductive-node-classification.html) | [source](graphsage-inductive-node-classification.ipynb) | diff --git a/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb b/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb new file mode 100644 index 000000000..46181ea56 --- /dev/null +++ b/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb @@ -0,0 +1,858 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semi-supervised node classification via GCN, Deep Graph Infomax and fine-tuning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbsphinx": "hidden", + "tags": [ + "CloudRunner" + ] + }, + "source": [ + "
Run the latest release of this notebook:
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This demo demonstrates how to perform semi-supervised node classification, using [the Deep Graph Infomax algorithm](https://arxiv.org/pdf/1809.10341.pdf) and GCN on the Cora dataset. It uses very few labelled training examples, demonstrating the benefits of pre-training a model with Deep Graph Infomax for data scarce environments.\n", + "\n", + "> Other related demos:\n", + "> - [the GCN node classification demo](gcn-node-classification.ipynb) describes the node classification task in more detail, in a supervised context\n", + "> - [the Deep Graph Infomax embeddings demo](../embeddings/deep-graph-infomax-embeddings.ipynb) describes using Deep Graph Infomax in more detail, including applying to algorithms beyond GCN.\n", + "\n", + "This follows the usual StellarGraph workflow:\n", + "\n", + "1. load the dataset\n", + "2. create our data generators\n", + "3. train our model\n", + "\n", + "We do step 3 three times:\n", + "\n", + "1. Pre-train a GCN model using Deep Graph Infomax, without any labelled data\n", + "2. Fine-tune that GCN model using the small training set\n", + "3. Train a fresh GCN model from scratch with the training set (no pre-training)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "nbsphinx": "hidden", + "tags": [ + "CloudRunner" + ] + }, + "outputs": [], + "source": [ + "# install StellarGraph if running on Google Colab\n", + "import sys\n", + "if 'google.colab' in sys.modules:\n", + " %pip install -q stellargraph[demos]==1.1.0b" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "nbsphinx": "hidden", + "tags": [ + "VersionCheck" + ] + }, + "outputs": [], + "source": [ + "# verify that we're using the correct version of StellarGraph for this notebook\n", + "import stellargraph as sg\n", + "\n", + "try:\n", + " sg.utils.validate_notebook_version(\"1.1.0b\")\n", + "except AttributeError:\n", + " raise ValueError(\n", + " f\"This notebook requires StellarGraph version 1.1.0b, but a different version {sg.__version__} is installed. Please see .\"\n", + " ) from None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import stellargraph as sg\n", + "from stellargraph.mapper import CorruptedGenerator, FullBatchNodeGenerator\n", + "from stellargraph.layer import GCN, DeepGraphInfomax\n", + "\n", + "import pandas as pd\n", + "from sklearn import model_selection, preprocessing\n", + "from IPython.display import display, HTML\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.keras import Model, layers, optimizers, callbacks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading the graph" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [ + "DataLoadingLinks" + ] + }, + "source": [ + "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [ + "DataLoading" + ] + }, + "outputs": [ + { + "data": { + "text/html": [ + "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dataset = sg.datasets.Cora()\n", + "display(HTML(dataset.description))\n", + "G, node_classes = dataset.load()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "StellarGraph: Undirected multigraph\n", + " Nodes: 2708, Edges: 5429\n", + "\n", + " Node types:\n", + " paper: [2708]\n", + " Features: float32 vector, length 1433\n", + " Edge types: paper-cites->paper\n", + "\n", + " Edge types:\n", + " paper-cites->paper: [5429]\n", + " Weights: all 1 (default)\n" + ] + } + ], + "source": [ + "print(G.info())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Generators\n", + "\n", + "Now we create the data generators using `CorruptedGenerator` ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.mapper.CorruptedGenerator)). `CorruptedGenerator` returns shuffled node features along with the regular node features and we train our model to discriminate between the two. \n", + "\n", + "Note that:\n", + "\n", + "- We typically pass all nodes to `corrupted_generator.flow` because this is an unsupervised task\n", + "- We don't pass `targets` to `corrupted_generator.flow` because these are binary labels (true nodes, false nodes) that are created by `CorruptedGenerator`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using GCN (local pooling) filters...\n" + ] + } + ], + "source": [ + "fullbatch_generator = FullBatchNodeGenerator(G)\n", + "\n", + "corrupted_generator = CorruptedGenerator(fullbatch_generator)\n", + "gen = corrupted_generator.flow(G.nodes())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model pre-training with Deep Graph Infomax\n", + "\n", + "We create and train our `GCN` ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.layer.GCN)) and `DeepGraphInfomax` ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.layer.DeepGraphInfomax)) models. Note that the loss used here must always be `tf.nn.sigmoid_cross_entropy_with_logits`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def make_gcn_model():\n", + " # function because we want to create a second one with the same parameters later\n", + " return GCN(\n", + " layer_sizes=[16, 16],\n", + " activations=[\"relu\", \"relu\"],\n", + " generator=fullbatch_generator,\n", + " dropout=0.4,\n", + " )\n", + "\n", + "\n", + "pretrained_gcn_model = make_gcn_model()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "infomax = DeepGraphInfomax(pretrained_gcn_model, corrupted_generator)\n", + "x_in, x_out = infomax.in_out_tensors()\n", + "\n", + "dgi_model = Model(inputs=x_in, outputs=x_out)\n", + "dgi_model.compile(\n", + " loss=tf.nn.sigmoid_cross_entropy_with_logits, optimizer=optimizers.Adam(lr=1e-3)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "epochs = 500" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "dgi_es = callbacks.EarlyStopping(monitor=\"loss\", patience=50, restore_best_weights=True)\n", + "dgi_history = dgi_model.fit(gen, epochs=epochs, verbose=0, callbacks=[dgi_es])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sg.utils.plot_history(dgi_history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Node classification\n", + "\n", + "We've now initialised the weights of the model to capture useful properties of the graph structure and node structure. We can now further train the model to perform a node classification prediction task. To emphasise the value of the unsupervised weights, we will use a very small amount of labelled data for training.\n", + "\n", + "> See [the GCN node classification demo](gcn-node-classification.ipynb) for more details on this task.\n", + "\n", + "### Data preparation\n", + "\n", + "The Cora dataset labels academic papers into one of 7 subjects:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
subject
Neural_Networks818
Probabilistic_Methods426
Genetic_Algorithms418
Theory351
Case_Based298
Reinforcement_Learning217
Rule_Learning180
\n", + "
" + ], + "text/plain": [ + " subject\n", + "Neural_Networks 818\n", + "Probabilistic_Methods 426\n", + "Genetic_Algorithms 418\n", + "Theory 351\n", + "Case_Based 298\n", + "Reinforcement_Learning 217\n", + "Rule_Learning 180" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "node_classes.value_counts().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To simulate a data-poor environment, we will split the data into a train set of size 8, along with test and validation sets." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "train_classes, test_classes = model_selection.train_test_split(\n", + " node_classes, train_size=8, stratify=node_classes, random_state=1\n", + ")\n", + "val_classes, test_classes = model_selection.train_test_split(\n", + " test_classes, train_size=500, stratify=test_classes\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The train set has only one or two observations of each class." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
subject
Neural_Networks2
Probabilistic_Methods1
Rule_Learning1
Reinforcement_Learning1
Genetic_Algorithms1
Theory1
Case_Based1
\n", + "
" + ], + "text/plain": [ + " subject\n", + "Neural_Networks 2\n", + "Probabilistic_Methods 1\n", + "Rule_Learning 1\n", + "Reinforcement_Learning 1\n", + "Genetic_Algorithms 1\n", + "Theory 1\n", + "Case_Based 1" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_classes.value_counts().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For a categorical task, the categories need to be one hot encoded." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "target_encoding = preprocessing.LabelBinarizer()\n", + "\n", + "train_targets = target_encoding.fit_transform(train_classes)\n", + "val_targets = target_encoding.transform(val_classes)\n", + "test_targets = target_encoding.transform(test_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "train_gen = fullbatch_generator.flow(train_classes.index, train_targets)\n", + "test_gen = fullbatch_generator.flow(test_classes.index, test_targets)\n", + "val_gen = fullbatch_generator.flow(val_classes.index, val_targets)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fine-tuning model\n", + "\n", + "We now have the required pieces to finalise our GCN model for node classification:\n", + "\n", + "- a GCN model with weights pre-trained with Deep Graph Infomax to capture the graph structure\n", + "- a small train set\n", + "\n", + "We use the same GCN model as before but train it for a supervised categorical prediction task. See [the fully-supervised GCN node classification](gcn-node-classification.ipynb) demo for more details." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_x_in, pretrained_x_out = pretrained_gcn_model.in_out_tensors()\n", + "\n", + "pretrained_predictions = tf.keras.layers.Dense(\n", + " units=train_targets.shape[1], activation=\"softmax\"\n", + ")(pretrained_x_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_model = Model(inputs=pretrained_x_in, outputs=pretrained_predictions)\n", + "pretrained_model.compile(\n", + " optimizer=optimizers.Adam(lr=0.01), loss=\"categorical_crossentropy\", metrics=[\"acc\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_es = callbacks.EarlyStopping(\n", + " monitor=\"val_acc\", patience=50, restore_best_weights=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "pretrained_history = pretrained_model.fit(\n", + " train_gen,\n", + " epochs=epochs,\n", + " verbose=0,\n", + " validation_data=val_gen,\n", + " callbacks=[prediction_es],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sg.utils.plot_history(pretrained_history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've now fine-tuned our model for node classification. Observe that the accuracy in the first few epochs was very poor, but it quickly improved. (The train accuracy plot is quantised because the training set is so small.)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 920us/step - loss: 1.5896 - acc: 0.5632\n", + "{'loss': 1.5896106958389282, 'acc': 0.5631818175315857}\n" + ] + } + ], + "source": [ + "pretrained_test_metrics = dict(\n", + " zip(pretrained_model.metrics_names, pretrained_model.evaluate(test_gen))\n", + ")\n", + "print(pretrained_test_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model without Deep Graph Infomax pre-training\n", + "\n", + "Let's also train an equivalent GCN model in a fully supervised manner, starting with the " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "direct_gcn_model = make_gcn_model()\n", + "direct_x_in, direct_x_out = direct_gcn_model.in_out_tensors()\n", + "direct_predictions = tf.keras.layers.Dense(\n", + " units=train_targets.shape[1], activation=\"softmax\"\n", + ")(direct_x_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "direct_model = Model(inputs=direct_x_in, outputs=direct_predictions)\n", + "direct_model.compile(\n", + " optimizer=optimizers.Adam(lr=0.01), loss=\"categorical_crossentropy\", metrics=[\"acc\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "direct_history = direct_model.fit(\n", + " train_gen,\n", + " epochs=epochs,\n", + " verbose=0,\n", + " validation_data=val_gen,\n", + " callbacks=[prediction_es],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sg.utils.plot_history(direct_history)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1/1 [==============================] - 0s 946us/step - loss: 2.0196 - acc: 0.4559\n", + "{'loss': 2.0196211338043213, 'acc': 0.4559091031551361}\n" + ] + } + ], + "source": [ + "direct_test_metrics = dict(\n", + " zip(direct_model.metrics_names, direct_model.evaluate(test_gen))\n", + ")\n", + "print(direct_test_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparison of model performance\n", + "\n", + "The following table shows the performance of the two models, for comparison." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
lossacc
with DGI pre-training1.590.563
without pre-training2.020.456
\n", + "
" + ], + "text/plain": [ + " loss acc\n", + "with DGI pre-training 1.59 0.563\n", + "without pre-training 2.02 0.456" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(\n", + " [pretrained_test_metrics, direct_test_metrics],\n", + " index=[\"with DGI pre-training\", \"without pre-training\"],\n", + ").round(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this demo, we performed semi-supervised node classification on the Cora dataset. This example had extreme data scarcity: only 8 labelled training examples, with one or two from each of the 7 classes. We used Deep Graph Infomax to train a GCN model on the whole Cora graph, without labels. We then further trained this GCN model in the normal manner, to fine-tuned its weights on the small set of labelled data. The GCN model pre-trained with Deep Graph Infomax outperforms a GCN model without any such pre-training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "nbsphinx": "hidden", + "tags": [ + "CloudRunner" + ] + }, + "source": [ + "
Run the latest release of this notebook:
" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/demos/node-classification/gcn-node-classification.ipynb b/demos/node-classification/gcn-node-classification.ipynb index 654ddac7b..b34291b04 100644 --- a/demos/node-classification/gcn-node-classification.ipynb +++ b/demos/node-classification/gcn-node-classification.ipynb @@ -1333,6 +1333,8 @@ "2. built a TensorFlow Keras model and data generator with [the StellarGraph library](https://github.com/stellargraph/stellargraph) \n", "3. trained and evaluated it using TensorFlow and other libraries\n", "\n", + "For problems with only small amounts of labelled data, model performance can be improved by semi-supervised training. See [the GCN + Deep Graph Infomax fine-tuning demo](gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb) for more details on how to do this.\n", + "\n", "StellarGraph includes [other algorithms for node classification](README.md) and [algorithms and demos for other tasks](../README.md). Most can be applied with the same basic structure as this GCN demo." ] }, diff --git a/docs/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.nblink b/docs/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.nblink new file mode 100644 index 000000000..c89920100 --- /dev/null +++ b/docs/demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.nblink @@ -0,0 +1,3 @@ +{ + "path": "../../../demos/node-classification/gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb" +} diff --git a/docs/demos/node-classification/index.rst b/docs/demos/node-classification/index.rst index c342b5cbf..8b55b76f1 100644 --- a/docs/demos/node-classification/index.rst +++ b/docs/demos/node-classification/index.rst @@ -119,6 +119,14 @@ This table lists all node classification demos, including the algorithms trained - yes - - yes + * - :doc:`GCN, Deep Graph Infomax and fine-tuning ` + - GCN, DeepGraphInfomax, semi-supervised training + - yes + - + - + - + - + - yes See :doc:`the root README <../../README>` or each algorithm's documentation for the relevant citation(s). See :doc:`the demo index <../index>` for more tasks, and a summary of each algorithm.