# Link prediction exercise: the PubMed-Diabetes citation dataset

In this example, we build a model to predict citation links in the PubMed-Diabetes dataset. The problem is treated as a supervised link prediction problem on a homogeneous citation network with nodes representing papers (with attributes such as binary keyword indicators and categorical subject) and links corresponding to paper-paper citations. 

To address this problem, we build a GraphSAGE model that takes pairs of papers as (paper1, paper2) corresponding to possible citation links, and outputs a pair of node embeddings each paper in the pair. These embeddings are then fed into a link classification layer, which applies a binary operator to those node embeddings (e.g., multiplying them elementwise) to construct the link embedding of the (paper1, paper2) pair. Next a dense layer with sigmoid activation function is used to obtain the prediction for these candidate links to actually exist in the network.

The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy) using stochastic gradient descent (SGD) updates of the model parameters using the 'training' node pairs fed into the model.

**References**

[1] Inductive Representation Learning on Large Graphs, W. L. Hamilton, R. Ying, and J. Leskovec, NIPS 2017

Copyright © 2010-2019 Commonwealth Scientific and Industrial Research Organisation (CSIRO). All Rights Reserved.

In [1]:
import itertools
import os
import keras
import networkx as nx
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import stellargraph as sg
from stellargraph.data import EdgeSplitter
from stellargraph.mapper import GraphSAGELinkGenerator
from stellargraph.layer import GraphSAGE, link_classification
from sklearn import preprocessing, feature_extraction, model_selection, metrics

%matplotlib inline

Using TensorFlow backend.


### Loading the PubMed Diabetes network data

**Downloading the dataset:**
    
The dataset used in this demo can be downloaded from https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz

The following is the description of the dataset:

> The Pubmed Diabetes dataset consists of 19717 scientific publications from PubMed database 
> pertaining to diabetes classified into one of three classes. The citation network consists 
> of 44338 links. Each publication in the dataset is described by a TF/IDF weighted word 
> vector from a dictionary which consists of 500 unique words.

Download and unzip the `Pubmed-Diabetes.tgz` file to a location on your computer.

Set the `data_dir` variable to point to the location of the processed dataset.

In [2]:
data_dir = "../data/pubmed/"

Load the graph from edgelist

In [3]:
edgelist = pd.read_csv(
    os.path.join(data_dir, 'Pubmed-Diabetes.DIRECTED.cites.tab'), 
    sep='\t', skiprows=2, header=None, usecols=[1,3], names=['source', 'target']
)
edgelist = edgelist.applymap(lambda x: x.lstrip('paper:')) 

In [4]:
edgelist.head()

Unnamed: 0,source,target
0,19127292,17363749
1,19668377,17293876
2,1313726,3002783
3,19110882,14578298
4,18606979,10333910


Load the features and subject for the nodes

In [5]:
node_attributes = []
node_index = []
with open(os.path.join(os.path.expanduser(data_dir), "Pubmed-Diabetes.NODE.paper.tab")) as fp:
    for line in itertools.islice(fp, 2, None):
        line_res = line.split("\t")
        row = {k:float(v) for k,v in map(lambda x: x.split("="), line_res[1:-1])}
        node_attributes.append(row)
        node_index.append(line_res[0])

node_data = pd.DataFrame(node_attributes, index=node_index)
node_data.fillna(0, inplace=True)
node_features = node_data.drop('label', axis=1)

In [6]:
node_features.head()

Unnamed: 0,w-rat,w-common,w-use,w-examin,w-pathogenesi,w-retinopathi,w-mous,w-studi,w-anim,w-model,...,w-kidney,w-urinari,w-myocardi,w-meal,w-ica,w-locus,w-tcell,w-depress,w-bone,w-mutat
12187484,0.093935,0.028698,0.01176,0.019375,0.063161,0.170891,0.067702,0.017555,0.098402,0.062691,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2344352,0.023618,0.0,0.014784,0.0,0.0,0.0,0.0,0.0,0.030926,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
14654069,0.102263,0.0,0.010669,0.0,0.0,0.0,0.0,0.0,0.044636,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
16443886,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.038715,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2684155,0.030616,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.080179,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


<h2 style="color:red;"> Exercise:</h2>

<p style="color:red; font-size:12pt">
Using the loaded Pubmed-Diabetes dataset, create positive/negative pairs of nodes then build, train and test a link prediction model.
</p>