In [None]:
import os
import numpy as np
import pandas as pd
import scipy.io

dataset_name = 'pubmed'
data_dir = os.path.join('../datasets/raw', dataset_name)
fn = os.path.join(data_dir, 'ind.{}.mat'.format(dataset_name))
data = scipy.io.loadmat(fn)

n_train = data['all_x'].shape[0]
n_test = data['tx'].shape[0]
test_indices = np.squeeze(data['test_idx'])

print('num train: {} num test: {}'.format(n_train, n_test))

In [None]:
testIds = set(test_indices)

#Remove any test node that only links to the test nodes
def load_graph(fn):
    graph = {}

    with open(fn) as in_csv:
        for line in in_csv:
            tokens = line.strip().split(',')
            nodeIDs = [int(t) for t in tokens]
            key = nodeIDs[0]
            neighbors = nodeIDs[1:]
            graph[key] = neighbors
    return graph

graph_fn = os.path.join(data_dir, 'ind.{}.graph.csv'.format(dataset_name))
gp = load_graph(graph_fn)

In [None]:
# Graph should contain the same number of train data as all_x
train_indices = [key for key in gp if key not in testIds and key < n_train]
assert(len(train_indices) == n_train)

# get valid train indices
train_indices = [key for key in gp if key not in testIds]
extra_train_indices = [idx for idx in train_indices if idx >= n_train]
#print(len(train_indices))
#print(len(extra_train_indices))
train_indices = [idx for idx in train_indices if idx not in extra_train_indices]
#print(len(train_indices))

In [None]:
#####################################################################################
valid_train_nodes = []
for nodeId in train_indices:    
    vertices = gp[nodeId]
    num_nodes = len(vertices)
    num_test_nodes = len([v for v in vertices if v in testIds or v in extra_train_indices])
    
    if num_nodes > num_test_nodes:
        valid_train_nodes.append(nodeId)
        
print('total train: {}'.format(n_train))
print('total valid train: {}'.format(len(valid_train_nodes)))

# figure out the trainId to keep (starting from 0)
valid_train_ids = [idx for idx, trainId in enumerate(train_indices) if trainId in valid_train_nodes]
assert(len(valid_train_nodes) == len(valid_train_ids))

# filter the test data and labels
train_data = data['all_x'][valid_train_ids, :]
train_labels = data['all_y'][valid_train_ids, :]
assert(train_data.shape[0] == train_labels.shape[0] == len(valid_train_ids))

In [None]:
#####################################################################################
valid_test_nodes = []
for testId in test_indices:
    vertices = gp[testId]
    
    num_nodes = len(vertices)
    num_valid_train_nodes = len([v for v in vertices if v in set(valid_train_ids)])
    
    if num_valid_train_nodes > 0:
        valid_test_nodes.append(testId)
        
print('total test: {}'.format(n_test))
print('total valid test: {}'.format(len(valid_test_nodes)))

# make sure there is no duplication
assert(len(valid_test_nodes) == len(set(valid_test_nodes)))

# figure out the testId to keep (starting from 0)
valid_test_ids = [idx for idx, testId in enumerate(test_indices) if testId in valid_test_nodes]
assert(len(valid_test_nodes) == len(valid_test_ids))

# filter the test data and labels
test_data = data['tx'][valid_test_ids, :]
test_labels = data['ty'][valid_test_ids, :]
assert(test_data.shape[0] == test_labels.shape[0] == len(valid_test_ids))

In [None]:
# create a conversion from global id to trainId
globalId2TrainID = {}
for trainId, globalId in enumerate(valid_train_nodes):
    globalId2TrainID[globalId] = trainId

In [None]:
from tqdm import tqdm 
# create train graph
train_graph = {}
for nodeId in tqdm(valid_train_nodes):
    assert(nodeId in gp)
    vertices = [globalId2TrainID[v] for v in gp[nodeId] if v in set(valid_train_nodes)]
    
    assert(len(vertices) > 0)
    train_graph[globalId2TrainID[nodeId]] = vertices
    
# create test graph
test_graph = {}
for testId, nodeId in enumerate(valid_test_nodes):
    assert(nodeId in gp)
    vertices = [globalId2TrainID[v] for v in gp[nodeId] if v in set(valid_train_nodes)]
    
    assert(len(vertices) > 0)
    test_graph[testId] = vertices # use index starting from 0
    
assert(len(train_graph) == train_data.shape[0])
assert(len(test_graph) == test_data.shape[0])

print('train: {} test: {}'.format(train_data.shape[0], test_data.shape[0]))

In [None]:
# convert labels to a sparse matrix format
import sklearn.preprocessing
from scipy import sparse

train_labels = np.argmax(train_labels, axis=1)
test_labels = np.argmax(test_labels, axis=1)

n_classes = np.max(train_labels) - np.min(train_labels) + 1

label_binarizer = sklearn.preprocessing.LabelBinarizer()
label_binarizer.fit(range(n_classes))

gnd_train = label_binarizer.transform(train_labels)
gnd_test = label_binarizer.transform(test_labels)
gnd_train = sparse.csr_matrix(gnd_train)
gnd_test = sparse.csr_matrix(gnd_test)

print(gnd_train.shape)
print(gnd_test.shape)

In [None]:
# create a connection matrix
n_train = train_data.shape[0]
train_connections = np.zeros((n_train, n_train), dtype=int)
for doc_id in train_graph:
    train_connections[doc_id][train_graph[doc_id]] = 1
train_connections = sparse.csr_matrix(train_connections)

n_test = test_data.shape[0]
test_connections = np.zeros((n_test, n_train), dtype=int)
for doc_id in test_graph:
    test_connections[doc_id][test_graph[doc_id]] = 1
test_connections = sparse.csr_matrix(test_connections)

In [None]:
save_dir = os.path.join('../datasets/clean', dataset_name)
##########################################################################################

train = []
for doc_id in train_graph:
    doc = {'doc_id': doc_id, 'bow': train_data[doc_id], 
           'label': gnd_train[doc_id], 'neighbors': train_connections[doc_id]}
    train.append(doc)

train_df = pd.DataFrame.from_dict(train)
train_df.set_index('doc_id', inplace=True)

fn = os.path.join(save_dir, '{}.train.pkl'.format(dataset_name))
train_df.to_pickle(fn)
##########################################################################################

test = []
for doc_id in test_graph:
    doc = {'doc_id': doc_id, 'bow': test_data[doc_id], 
           'label': gnd_test[doc_id], 'neighbors': test_connections[doc_id]}
    test.append(doc)

test_df = pd.DataFrame.from_dict(test)
test_df.set_index('doc_id', inplace=True)

fn = os.path.join(save_dir, '{}.test.pkl'.format(dataset_name))
test_df.to_pickle(fn)