# Prediction with Saved Tensorflow Classifier

In [1]:
import neo4j_bq as bq
import pandas as pd
import numpy as np
import neo4j_arrow as na
from google.cloud import bigquery
from tensorflow import keras
import tensorflow as tf
from graphdatascience import GraphDataScience

## Create BigQuery Table to Hold Paper Subject Predictions

In [2]:
def bq_tbl_exists(client, table_ref):
    from google.cloud.exceptions import NotFound
    try:
        client.get_table(table_ref)
        return True
    except NotFound:
        return False

In [3]:
client = bigquery.Client()
table_id = "neo4jgraphconnectdemo2022.ogb_mag240m.paper_predictions3"

In [4]:
schema = [
    bigquery.SchemaField("paper", "INTEGER", mode="REQUIRED"),
    bigquery.SchemaField("years", "INTEGER", mode="REQUIRED"),
    bigquery.SchemaField("flag", "INTEGER", mode="REQUIRED"),
    bigquery.SchemaField("predictedFlag", "INTEGER", mode="REQUIRED")
]

if not bq_tbl_exists(client, table_id):
    table = bigquery.Table(table_id, schema=schema)
    table = client.create_table(table)
    print(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")
else:
    table = client.get_table(table_id)
    print(f"Found table {table.project}.{table.dataset_id}.{table.table_id}")

Created table neo4jgraphconnectdemo2022.ogb_mag240m.paper_predictions3


## Predict Paper Subjects (a.k.a. "Flags")

In [5]:
with open('pass.txt', mode='r') as f:
    password = f.readline().strip()

In [6]:
model = keras.models.load_model('simple-paper-classifier-5')
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])

2022-06-11 15:57:17.549160: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [7]:
# Replace with the actual URI, username and password
CONNECTION_URI = "neo4j+s://demo2.graphconnect.app:7687"
USERNAME = "neo4j"

# Client instantiation
gds = GraphDataScience(
    CONNECTION_URI,
    auth=(USERNAME, password)
)

In [31]:
G=gds.graph.get("gcdemo")

In [9]:
def predict_for_year(year):
    
    #sub-project and read papers
    g, _ = gds.beta.graph.project.subgraph('proj', G,  f'n:Paper AND (n.years = {year})', '*', concurrency=224)
    arrow_client = na.Neo4jArrowClient('demo2.graphconnect.app', graph="proj", password=password, concurrency=224)
    dfs=[]
    for chunk in arrow_client.read_nodes(["graphEmbedding", "flag", "years"]):
        dfs.append(chunk.to_pandas())
    df = pd.concat(dfs)
    
    #drop subgraph projection
    g.drop()

    #predeict paper subjects
    X = np.stack(df.graphEmbedding, axis=0)
    df["predictedFlag"] = np.argmax(model.predict(X, batch_size=100_000), axis = 1)
    
    #write to BigQuery
    bq_client =  bigquery.Client()
    job = bq_client.load_table_from_dataframe(df[["nodeId","flag", "years","predictedFlag"]].rename(columns={"nodeId":"paper"}), table_id)
    print(job.result())  # Wait for the job to complete.
    
    return df.shape[0]

In [10]:
papers = 0
for i in range (1901,2022):
    papers += predict_for_year(i)
    print(f'Finished year {i}. Predicted {papers} papers so far')
    print('-------------------------------')

LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=631f559f-a36b-4c06-8cbd-0e1f4339b7f5>
Finished year 1901. Predicted 22578 papers so far
-------------------------------
LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=07c8213f-2e9e-45a1-9e74-33a5bf0defe4>
Finished year 1902. Predicted 45854 papers so far
-------------------------------
LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=b221560d-23de-48f3-833f-7de56697a5f0>
Finished year 1903. Predicted 70255 papers so far
-------------------------------
LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=48c5df12-b3a7-406a-9cb9-57cd9e86d224>
Finished year 1904. Predicted 94543 papers so far
-------------------------------
LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=b9db5262-b360-478c-8a5b-265d0f717d57>
Finished year 1905. Predicted 120390 papers so far
-------------------------------
LoadJob<project=neo4jgraphconnectdemo2022, location=US, id=37a4bef5-c731-40cb-a858-7e443cfe4c6b>
Finishe