In [1]:
from graphframes import *
from pyspark import SparkContext

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [3]:
spark

In [4]:
from pyspark.sql.types import * 
from graphframes import *

In [5]:
def create_transport_graph(): 
    node_fields = [
        StructField("id", StringType(), True),
        StructField("latitude", FloatType(), True),
        StructField("longitude", FloatType(), True),
        StructField("population", IntegerType(), True)
    ]
    nodes = spark.read.csv("data/transport/transport-nodes.csv", header=True,
                           schema=StructType(node_fields))
    rels = spark.read.csv("data/transport/transport-relationships.csv", header=True)
    reversed_rels = (rels.withColumn("newSrc", rels.dst)
                     .withColumn("newDst", rels.src)
                     .drop("dst", "src")
                     .withColumnRenamed("newSrc", "src")
                     .withColumnRenamed("newDst", "dst")
                     .select("src", "dst", "relationship", "cost"))
    relationships = rels.union(reversed_rels) 
    return GraphFrame(nodes, relationships)

g = create_transport_graph()

#### BFS 

In [6]:
(g.vertices
     .filter("population > 100000 and population < 300000")
     .sort("population")
     .show())

+----------+--------+---------+----------+
|        id|latitude|longitude|population|
+----------+--------+---------+----------+
|Colchester|51.88921|  0.90421|    104390|
|   Ipswich|52.05917|  1.15545|    133384|
+----------+--------+---------+----------+



In [7]:
from_expr = "id='Den Haag'"
to_expr = "population > 100000 and population < 300000 and id <> 'Den Haag'"
result = g.bfs(from_expr, to_expr)
    
print(result.columns)

['from', 'e0', 'v1', 'e1', 'v2', 'e2', 'to']


