# Import, preprocess, and store data split

## Import relevant packages

In [28]:
import torch
import torch_geometric
from torch_geometric.datasets import Coauthor, Planetoid, WikipediaNetwork
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import subgraph

import pickle
import bz2
import time

torch.manual_seed(10)

<torch._C.Generator at 0x104615cb0>

## Import homogeneous datasets

In [8]:
root = '../data'
wiki_datasets = ["chameleon","crocodile"]
planetoid_dataset = "PubMed"
coauthor_dataset = "CS"

wiki_chameleon = WikipediaNetwork(root=root, name=wiki_datasets[0]).data
wiki_crocodile = WikipediaNetwork(root=root, name=wiki_datasets[1], geom_gcn_preprocess=False).data
pubmed = Planetoid(root=root, name=planetoid_dataset).data
cs = Coauthor(root=root, name=coauthor_dataset).data


Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_node_feature_label.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/new_data/chameleon/out1_graph_edges.txt
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_0.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_1.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_2.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f/splits/chameleon_split_0.6_0.2_3.npz
Downloading https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/f1fc0d14

## Split each dataset and save

In [43]:
def inductive_split(data: torch_geometric.data):
    ''' 
    Function that takes graph data and creates train, test, and valid masks 
    in order to perform inductive link prediction.

    When splitting wikipedia data, use split 0.8, 0.9
    When splitting pubmed and cs, use split 0.30, 0.65
    '''
    rands = torch.rand(data.num_nodes)
    
    # Initialize size of splits
    train_mask =  rands < 0.3
    test_mask = rands > 0.65       
    val_mask = []

    #Create val_mask with nodes not in train_mask or test_mask
    for i in torch.arange(data.num_nodes): 
        if (i not in train_mask.nonzero() and i not in test_mask.nonzero()):
            val_mask.append(True)
        else:
            val_mask.append(False)
            
    val_mask  = torch.Tensor(val_mask).to(torch.bool)

    # Create subgraphs based on node assignments in masks        
    train_data = data.clone()
    train_data.edge_index, _ = subgraph(train_mask, data.edge_index, relabel_nodes=True)
    train_data.x = data.x[train_mask]

    val_data = data.clone()
    val_data.edge_index, _ = subgraph(val_mask, data.edge_index, relabel_nodes=True)
    val_data.x = data.x[val_mask]
            
    test_data = data.clone()
    test_data.edge_index, _ = subgraph(test_mask, data.edge_index, relabel_nodes=True)
    test_data.x = data.x[test_mask]

    # Save each train, test, val subgraph using pickling
    timestr = time.strftime("%Y%m%d-%H%M%S")
    names =['train_data_'+timestr,'valid_data_'+timestr,'test_data_'+timestr]
    data = [train_data, val_data, test_data]
    for data, name in zip(data,names):
        pickle.dump(data, bz2.BZ2File('../data/{0}.p'.format(name),'wb'))

In [44]:
wiki_data = [wiki_chameleon, wiki_crocodile]
other_data = [pubmed, cs]
for set in other_data:
    inductive_split(set)