# Neo4j sample with graphdatascience package <a target="_blank" href="https://colab.research.google.com/github/yWorks/yfiles-jupyter-graphs/blob/main/examples/27_neo4j-sample-gds_example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This example notebook expects you to have these packages installed (using pip install or whatever works for you):

In [None]:
%pip install yfiles_jupyter_graphs --quiet
%pip install neo4j graphdatascience --quiet

You can also open this notebook in Google Colab when Google Colab's custom widget manager is enabled:

In [None]:
try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

<a target="_blank" href="https://colab.research.google.com/github/yWorks/yfiles-jupyter-graphs/blob/main/examples/27_neo4j-sample-gds_example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The queries are meant to work with the airport example. So if you have GDS enabled and a database that contains the airport example from neo4j, just fill in your credentials in the next cell and give it a try!

Insert your database connection credentials, here.

In [None]:
from neo4j import basic_auth

db = "neo4j+s://yourdatabaseid.databases.neo4j.io"
auth = basic_auth("username", "password")


In [None]:
from neo4j import GraphDatabase

driver = GraphDatabase.driver(db, auth=auth)

with driver.session(database="neo4j") as session:
  graph = session.run("MATCH triple=(a:Airport)-[r]->() WHERE type(r) <> 'HAS_ROUTE'  RETURN triple LIMIT 25").graph()

from yfiles_jupyter_graphs import GraphWidget
GraphWidget(graph = graph)

In [None]:
driver = GraphDatabase.driver(
  db,
  auth=auth)

with driver.session(database="neo4j") as session:
  graph = session.run("MATCH p=()-[r]->() WHERE type(r) <> 'HAS_ROUTE'  RETURN p LIMIT 25").graph()

from yfiles_jupyter_graphs import GraphWidget

styles = {
    "Airport": {"color":"#6C7400", "shape":"ellipse", "label":"iata"},
    "City": {"color":"#005977", "shape":"rectangle", "label":"name"},
    "Region": {"color":"#386664", "shape":"rectangle", "label":"name"},
    "Country": {"color":"#498381", "shape":"octagon", "label":"code"},
    "Continent": {"color":"#254241", "shape":"hexagon", "label":"name"}
}   

w = GraphWidget(graph = graph)
w.set_edge_color_mapping(lambda edge : "orange" if edge["properties"]["label"] == "IN_COUNTRY" else "black")
w.set_node_styles_mapping(lambda node : styles.get(node["properties"]["label"], {}))
w.set_node_label_mapping(lambda node : node["properties"][styles.get(node["properties"]["label"], {"label":"label"})["label"]])

display(w)

In [None]:
def createAirportGraph(neo4jGraph):
  styles = {
    "Airport": {"color":"#6C7400", "shape":"ellipse", "label":"iata"},
    "City": {"color":"#005977", "shape":"rectangle", "label":"name"},
    "Region": {"color":"#386664", "shape":"rectangle", "label":"name"},
    "Country": {"color":"#498381", "shape":"octagon", "label":"code"},
    "Continent": {"color":"#254241", "shape":"hexagon", "label":"name"}
  }
    
  w = GraphWidget(graph = neo4jGraph)  

  w.set_edge_color_mapping(lambda edge : "orange" if edge["properties"]["label"] == "IN_COUNTRY" else "black")
  w.set_node_styles_mapping(lambda node : styles.get(node["properties"]["label"], {}))
  w.set_node_label_mapping(lambda node : node["properties"][styles.get(node["properties"]["label"], {"label":"label"})["label"]])

  return w  

In [None]:
from graphdatascience import GraphDataScience

# Use Neo4j URI and credentials according to your setup
gds = GraphDataScience(db, auth=auth)

# In a second run, be sure to drop the projected graph, first
#gds.graph.drop(gds.graph.get("air-routes"))

G_air, results = gds.graph.project("air-routes", "Airport", 
                                   {"HAS_ROUTE": {"orientation": "NATURAL", "properties": ["distance"]}})

source_id = gds.find_node_id(["Airport"], {"iata": "STR"})
target_id = gds.find_node_id(["Airport"], {"iata": "ART"})
result = gds.shortestPath.dijkstra.stream(G_air, sourceNode = source_id, targetNode = target_id, relationshipWeightProperty = "distance")

with driver.session(database="neo4j") as session:
  graph = session.run("""
          MATCH (sco:Country)<-[scor:IN_COUNTRY]-(sr:Region)<-[srr:IN_REGION]-(sc:City)<-[scr:IN_CITY]-
              (s:Airport)-[r:HAS_ROUTE]->(t:Airport)
              -[tcr:IN_CITY]->(tc:City)-[trr:IN_REGION]->(tr:Region)-[tcor:IN_COUNTRY]->(tco:Country) 
          WHERE id(s) in $ids AND id(t) in $ids 
          RETURN * LIMIT 100
      """, {"ids":list(result.nodeIds[0])}).graph()


createAirportGraph(graph)

In [None]:
import math;
from graphdatascience import GraphDataScience
from matplotlib import cm
from matplotlib.colors import to_hex

gds = GraphDataScience(db, auth=auth)
gds.graph.drop(gds.graph.get("air-routes"))

G_air, results = gds.graph.project("air-routes", "Airport", 
                                   {"HAS_ROUTE": {"orientation": "UNDIRECTED"}})

gds.pageRank.mutate(G_air, mutateProperty='pageRank')
pageRankTable = gds.graph.streamNodeProperties(G_air, ['pageRank'])

cutoff = 1.5

filteredTable = pageRankTable[pageRankTable.propertyValue > cutoff]
ranks = dict(zip(filteredTable.nodeId, filteredTable.propertyValue))

with driver.session(database="neo4j") as session:
  graph = session.run("""
      MATCH (s)
        WHERE id(s) in $ids 
      RETURN s
    """, {"ids":list(filteredTable.nodeId)}).graph()


w = GraphWidget(graph = graph)

w.set_node_label_mapping(lambda node : node["properties"]["iata"])


def custom_scale_mapping(node):
    rank = ranks[node.get('id')]
    return 1 + math.sqrt(rank - cutoff + 1) if rank > cutoff else 1
w.node_scale_factor_mapping = custom_scale_mapping

maxRank = filteredTable.propertyValue.max()
viridis = cm.get_cmap('viridis', round(maxRank))

def color_mapping(node):
    rank = ranks[node.get('id')]
    return to_hex(viridis(rank/maxRank))
w.node_color_mapping = color_mapping

def location_mapping(node): 
    location = node["properties"]["location"]
    return [location[0] * 150, location[1] * -200]
w.node_position_mapping = location_mapping
w.graph_layout = "organic_edge_router"

display(w)