In [1]:
!pip install pyvis

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyvis
  Downloading pyvis-0.3.2-py3-none-any.whl (756 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Collecting jedi>=0.16 (from ipython>=5.3.0->pyvis)
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi, pyvis
Successfully installed jedi-0.18.2 pyvis-0.3.2


In [2]:
import pandas as pd
import os
import numpy as np
import random
import matplotlib.pyplot as plt
import scipy.sparse as sp
from nltk.corpus import wordnet as wn
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial.distance import cosine
from collections import defaultdict
import torch
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pyvis.network import Network
import time

In [3]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


## Word-Word & Word-doc Weight



In [4]:
synthea_df = pd.read_csv('/content/drive/MyDrive/data/synthea.csv')

In [5]:
size = len(synthea_df)
train_size = int(0.75*size)
val_size = int(0.1*size)
test_size = size - train_size - val_size
print(size, train_size, test_size, val_size)

724 543 109 72


In [6]:
shuffled_id = np.arange(size)
random.shuffle(shuffled_id)

In [7]:
shuffled_document = synthea_df['Document'][shuffled_id]
shuffled_labels = synthea_df['Condition'].values[shuffled_id]

In [8]:
print(len(shuffled_document))
print(len(shuffled_labels))

724
724


In [9]:
labels_list = list(synthea_df['Condition'].unique())

In [10]:
labels_list

[10509002, 15777000, 444814009, 65363002]

In [11]:
# build vocab
word_freq = defaultdict(lambda : 0)
word_set = set()
for doc_words in shuffled_document:
    words = doc_words.split()
    for word in words:
        word_set.add(word)
        word_freq[word] += 1

vocab = list(word_set)
vocab_size = len(vocab)
word_embeddings_dim = vocab_size

In [12]:
print(words)

['73761001', '824184']


In [13]:
print(vocab[:5])
print(word_freq['239981'])
print(vocab_size)

['46706006', '997488', '755621000000101', '834101', '757594']
3
53


In [14]:
# each word in what doc
word_doc_list = defaultdict(lambda : [])
for index, doc in enumerate(shuffled_document):
    appeared = set()
    doc_words = doc.split()
    for word in doc_words:
        if word in appeared:
            continue
        word_doc_list[word].append(index)
        appeared.add(word)
print(word_doc_list)       

defaultdict(<function <lambda> at 0x7f603a7dea70>, {'311791003': [0, 5, 19, 20, 24, 29, 38, 47, 48, 50, 65, 67, 73, 77, 79, 84, 88, 97, 109, 113, 120, 121, 148, 153, 154, 155, 159, 162, 163, 184, 187, 191, 199, 200, 210, 214, 217, 225, 226, 228, 236, 244, 245, 247, 250, 253, 262, 269, 276, 285, 289, 292, 297, 298, 311, 313, 318, 329, 331, 337, 342, 345, 348, 349, 364, 367, 377, 395, 413, 419, 423, 425, 426, 433, 439, 444, 462, 463, 464, 467, 472, 486, 487, 496, 500, 505, 506, 507, 510, 513, 514, 515, 525, 532, 541, 542, 561, 576, 582, 583, 588, 593, 600, 606, 613, 618, 627, 633, 635, 656, 657, 664, 668, 681, 686, 687, 691, 697, 698, 706, 717, 721], '282464': [0, 5, 19, 20, 24, 48, 50, 97, 109, 113, 155, 186, 191, 194, 210, 217, 225, 228, 236, 244, 247, 248, 253, 262, 269, 289, 297, 298, 313, 318, 364, 367, 413, 419, 425, 439, 444, 462, 467, 487, 507, 513, 514, 515, 541, 561, 576, 613, 627, 635, 657, 664, 697, 698, 717, 721], '392151': [0, 5, 19, 20, 38, 47, 48, 50, 65, 67, 88, 109, 148

In [15]:
# each word in how many docs
word_doc_freq = {}
for word, doc_list in word_doc_list.items():
    word_doc_freq[word] = len(doc_list)
    
print(word_doc_freq)

{'311791003': 122, '282464': 56, '392151': 77, '73761001': 175, '834060': 252, '23426006': 357, '608680': 357, '824184': 95, '287664005': 29, '1049630': 6, '46706006': 48, '834101': 74, '757594': 8, '749762': 2, '1367439': 6, '1020137': 27, '309097': 23, '755621000000101': 33, '646250': 12, '1366342': 19, '65200003': 13, '807283': 28, '198405': 58, '748856': 9, '76601001': 3, '831533': 6, '751905': 6, '106258': 2, '389128': 1, '749785': 8, '22523008': 7, '1359133': 1, '1605257': 1, '68254000': 4, '395142003': 6, '727316': 7, '665078': 2, '1014676': 3, '749882': 3, '239981': 3, '301807007': 3, '748962': 5, '1856546': 2, '197378': 2, '748879': 2, '1536586': 3, '1000158': 1, '169553002': 1, '1111011': 1, '997501': 1, '399208008': 1, '269911007': 1, '997488': 1}


In [16]:
# map each word to an index
word_id_map = {}
for index, word in enumerate(vocab):
    word_id_map[word] = index

In [17]:
#ecoding using tfidf
tfidf_vec = TfidfVectorizer(max_features=1000)
tfidf_matrix = tfidf_vec.fit_transform(vocab)
tfidf_matrix_array = tfidf_matrix.toarray()

In [18]:
tfidf_matrix_array

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [19]:
word_id_map.items()

dict_items([('46706006', 0), ('997488', 1), ('755621000000101', 2), ('834101', 3), ('757594', 4), ('1856546', 5), ('831533', 6), ('997501', 7), ('22523008', 8), ('301807007', 9), ('834060', 10), ('748962', 11), ('65200003', 12), ('76601001', 13), ('824184', 14), ('749762', 15), ('807283', 16), ('198405', 17), ('727316', 18), ('311791003', 19), ('197378', 20), ('1014676', 21), ('1111011', 22), ('1536586', 23), ('1359133', 24), ('646250', 25), ('309097', 26), ('282464', 27), ('608680', 28), ('748879', 29), ('1020137', 30), ('665078', 31), ('749882', 32), ('1049630', 33), ('239981', 34), ('399208008', 35), ('392151', 36), ('748856', 37), ('1367439', 38), ('287664005', 39), ('68254000', 40), ('73761001', 41), ('751905', 42), ('106258', 43), ('1605257', 44), ('395142003', 45), ('269911007', 46), ('1366342', 47), ('1000158', 48), ('23426006', 49), ('749785', 50), ('169553002', 51), ('389128', 52)])

In [20]:
word_vector_map = {}
for word, idd in word_id_map.items():
    word_vector_map[word] = tfidf_matrix_array[idd]

In [21]:
word_vector_map['239981']

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])

In [22]:
word_vectors = np.random.uniform(-0.01, 0.01,(vocab_size, word_embeddings_dim))
print(word_vectors.shape)
for i in range(len(vocab)):
    word = vocab[i]
    if word in word_vector_map:
        vector = word_vector_map[word]
        word_vectors[i] = vector
print(vocab[0])
print(word_vector_map['395142003'])
print(word_vectors[0])

(53, 53)
46706006
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0.]


In [23]:
#train data feature
# create document vector by summing word_vector of all words in a document
# normalize it
data_x = []

for i, doc in enumerate(shuffled_document[:train_size+val_size]):
    doc_vec = np.array([0.0 for k in range(word_embeddings_dim)])
    doc_words = doc.split()
    doc_len = len(doc_words)
    for word in doc_words:
        if word in word_vector_map:
            word_vector = word_vector_map[word]
            doc_vec = doc_vec + np.array(word_vector)          
    data_x.append(doc_vec)
print(len(data_x))
X = np.vstack([np.array(data_x)/doc_len, np.array(word_vectors)])  
print(len(X))        
print(X[0])

615
668
[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.33333333 0.         0.         0.         0.33333333
 0.         0.33333333 0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.        ]


In [24]:
# train data labels
y = []
for label in shuffled_labels[:train_size+val_size]:
    y.append(labels_list.index(label))
y.extend([-1 for _ in range(vocab_size)])
print(y[0])

3


In [25]:
# test data feature
data_tx = []
for i, doc in enumerate(shuffled_document[-test_size:]):
    doc_vec = np.array([0.0 for k in range(word_embeddings_dim)])
    doc_words = doc.split()
    doc_len = len(doc_words)
    for word in doc_words:
        if word in word_vector_map:
            word_vector = word_vector_map[word]
            doc_vec = doc_vec + np.array(word_vector)
    data_tx.append(doc_vec)
tx = np.array(data_tx)/doc_len

In [26]:
len(shuffled_labels)

724

In [27]:
#test data label
ty = []
ty = [labels_list.index(label) for label in shuffled_labels[-test_size:]]
print(tx.shape, len(ty))

(109, 53) 109


In [28]:
print(test_size)
print(tx.shape, len(ty))
print('------')
print(train_size + val_size+ vocab_size)
print(X.shape, len(y))

109
(109, 53) 109
------
668
(668, 53) 668


In [29]:
# word co-occurence with context windows
window_size = 3
windows = []

In [30]:
for doc_words in shuffled_document:
    words = doc_words.split()
    length = len(words)
    if length <= window_size:
        windows.append(words)
    else:
        for j in range(length - window_size + 1):
            window = words[j: j + window_size]
            windows.append(window)

In [31]:
print(windows)

[['311791003', '282464', '392151'], ['73761001', '834060'], ['23426006', '608680'], ['73761001', '834060'], ['23426006', '73761001', '834060'], ['73761001', '834060', '608680'], ['311791003', '282464', '392151'], ['73761001', '824184'], ['23426006', '834060', '608680'], ['287664005', '1049630'], ['46706006', '834101', '757594'], ['73761001', '834060'], ['73761001', '834060'], ['23426006', '834101', '749762'], ['834101', '749762', '608680'], ['749762', '608680', '1367439'], ['73761001', '834060'], ['23426006', '608680'], ['23426006', '1020137', '608680'], ['73761001', '834060'], ['23426006', '608680'], ['23426006', '608680'], ['311791003', '282464', '392151'], ['311791003', '282464', '309097'], ['282464', '309097', '392151'], ['287664005', '755621000000101', '646250'], ['23426006', '608680'], ['73761001', '834060'], ['311791003', '282464'], ['23426006', '834060', '608680'], ['23426006', '608680'], ['23426006', '608680'], ['287664005', '755621000000101', '1366342'], ['311791003', '834101

In [32]:
word_window_freq = defaultdict(lambda :0)
for window in windows:
    appeared = set()
    for word in window:
        if word in appeared:
            continue
        word_window_freq[word]+=1
        appeared.add(word)

In [33]:
word_doc_freq

{'311791003': 122,
 '282464': 56,
 '392151': 77,
 '73761001': 175,
 '834060': 252,
 '23426006': 357,
 '608680': 357,
 '824184': 95,
 '287664005': 29,
 '1049630': 6,
 '46706006': 48,
 '834101': 74,
 '757594': 8,
 '749762': 2,
 '1367439': 6,
 '1020137': 27,
 '309097': 23,
 '755621000000101': 33,
 '646250': 12,
 '1366342': 19,
 '65200003': 13,
 '807283': 28,
 '198405': 58,
 '748856': 9,
 '76601001': 3,
 '831533': 6,
 '751905': 6,
 '106258': 2,
 '389128': 1,
 '749785': 8,
 '22523008': 7,
 '1359133': 1,
 '1605257': 1,
 '68254000': 4,
 '395142003': 6,
 '727316': 7,
 '665078': 2,
 '1014676': 3,
 '749882': 3,
 '239981': 3,
 '301807007': 3,
 '748962': 5,
 '1856546': 2,
 '197378': 2,
 '748879': 2,
 '1536586': 3,
 '1000158': 1,
 '169553002': 1,
 '1111011': 1,
 '997501': 1,
 '399208008': 1,
 '269911007': 1,
 '997488': 1}

In [34]:
# the difference between word_doc_freq is that a document may be in multiple windows so the count is usually more
word_window_freq

defaultdict(<function __main__.<lambda>()>,
            {'311791003': 124,
             '282464': 80,
             '392151': 102,
             '73761001': 198,
             '834060': 317,
             '23426006': 358,
             '608680': 366,
             '824184': 97,
             '287664005': 30,
             '1049630': 6,
             '46706006': 58,
             '834101': 95,
             '757594': 8,
             '749762': 5,
             '1367439': 8,
             '1020137': 33,
             '309097': 41,
             '755621000000101': 41,
             '646250': 12,
             '1366342': 28,
             '65200003': 13,
             '807283': 33,
             '198405': 58,
             '748856': 12,
             '76601001': 4,
             '831533': 6,
             '751905': 8,
             '106258': 5,
             '389128': 1,
             '749785': 11,
             '22523008': 9,
             '1359133': 1,
             '1605257': 2,
             '68254000': 4,
          

In [35]:
word_pair_count = defaultdict(lambda :0)
for window in windows:
    for i in range(1, len(window)):
        for j in range(0,i):
            word_i = window[i]
            word_j = window[j]
            word_i_id  = word_id_map[word_i]
            word_j_id =  word_id_map[word_j]
            if word_i_id == word_j_id:
                continue
            word_pair_str = str(word_i_id) + ',' + str(word_j_id)
            word_pair_count[word_pair_str]+=1
            word_pair_str = str(word_j_id) + ',' + str(word_i_id)
            word_pair_count[word_pair_str]+=1

In [36]:
print(word_pair_count)
print(len(word_pair_count))

defaultdict(<function <lambda> at 0x7f603a7dee60>, {'27,19': 51, '19,27': 51, '36,19': 69, '19,36': 69, '36,27': 57, '27,36': 57, '10,41': 117, '41,10': 117, '28,49': 314, '49,28': 314, '41,49': 34, '49,41': 34, '10,49': 113, '49,10': 113, '28,41': 30, '41,28': 30, '28,10': 120, '10,28': 120, '14,41': 44, '41,14': 44, '33,39': 2, '39,33': 2, '3,0': 6, '0,3': 6, '4,0': 4, '0,4': 4, '4,3': 1, '3,4': 1, '3,49': 18, '49,3': 18, '15,49': 1, '49,15': 1, '15,3': 2, '3,15': 2, '28,3': 19, '3,28': 19, '28,15': 2, '15,28': 2, '38,15': 1, '15,38': 1, '38,28': 1, '28,38': 1, '30,49': 20, '49,30': 20, '28,30': 23, '30,28': 23, '26,19': 19, '19,26': 19, '26,27': 22, '27,26': 22, '36,26': 18, '26,36': 18, '2,39': 5, '39,2': 5, '25,39': 2, '39,25': 2, '25,2': 11, '2,25': 11, '47,39': 2, '39,47': 2, '47,2': 14, '2,47': 14, '3,19': 1, '19,3': 1, '38,19': 1, '19,38': 1, '38,3': 2, '3,38': 2, '14,3': 4, '3,14': 4, '14,38': 3, '38,14': 3, '38,12': 1, '12,38': 1, '14,12': 3, '12,14': 3, '16,0': 22, '0,16': 

In [37]:
# Word-Word node pmi as weights
row = []
col = []
weight = []
num_window = len(windows)

for key in word_pair_count:
    temp = key.split(',')
    i = int(temp[0])
    j = int(temp[1])
    count = word_pair_count[key]
    word_freq_i = word_window_freq[vocab[i]]
    word_freq_j = word_window_freq[vocab[j]]
    pmi = np.log((1.0 * count / num_window) /(1.0 * word_freq_i * word_freq_j/(num_window * num_window)))
    if pmi <= 0:
        continue
    row.append(train_size+val_size +i)
    col.append(train_size+val_size + j)
    weight.append(pmi)


In [38]:
print(row)
print(len(row))

[642, 634, 651, 634, 651, 642, 625, 656, 643, 664, 629, 656, 648, 654, 619, 615, 619, 618, 630, 618, 653, 630, 645, 664, 643, 645, 641, 634, 641, 642, 651, 641, 617, 654, 640, 654, 640, 617, 662, 654, 662, 617, 653, 618, 629, 653, 653, 627, 629, 627, 631, 615, 625, 618, 625, 654, 632, 634, 632, 651, 619, 628, 618, 656, 632, 642, 653, 617, 653, 662, 621, 654, 618, 627, 625, 627, 631, 625, 632, 641, 657, 654, 629, 654, 629, 657, 629, 662, 629, 634, 658, 618, 658, 625, 667, 625, 667, 658, 665, 654, 629, 615, 629, 640, 625, 617, 643, 623, 652, 625, 631, 619, 652, 627, 631, 627, 631, 652, 639, 654, 629, 639, 659, 627, 652, 659, 629, 659, 656, 655, 625, 615, 629, 621, 651, 660, 642, 660, 633, 642, 633, 641, 646, 641, 646, 633, 632, 633, 632, 646, 636, 660, 651, 636, 642, 636, 623, 656, 652, 654, 647, 654, 618, 623, 615, 655, 649, 656, 629, 649, 615, 627, 647, 615, 647, 631, 624, 656, 624, 617, 662, 624, 657, 664, 643, 657, 657, 645, 640, 628, 629, 628, 615, 654, 626, 617, 625, 626, 662, 626,

In [39]:
print(col)
print(len(col))

[634, 642, 634, 651, 642, 651, 656, 625, 664, 643, 656, 629, 654, 648, 615, 619, 618, 619, 618, 630, 630, 653, 664, 645, 645, 643, 634, 641, 642, 641, 641, 651, 654, 617, 654, 640, 617, 640, 654, 662, 617, 662, 618, 653, 653, 629, 627, 653, 627, 629, 615, 631, 618, 625, 654, 625, 634, 632, 651, 632, 628, 619, 656, 618, 642, 632, 617, 653, 662, 653, 654, 621, 627, 618, 627, 625, 625, 631, 641, 632, 654, 657, 654, 629, 657, 629, 662, 629, 634, 629, 618, 658, 625, 658, 625, 667, 658, 667, 654, 665, 615, 629, 640, 629, 617, 625, 623, 643, 625, 652, 619, 631, 627, 652, 627, 631, 652, 631, 654, 639, 639, 629, 627, 659, 659, 652, 659, 629, 655, 656, 615, 625, 621, 629, 660, 651, 660, 642, 642, 633, 641, 633, 641, 646, 633, 646, 633, 632, 646, 632, 660, 636, 636, 651, 636, 642, 656, 623, 654, 652, 654, 647, 623, 618, 655, 615, 656, 649, 649, 629, 627, 615, 615, 647, 631, 647, 656, 624, 617, 624, 624, 662, 664, 657, 657, 643, 645, 657, 628, 640, 628, 629, 654, 615, 617, 626, 626, 625, 626, 662,

In [40]:
print(weight)
print(len(weight))

[1.4922469393772866, 1.4922469393772866, 1.5515816326398308, 1.5515816326398308, 1.7987813268082766, 1.7987813268082766, 0.4777346371578195, 0.4777346371578195, 0.728956173038066, 0.728956173038066, 0.6839411316522225, 0.6839411316522225, 2.2629198366016143, 2.2629198366016143, 2.009139315825515, 2.009139315825515, 0.1294110736515025, 0.1294110736515025, 1.2925618834571835, 1.2925618834571835, 3.073850052817943, 3.073850052817943, 0.38142123261868977, 0.38142123261868977, 0.4990828279931826, 0.4990828279931826, 1.1733148537889753, 1.1733148537889753, 1.7581732589120058, 1.7581732589120058, 1.3145563848394652, 1.3145563848394652, 1.2573979709995164, 1.2573979709995164, 1.569772656041669, 1.569772656041669, 2.9621460632379417, 2.9621460632379417, 0.7224747956544655, 0.7224747956544655, 2.356010259667626, 2.356010259667626, 0.8225582542114478, 0.8225582542114478, 1.2071892754167703, 1.2071892754167703, 2.118338607790507, 2.118338607790507, 0.7216814596350696, 0.7216814596350696, 2.2968213

In [41]:
# doc word frequency
doc_word_freq = defaultdict(lambda : 0)
for doc_id, doc in enumerate(shuffled_document):
    words = doc_words.split()
    for word in words:
        word_id = word_id_map[word]
        doc_word_str = str(doc_id) + ',' + str(word_id)
        doc_word_freq[doc_word_str]+=1

In [42]:
doc_word_freq

defaultdict(<function __main__.<lambda>()>,
            {'0,41': 1,
             '0,14': 1,
             '1,41': 1,
             '1,14': 1,
             '2,41': 1,
             '2,14': 1,
             '3,41': 1,
             '3,14': 1,
             '4,41': 1,
             '4,14': 1,
             '5,41': 1,
             '5,14': 1,
             '6,41': 1,
             '6,14': 1,
             '7,41': 1,
             '7,14': 1,
             '8,41': 1,
             '8,14': 1,
             '9,41': 1,
             '9,14': 1,
             '10,41': 1,
             '10,14': 1,
             '11,41': 1,
             '11,14': 1,
             '12,41': 1,
             '12,14': 1,
             '13,41': 1,
             '13,14': 1,
             '14,41': 1,
             '14,14': 1,
             '15,41': 1,
             '15,14': 1,
             '16,41': 1,
             '16,14': 1,
             '17,41': 1,
             '17,14': 1,
             '18,41': 1,
             '18,14': 1,
             '19,41': 1,
 

In [43]:
# Word-Doc node train weights
for i, doc_words in enumerate(shuffled_document):
    words = doc_words.split()
    doc_word_set = set()
    for word in words:
        if word in doc_word_set:
            continue
        j = word_id_map[word]

        key = str(i) + ',' + str(j)
        freq = doc_word_freq[key]
        if i < train_size + val_size:
            row.append(i)
        else:
            row.append(i + vocab_size)
        col.append(train_size+val_size+ j)
        idf = np.log(1.0 * len(shuffled_document) / word_doc_freq[vocab[j]])
        weight.append(freq * idf)
        doc_word_set.add(word)

In [44]:
print(len(row))
print(len(col))
print(len(weight))

2307
2307
2307


In [45]:
node_size = train_size + val_size + vocab_size + test_size
adj = sp.csr_matrix((weight, (row, col)), shape=(node_size, node_size))
#  Ensures that the adjacency matrix is symmetric
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
print(len(row),len(col), len(weight), max(row), node_size)

2307 2307 2307 776 777


## Train on Graph Convolution Network (GCN)

In [46]:
def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
    
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

In [47]:
labels = y + ty
features_merge = np.vstack([X, tx]).astype('float32')
# feature sparse csr
features = sp.csr_matrix(features_merge, dtype=np.float32)
# adjancy sparce csr (weight in what row and column)
adj = sp.csr_matrix((weight, (row, col)), shape=(node_size, node_size))
# build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
# normalize featrue matrix
features = normalize(features)
# add identity matrix and then normalize it
adj = normalize(adj + sp.eye(adj.shape[0]))

In [48]:
perm = np.arange(0, node_size)
idx_train = perm[:train_size]
idx_val = perm[train_size:train_size + val_size]
idx_test= perm[train_size + val_size + vocab_size : ]

In [49]:
features = torch.FloatTensor(np.array(features.todense()))
labels = torch.LongTensor(labels)
adj = sparse_mx_to_torch_sparse_tensor(adj)
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

In [50]:
class GraphConvolution(Module):

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    # initialized the weight and bias with values drawn from a uniform distribution
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    # input (matrix of node features)
    # adj (adjacecny matrix)
    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output
    # shape of input feature
    # shape of output feature
    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [51]:
class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()

        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nclass)
        self.dropout = dropout

    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)

In [52]:
lr = 0.01
weight_decay = 5e-4
dropout = 0.5
hidden_size =16
fastmode = False
epochs = 200

In [53]:
# Model and optimizer
model = GCN(nfeat=features.shape[1],
            nhid=hidden_size,
            nclass=labels.max().item() + 1,
            dropout=dropout)
optimizer = optim.Adam(model.parameters(),
                       lr=lr, weight_decay=weight_decay)

In [54]:
fastmode = False
def train(epoch):
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    if not fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output = model(features, adj)
    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()))


In [55]:
epochs

200

In [56]:
for epoch in range(epochs):
    train(epoch)

Epoch: 0001 loss_train: 1.4716 acc_train: 0.1565 loss_val: 1.4057 acc_val: 0.0833
Epoch: 0002 loss_train: 1.4045 acc_train: 0.1971 loss_val: 1.3741 acc_val: 0.1250
Epoch: 0003 loss_train: 1.4125 acc_train: 0.2762 loss_val: 1.3425 acc_val: 0.5000
Epoch: 0004 loss_train: 1.3326 acc_train: 0.3333 loss_val: 1.3126 acc_val: 0.6944
Epoch: 0005 loss_train: 1.3201 acc_train: 0.4401 loss_val: 1.2832 acc_val: 0.7083
Epoch: 0006 loss_train: 1.2903 acc_train: 0.5009 loss_val: 1.2547 acc_val: 0.6806
Epoch: 0007 loss_train: 1.3107 acc_train: 0.5488 loss_val: 1.2277 acc_val: 0.6806
Epoch: 0008 loss_train: 1.2178 acc_train: 0.5948 loss_val: 1.2015 acc_val: 0.6528
Epoch: 0009 loss_train: 1.2534 acc_train: 0.5672 loss_val: 1.1751 acc_val: 0.6528
Epoch: 0010 loss_train: 1.2221 acc_train: 0.5709 loss_val: 1.1485 acc_val: 0.6528
Epoch: 0011 loss_train: 1.2008 acc_train: 0.5930 loss_val: 1.1218 acc_val: 0.6528
Epoch: 0012 loss_train: 1.1580 acc_train: 0.5893 loss_val: 1.0951 acc_val: 0.6528
Epoch: 0013 loss

In [57]:
def test():
    model.eval()
    output = model(features, adj)

    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

In [58]:
test()

Test set results: loss= 0.0848 accuracy= 0.9908
