In [29]:
from neo4j import GraphDatabase
from sklearn.manifold import TSNE
import numpy as np
import altair as alt
import pandas as pd

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
class Neo4jConnection:
    
    def __init__(self, uri, user, pwd):
        self.__uri = uri
        self.__user = user
        self.__pwd = pwd
        self.__driver = None
        try:
            self.__driver = GraphDatabase.driver(self.__uri, auth=(self.__user, self.__pwd))
        except Exception as e:
            print("Failed to create the driver:", e)
        
    def close(self):
        if self.__driver is not None:
            self.__driver.close()
        
    def query(self, query, parameters=None, db=None):
        assert self.__driver is not None, "Driver not initialized!"
        session = None
        response = None
        try: 
            session = self.__driver.session(database=db) if db is not None else self.__driver.session() 
            response = list(session.run(query, parameters))
        except Exception as e:
            print("Query failed:", e)
        finally: 
            if session is not None:
                session.close()
        return response

In [31]:
uri = 'neo4j://localhost:7687'
pwd = 'Covid-19KG'

conn = Neo4jConnection(uri=uri, user="neo4j", pwd=pwd)


In [32]:
conn.query('MATCH (n:State) RETURN COUNT(n) AS count')

[<Record count=51>]

In [33]:
query = '''MATCH (s:State)
           RETURN s.code AS state, s.embedding as embedding_state
'''

df = pd.DataFrame([dict(_) for _ in conn.query(query)])
df.head()

Unnamed: 0,state,embedding_state
0,AR,"[-0.012031819315977998, -0.01122531718244025, ..."
1,DC,"[-0.009264488459663477, -0.0088018267925782, -..."
2,DE,"[-0.024627740929851226, -0.023031867375497145,..."
3,FL,"[-0.01164724698277438, -0.011000245209659055, ..."
4,GA,"[-0.01542598052334032, -0.014219058976192005, ..."


In [34]:
len(df.embedding_state[0])

64

In [35]:
X_embedded = TSNE(n_components=2, random_state=6).fit_transform(list(df.embedding_state))

places = df.state
tsne_df = pd.DataFrame(data = {
    "state": places,
    "x": [value[0] for value in X_embedded],
    "y": [value[1] for value in X_embedded]
})
tsne_df.head()



Unnamed: 0,state,x,y
0,AR,-31.454567,7.59303
1,DC,-50.578136,17.923983
2,DE,-49.716061,53.160572
3,FL,-56.712261,69.686607
4,GA,-23.678928,78.954315


In [36]:
tsne_df.shape

(51, 3)

In [37]:
alt.Chart(tsne_df).mark_circle(size=60).encode(
    x='x',
    y='y',
    color='state',
    tooltip=['state']
).properties(width=700, height=400)