In [8]:
result.show()

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                from|                  e0|                  v1|                  e1|                  v2|                  e2|                  to|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|[Den Haag, 52.078...|[Den Haag, Hoek v...|[Hoek van Holland...|[Hoek van Holland...|[Felixstowe, 51.9...|[Felixstowe, Ipsw...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Den Haag, Hoek v...|[Hoek van Holland...|[Hoek van Holland...|[Felixstowe, 51.9...|[Felixstowe, Ipsw...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Den Haag, Hoek v...|[Hoek van Holland...|[Hoek van Holland...|[Felixstowe, 51.9...|[Felixstowe, Ipsw...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Den Haag, Hoek v...|[Hoek van Holland...|[Hoek van Holland...|[Felixstowe, 51.9...|

In [9]:
columns = [column for column in result.columns if not column.startswith("e")] 
result.select(columns).show()

+--------------------+--------------------+--------------------+--------------------+
|                from|                  v1|                  v2|                  to|
+--------------------+--------------------+--------------------+--------------------+
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
|[Den Haag, 52.078...|[Hoek van Holland...|[Felixstowe, 51.9...|[Ipswich, 52.0591...|
+--------------------+--------------------+-----------

#### Shortest Path (weighted)
* Message passing via AggregateMessages

In [10]:
from graphframes.lib import AggregateMessages as AM 
from pyspark.sql import functions as F

In [11]:
add_path_udf = F.udf(lambda path, id: path + [id], ArrayType(StringType()))
sc = spark.sparkContext

In [12]:
def shortest_path(g, origin, destination, column_name="cost"):
    if g.vertices.filter(g.vertices.id == destination).count() == 0:
        return
    
    vertices = (g.vertices.withColumn("visited", F.lit(False))
                .withColumn("distance", F.when(g.vertices["id"] == origin, 0)
                .otherwise(float("inf")))
                .withColumn("path", F.array()))
    cached_vertices = AM.getCachedDataFrame(vertices)
    g2 = GraphFrame(cached_vertices, g.edges)

    print("hello")
    while g2.vertices.filter('visited == False').first(): 
        current_node_id = g2.vertices.filter('visited == False').sort("distance").first().id
        msg_distance = AM.edge[column_name] + AM.src['distance']
        msg_path = add_path_udf(AM.src["path"], AM.src["id"])
        msg_for_dst = F.when(AM.src['id'] == current_node_id, F.struct(msg_distance, msg_path))
        new_distances = g2.aggregateMessages(F.min(AM.msg).alias("aggMess"), sendToDst=msg_for_dst)
    
        new_visited_col = F.when(
                g2.vertices.visited | (g2.vertices.id == current_node_id),
                                                    True).otherwise(False)
        
        new_distance_col = F.when(new_distances["aggMess"].isNotNull() & \
                                  (new_distances.aggMess["col1"] < g2.vertices.distance),\
                                  new_distances.aggMess["col1"]).otherwise(g2.vertices.distance)
            
        new_path_col = F.when(new_distances["aggMess"].isNotNull() &
                           (new_distances.aggMess["col1"]
                           < g2.vertices.distance), new_distances.aggMess["col2"]
                           .cast("array<string>")).otherwise(g2.vertices.path)
        
        new_vertices = (g2.vertices.join(new_distances, on="id", how="left_outer")
                            .drop(new_distances["id"])
                            .withColumn("visited", new_visited_col)
                            .withColumn("newDistance", new_distance_col)
                            .withColumn("newPath", new_path_col)
                            .drop("aggMess", "distance", "path")
                            .withColumnRenamed('newDistance', 'distance')
                            .withColumnRenamed('newPath', 'path'))
        cached_new_vertices = AM.getCachedDataFrame(new_vertices)
        g2 = GraphFrame(cached_new_vertices, g2.edges)
        
        print(current_node_id, destination, g2.vertices.filter(g2.vertices.id == destination).first().visited)
        if g2.vertices.filter(g2.vertices.id == destination).first().visited:
            return (g2.vertices.filter(g2.vertices.id == destination)
                    .withColumn("newPath", add_path_udf("path", "id"))
                    .drop("visited", "path")
                    .withColumnRenamed("newPath", "path"))

In [13]:
result = shortest_path(g, origin = "Amsterdam", destination = "Colchester", column_name="cost")
if result:
    result.select("id", "distance", "path").show(truncate=False)

hello
Amsterdam Colchester False
Utrecht Colchester False
Den Haag Colchester False
Gouda Colchester False
Rotterdam Colchester False
Hoek van Holland Colchester False
Felixstowe Colchester False
Ipswich Colchester False
Colchester Colchester True
+----------+--------+------------------------------------------------------------------------+
|id        |distance|path                                                                    |
+----------+--------+------------------------------------------------------------------------+
|Colchester|347.0   |[Amsterdam, Den Haag, Hoek van Holland, Felixstowe, Ipswich, Colchester]|
+----------+--------+------------------------------------------------------------------------+



#### Shortest Path Landmark (unweighted)

In [14]:
result = g.shortestPaths(["Colchester", "Immingham", "Hoek van Holland"])
result.sort(["id"]).select("id", "distances").show(truncate=False)

+----------------+--------------------------------------------------------+
|id              |distances                                               |
+----------------+--------------------------------------------------------+
|Amsterdam       |[Immingham -> 1, Hoek van Holland -> 2, Colchester -> 4]|
|Colchester      |[Colchester -> 0, Immingham -> 3, Hoek van Holland -> 3]|
|Den Haag        |[Hoek van Holland -> 1, Immingham -> 2, Colchester -> 4]|
|Doncaster       |[Immingham -> 1, Colchester -> 2, Hoek van Holland -> 4]|
|Felixstowe      |[Hoek van Holland -> 1, Colchester -> 2, Immingham -> 4]|
|Gouda           |[Hoek van Holland -> 2, Immingham -> 3, Colchester -> 5]|
|Hoek van Holland|[Hoek van Holland -> 0, Immingham -> 3, Colchester -> 3]|
|Immingham       |[Immingham -> 0, Colchester -> 3, Hoek van Holland -> 3]|
|Ipswich         |[Colchester -> 1, Hoek van Holland -> 2, Immingham -> 4]|
|London          |[Colchester -> 1, Immingham -> 2, Hoek van Holland -> 4]|
|Rotterdam  

#### Single Source Shortest Path

In [15]:
def sssp(g, origin, column_name="cost"): 
    vertices = g.vertices \
            .withColumn("visited", F.lit(False)) \
            .withColumn("distance",
                F.when(g.vertices["id"] == origin, 0).otherwise(float("inf"))) \
            .withColumn("path", F.array())
    cached_vertices = AM.getCachedDataFrame(vertices)
    g2 = GraphFrame(cached_vertices, g.edges)
    while g2.vertices.filter('visited == False').first():
        current_node_id = g2.vertices.filter('visited == False').sort("distance").first().id
        
        msg_distance = AM.edge[column_name] + AM.src['distance']
        msg_path = add_path_udf(AM.src["path"], AM.src["id"])
        msg_for_dst = F.when(AM.src['id'] == current_node_id, F.struct(msg_distance, msg_path))
        new_distances = g2.aggregateMessages(F.min(AM.msg).alias("aggMess"), sendToDst=msg_for_dst)
        
        new_visited_col = F.when(g2.vertices.visited | (g2.vertices.id == current_node_id), True).otherwise(False)
        
        new_distance_col = F.when(new_distances["aggMess"].isNotNull() &
                                      (new_distances.aggMess["col1"] <
                                      g2.vertices.distance),
                                      new_distances.aggMess["col1"]) \
                                      .otherwise(g2.vertices.distance)
        new_path_col = F.when(new_distances["aggMess"].isNotNull() &
                                  (new_distances.aggMess["col1"] <
                                  g2.vertices.distance),
                                  new_distances.aggMess["col2"]
                                  .cast("array<string>")) \
                                  .otherwise(g2.vertices.path)
        new_vertices = g2.vertices.join(new_distances, on="id",
                                            how="left_outer") \
                .drop(new_distances["id"]) \
                .withColumn("visited", new_visited_col) \
                .withColumn("newDistance", new_distance_col) \
                .withColumn("newPath", new_path_col) \
                .drop("aggMess", "distance", "path") \
                .withColumnRenamed('newDistance', 'distance') \
                .withColumnRenamed('newPath', 'path')
        cached_new_vertices = AM.getCachedDataFrame(new_vertices)
        g2 = GraphFrame(cached_new_vertices, g2.edges)
        
    return g2.vertices \
            .withColumn("newPath", add_path_udf("path", "id")) \
            .drop("visited", "path") \
            .withColumnRenamed("newPath", "path")

In [16]:
via_udf = F.udf(lambda path: path[1:-1], ArrayType(StringType()))

result = sssp(g, "Amsterdam", "cost")

(result
 .withColumn("via", via_udf("path"))
 .select("id", "distance", "via")
 .sort("distance")
 .show(truncate=False))

+----------------+--------+-------------------------------------------------------------+
|id              |distance|via                                                          |
+----------------+--------+-------------------------------------------------------------+
|Amsterdam       |0.0     |[]                                                           |
|Utrecht         |46.0    |[]                                                           |
|Den Haag        |59.0    |[]                                                           |
|Gouda           |81.0    |[Utrecht]                                                    |
|Rotterdam       |85.0    |[Den Haag]                                                   |
|Hoek van Holland|86.0    |[Den Haag]                                                   |
|Felixstowe      |293.0   |[Den Haag, Hoek van Holland]                                 |
|Ipswich         |315.0   |[Den Haag, Hoek van Holland, Felixstowe]                     |
|Colcheste