### Imports

In [None]:
import json
import time
import pandas as pd
from py2neo import Graph, Node, Relationship
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.tree import DecisionTreeClassifier

### Connect to graph

In [None]:
graph = Graph("bolt://neo4j-quanta:7687", auth=('neo4j','myneo'))
print("Connected to graph database with {:,} nodes and {:,} relationships!".format(
    graph.database.primitive_counts['NumberOfNodeIdsInUse'], 
    graph.database.primitive_counts['NumberOfRelationshipIdsInUse']))

### Build dataset

In [None]:
years_to_use = 3
start_year = 2009
end_year = 2013

print("Getting dataset...", end=" ")
cites_str = ',\n    '.join(['CASE WHEN {} < q.year THEN NULL ELSE SIZE((q)<-[:CITES]-(:Quanta {{year: {}}})) END as c{}'.format(
    yr, yr, yr) for yr in range(start_year, end_year+1)])
tspr_str = ',\n    '.join(['q.tspr{} as tspr{}'.format(
    yr, yr) for yr in range(start_year, end_year+1)])
query = """
MATCH (q:Quanta)
WHERE 
    (q.doctype='Journal') AND 
    (q.lang='en') AND 
    EXISTS(q.fos) AND 
    (q.year>={} AND q.year <= {}) 
RETURN
    q.year as year,
    q.title as title,
    q.id as id,
    {},
    {}
LIMIT 1000
""".format(start_year, end_year-years_to_use, tspr_str, cites_str)
print(query)
query_start_time = time.time()
df = graph.run(query).to_data_frame()
print("Done ({:.2f} minutes).".format((time.time()-query_start_time)/60))

In [None]:
df_new = df[['year','title','id']]

for i in range(years_to_use+1):
    df_new['c{}'.format(i)] = df.apply(lambda row: row['c{}'.format(row['year']+i)], axis=1)
    df_new['p{}'.format(i)] = df.apply(lambda row: row['tspr{}'.format(row['year']+i)], axis=1)

In [None]:
def balanced_subsample(x,y,subsample_size=1.0):
    class_xs = []
    min_elems = None
    for yi in np.unique(y):
        elems = x[(y == yi)]
        class_xs.append((yi, elems))
        if min_elems == None or elems.shape[0] < min_elems:
            min_elems = elems.shape[0]
    use_elems = min_elems
    if subsample_size < 1:
        use_elems = int(min_elems*subsample_size)
    xs = []
    ys = []
    for ci,this_xs in class_xs:
        if len(this_xs) > use_elems:
            np.random.shuffle(this_xs)
        x_ = this_xs[:use_elems]
        y_ = np.empty(use_elems)
        y_.fill(ci)
        xs.append(x_)
        ys.append(y_)
    xs = np.concatenate(xs)
    ys = np.concatenate(ys)
    return xs,ys

In [None]:
df_new.size

In [None]:
df_new.columns
cs = ['c{}'.format(x) for x in range(years_to_use+1)]
ps = ['p{}'.format(x) for x in range(years_to_use+1)]
cs.extend(ps)
X = df_new[cs]
y = df_new['p{}'.format(years_to_use)] >= df_new['p{}'.format(years_to_use)].quantile(0.9)

In [None]:
X_bal, y_bal = balanced_subsample(X,y)


In [None]:
# y = df_new['p6'] >= df_new['p6'].quantile(0.90)
#X = df_new[['c1','c2','c3']].values
#X = df_new[['p1','p2','p3']].values
# X = df_new[['p1','p2','p3','c1','c2','c3']].values

clf = DecisionTreeClassifier(random_state=0)
cross_val_score(clf, X_bal, y_bal, cv=10).mean()

In [None]:
y.size

In [None]:
def perf_measure(y_actual, y_hat):
    TP = 0
    FP = 0
    TN = 0
    FN = 0

    for i in range(len(y_hat)): 
        if y_actual[i]==y_hat[i]==1:
           TP += 1
        if y_hat[i]==1 and y_actual[i]!=y_hat[i]:
           FP += 1
        if y_actual[i]==y_hat[i]==0:
           TN += 1
        if y_hat[i]==0 and y_actual[i]!=y_hat[i]:
           FN += 1

    return(TP, FP, TN, FN)

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])


plt.figure()
lw = 2
plt.plot(fpr[2], tpr[2], color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()