# Spark GraphFrames

GraphFrames is a Spark [package](https://spark-packages.org/package/graphframes/graphframes) that aims to use DataFrames for representing and working with graph structures. For the examples you will use the same datasets as in the previous notebook where you worked with GraphX API. 

Because GraphFrames are still not part of the official Spark distribution, you need to include them as a package. The easiest way to do it is to add `graphframes:graphframes:0.6.0-spark2.3-s_2.11` to the `spark.jars.packages` configuration parameter in the `/usr/local/spark/spark-defaults.conf` file. 

Then start a Spark session using the following code:

In [1]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Spark course - GraphFrames").\
    master("local[*]").enableHiveSupport().getOrCreate()

sc = spark.sparkContext

Check your notebook's output to see if the GraphFrames package has been successfully downloaded.

## Creating GraphFrames

To create an `graphframes.GraphFrame`, you need to give it a DataFrame with vertices and a DataFrame with edges (you can find the official GraphFrames documentation [here](https://graphframes.github.io/graphframes/docs/_site/api/python/graphframes.html#graphframes.GraphFrame)). The vertices DataFrame needs to have a column named `id` with vertex IDs, and any other columns with additional information you need. The edges DataFrame needs to have `src` and `dst` columns (containing IDs of source and destination vertices) and then again other columns you may need.

You will recreate the Simpson's family graph now but use GraphFrames to represent it. Here is the graph figure again:
![graph](graph.png)

Create DataFrames from arrays of tuples (`spark.createDataFrame` method) and then call `toDF` method to give names to the columns. 

In [2]:
from pyspark.sql import Row
vertices = spark.createDataFrame([
    Row(id=1, name="Homer", age=39),
    Row(id=2, name="Marge", age=39),
    Row(id=3, name="Bart", age=12),
    Row(id=4, name="Milhouse", age=12) ])
edges = spark.createDataFrame([
    Row(src=1, dst=2, relation="marriedTo"),
    Row(src=3, dst=1, relation="father"),
    Row(src=3, dst=2, relation="mother"),
    Row(src=4, dst=3, relation="friend") ])

Now create a `GraphFrame` using these two DataFrames.

In [3]:
from graphframes import GraphFrame
graph = GraphFrame(vertices, edges)

Access graph's vertices and edges and `show` them.

In [4]:
graph.vertices.show()
graph.edges.show()

+---+---+--------+
|age| id|    name|
+---+---+--------+
| 39|  1|   Homer|
| 39|  2|   Marge|
| 12|  3|    Bart|
| 12|  4|Milhouse|
+---+---+--------+

+---+---------+---+
|dst| relation|src|
+---+---------+---+
|  2|marriedTo|  1|
|  1|   father|  3|
|  2|   mother|  3|
|  3|   friend|  4|
+---+---------+---+



Find all the tripplets in the graph and `show` them.

In [5]:
graph.triplets.show()

+-----------------+-----------------+--------------+
|              src|             edge|           dst|
+-----------------+-----------------+--------------+
|    [12, 3, Bart]|   [1, father, 3]|[39, 1, Homer]|
|[12, 4, Milhouse]|   [3, friend, 4]| [12, 3, Bart]|
|   [39, 1, Homer]|[2, marriedTo, 1]|[39, 2, Marge]|
|    [12, 3, Bart]|   [2, mother, 3]|[39, 2, Marge]|
+-----------------+-----------------+--------------+



## Filtering GraphFrames

Use GraphFrame's motif searching to find all sets of vertices where a connection exists between the first and the second and between the second and the third vertex, but not vice versa.

In [6]:
motif = graph.find("(a)-[]->(b); (b)-[]->(c); !(c)-[]->(b); !(b)-[]->(a)")
motif.show()

+-----------------+--------------+--------------+
|                a|             b|             c|
+-----------------+--------------+--------------+
|    [12, 3, Bart]|[39, 1, Homer]|[39, 2, Marge]|
|[12, 4, Milhouse]| [12, 3, Bart]|[39, 2, Marge]|
|[12, 4, Milhouse]| [12, 3, Bart]|[39, 1, Homer]|
+-----------------+--------------+--------------+



Further filter the resulting sets and find only those whose middle vertex has associated age of less than 30.

In [7]:
motif.filter("b.age < 30").show()

+-----------------+-------------+--------------+
|                a|            b|             c|
+-----------------+-------------+--------------+
|[12, 4, Milhouse]|[12, 3, Bart]|[39, 2, Marge]|
|[12, 4, Milhouse]|[12, 3, Bart]|[39, 1, Homer]|
+-----------------+-------------+--------------+



Find the vertex with the most edges going into it. Also find the name of the associated person.

In [8]:
import pyspark.sql.functions as F
graph.inDegrees.orderBy(F.col('inDegree').desc()).show()

+---+--------+
| id|inDegree|
+---+--------+
|  2|       2|
|  1|       1|
|  3|       1|
+---+--------+



In [9]:
graph.vertices.filter("id = 2").show()

+---+---+-----+
|age| id| name|
+---+---+-----+
| 39|  2|Marge|
+---+---+-----+



Find the vertex with the most outgoing edges and the name of the associated person (you can do it in two steps).

In [10]:
graph.outDegrees.orderBy(F.col('outDegree').desc()).show()

+---+---------+
| id|outDegree|
+---+---------+
|  3|        2|
|  1|        1|
|  4|        1|
+---+---------+



In [11]:
graph.vertices.filter(F.col('id') == 3).show()

+---+---+----+
|age| id|name|
+---+---+----+
| 12|  3|Bart|
+---+---+----+



Create a subgraph containing only the two vertices you found and display its triplets (create a new `GraphFrame`). Does it have the expected number of edges?

In [12]:
newGraph = GraphFrame(graph.vertices.filter((F.col('id') == 3) | (F.col('id') == 2)), edges)
newGraph.triplets.show()

+-------------+--------------+--------------+
|          src|          edge|           dst|
+-------------+--------------+--------------+
|[12, 3, Bart]|[2, mother, 3]|[39, 2, Marge]|
+-------------+--------------+--------------+



## Graph algorithms

For running graph algorithms lets again switch to the “Human Wayfinding in Information Networks” dataset.

Load articles and links with the following code:

In [13]:
articles = sc.textFile("../first-edition/ch09/articles.tsv").\
    filter(lambda line: line.strip() != "" and not line[0] == "#").\
    zipWithIndex()
links = sc.textFile("../first-edition/ch09/links.tsv").\
    filter(lambda line: line.strip() != "" and not line[0] == "#")
def spl(x):
    spl = x.split('\t')
    return (spl[0], spl[1])
linkIndexes = links.map(spl).join(articles).map(lambda x: x[1]).\
    join(articles).map(lambda x: x[1])

Unlike `Graph` in GraphX, `GraphFrame` has no `fromEdges` method, so you cannot use only a DataFrame with edges to create a `GraphFrame`. You will first need to convert `linkIndexes` to a DataFrame, create another DataFrame with all vertex IDs (you only need the `id` column), and then use both to create a GraphFrame. (Hint: to create a list of all vertex IDs try the DataFrame's `union` method)

In [14]:
wedges = spark.createDataFrame(linkIndexes, ["src", "dst"])
wvertices = wedges.select(wedges.src.alias('id')).union(wedges.select(wedges.dst.alias('id'))).distinct()

In [15]:
wikigraph = GraphFrame(wvertices, wedges)

Check your results by comparing the number of vertices in the graph and the number of unique article IDs in `linkIndexes`.

In [16]:
print(wikigraph.vertices.count())
print(linkIndexes.map(lambda x: x[0]).union(linkIndexes.map(lambda x: x[1])).distinct().count())

4592
4592


## Shortest paths

GraphFrames also have shortest paths algorithm that functions in the same way as the GraphX one. 

Run the [shortest paths algorithm](http://graphframes.github.io/api/scala/index.html#org.graphframes.GraphFrame@shortestPaths:org.graphframes.lib.ShortestPaths) to find the shortest path between vertices with IDs 10 ("14th century") and 3425 ("Rainbow"). (You can use either of these two as the starting landmark.)

In [17]:
shortest = wikigraph.shortestPaths([10])
shortest.show()

+----+---------+
|  id|distances|
+----+---------+
|  26|[10 -> 1]|
|  29|[10 -> 3]|
| 474|[10 -> 3]|
| 964|[10 -> 4]|
|1677|[10 -> 3]|
|1697|[10 -> 3]|
|1806|[10 -> 3]|
|1950|[10 -> 3]|
|2040|[10 -> 3]|
|2214|[10 -> 3]|
|2250|[10 -> 2]|
|2453|[10 -> 2]|
|2509|[10 -> 2]|
|2529|[10 -> 3]|
|2927|[10 -> 3]|
|3091|[10 -> 2]|
|3506|[10 -> 3]|
|3764|[10 -> 3]|
|4590|[10 -> 3]|
|  65|[10 -> 3]|
+----+---------+
only showing top 20 rows



In [18]:
shortest.filter(shortest.id == 3425).show()

+----+---------+
|  id|distances|
+----+---------+
|3425|[10 -> 2]|
+----+---------+



GraphFrames offer another algorithm for finding shortest paths between vertices: breadth-first search algorithm (the `bfs` method). Unlike shortest paths, BFS will also return the path between vertices. You can also use other vertex columns to specify which vertices you wish to find.

Find the shortest path using BFS between the two vertices you used above. Hint: because of a bug in Python GraphFrames interface, use String "from" and "to" expressions in SQL syntax)

In [47]:
bfsres = wikigraph.bfs("id = 10", "id = 3425")

In [20]:
bfsres.show()

+----+----------+------+------------+------+------------+------+
|from|        e0|    v1|          e1|    v2|          e2|    to|
+----+----------+------+------------+------+------------+------+
|[10]|[10, 1316]|[1316]|[1316, 4413]|[4413]|[4413, 3425]|[3425]|
|[10]|  [10, 12]|  [12]|  [12, 4413]|[4413]|[4413, 3425]|[3425]|
|[10]|[10, 1385]|[1385]|[1385, 4413]|[4413]|[4413, 3425]|[3425]|
|[10]|[10, 4147]|[4147]|[4147, 2165]|[2165]|[2165, 3425]|[3425]|
|[10]|[10, 4147]|[4147]|[4147, 2466]|[2466]|[2466, 3425]|[3425]|
+----+----------+------+------------+------+------------+------+



## Page rank

With GraphFrames you can run Page rank either until convergence (determined by the tolerance parameter and equivalent to GraphX) or until some number of iterations.

Run it now and specify tolerance (`tol` method) of `0.01`.

In [21]:
pr = wikigraph.pageRank(tol=0.01)

Examine the results.

In [22]:
pr.vertices.show()

+----+-------------------+
|  id|           pagerank|
+----+-------------------+
|  26|  2.038833865442534|
|  29|0.22148529739398057|
| 474|  2.278494867415354|
| 964|0.20566204895247073|
|1677|  1.154281917375591|
|1697| 0.2116434552836052|
|1806| 0.5611711984548275|
|1950| 0.1779560833591355|
|2040|0.18777907129744795|
|2214|  1.629150846099081|
|2250| 0.2032200347509332|
|2453|0.30756654346523216|
|2509|0.17133627281842367|
|2529|0.20786035245720705|
|2927|0.31619285501737315|
|3091| 0.2707326884211503|
|3506|  0.686664930156517|
|3764|  2.011237418669906|
|4590|0.19047425242692329|
|  65|0.17133627281842367|
+----+-------------------+
only showing top 20 rows



In [23]:
pr.edges.show()

+---+----+--------------------+
|src| dst|              weight|
+---+----+--------------------+
| 15|  19| 0.01098901098901099|
| 18|3758|0.037037037037037035|
| 24|1999|0.005988023952095809|
| 31|2053|               0.025|
| 32|2109|0.014285714285714285|
| 33|4234|0.017857142857142856|
| 37|3934|0.047619047619047616|
| 38| 594|0.010752688172043012|
| 38|3244|0.010752688172043012|
| 49|1086|0.047619047619047616|
| 51|1888|0.043478260869565216|
| 62|  15|0.022222222222222223|
| 74|1648| 0.07142857142857142|
| 97|2631| 0.02857142857142857|
|112|2605|              0.0625|
|125| 255| 0.02127659574468085|
|125|2444| 0.02127659574468085|
|128|1734|0.004716981132075...|
|130| 590|                 0.1|
|166|2417|0.015873015873015872|
+---+----+--------------------+
only showing top 20 rows



Which vertex has the highest rank? Which page does it correspond to (query the `articles` RDD)?

In [24]:
pr.vertices.orderBy(F.col('pagerank').desc()).show()

+----+------------------+
|  id|          pagerank|
+----+------------------+
|4297|42.941000938544825|
|1568| 28.83841348768163|
|1433| 28.52832394499783|
|4293| 27.95337082105467|
|1389| 21.95177052208199|
|1694|21.636324780265458|
|4542| 21.22829131687344|
|1385|20.022338324280703|
|2417| 19.85243244856757|
|2098|18.167974877814252|
|2226|17.476972659146497|
|2183| 16.72314650409159|
|3829|16.364722838771037|
| 894|16.053461713596366|
|3567|15.713729140392598|
|4148|15.661706730685271|
| 768|15.475989279371419|
|1101|14.666711238945627|
| 393|14.342080418591092|
| 128|14.287240443354486|
+----+------------------+
only showing top 20 rows



In [25]:
articles.filter(lambda x: x[1] == 4297).collect()

[('United_States', 4297)]

## Connected components

Using connected components is similar to the GraphX version. You just call `run` with no arguments. (Before running the algorithm set Spark's checkpoint directory to `/home/spark/checkpoint`)

Do that now and examine the results. Do they match those from the last notebook?

In [26]:
sc.setCheckpointDir("/home/spark/checkpoint")
cc = wikigraph.connectedComponents()

In [27]:
cc.select(cc['component']).distinct().count()

2

How many vertices are there in each connected component?

In [28]:
cc.groupBy(cc.component).count().show()

+---------+-----+
|component|count|
+---------+-----+
|        0| 4589|
|     1210|    3|
+---------+-----+



## Strongly connected components

Strongly connected components is also similar to GraphX version. Try it out with 100 iterations (call `maxIter` before calling `run`).

In [29]:
scc = wikigraph.stronglyConnectedComponents(maxIter=100)

How many strongly connected components did the algorithm find?

In [30]:
scc.select(scc.component).distinct().count()

519

Which components contain the most vertices?

In [31]:
scc.groupBy(scc.component).count().orderBy(F.col('count').desc()).show()

+---------+-----+
|component|count|
+---------+-----+
|        6| 4051|
|     2488|    6|
|     1831|    3|
|     2142|    2|
|     1986|    2|
|     2251|    2|
|      892|    2|
|     1834|    2|
|     1950|    2|
|     1513|    2|
|      557|    2|
|     1976|    2|
|      477|    2|
|     2160|    2|
|      195|    2|
|     1111|    2|
|     4224|    2|
|     2474|    2|
|     2321|    2|
|     1202|    1|
+---------+-----+
only showing top 20 rows



## Triangle count

Triangle count algorithm returns a DataFrame with a number of triangles passing through each vertex. Run it on the wikispeedia graph and examine the results.

In [32]:
tcnt = wikigraph.triangleCount()

In [33]:
tcnt.show()

+-----+----+
|count|  id|
+-----+----+
|  449|  26|
|   12|  29|
|  481| 474|
|    3| 964|
|  641|1677|
|    9|1697|
|  174|1806|
|   49|1950|
|   11|2040|
|  460|2214|
|   16|2250|
|   34|2453|
|  328|2509|
|   21|2529|
|   48|2927|
|  132|3091|
|   71|3506|
|  278|3764|
|   16|4590|
|  130|  65|
+-----+----+
only showing top 20 rows



## Label propagation

Label propagation is an algorithm for detecting communities in networks. The result of the algorithm is a DataFrame containing original vertex columns plus one additional column called `label`, corresponding to the detected community.

Run it now on the wikispeedia graph with the maximum of 20 iterations and examine the results.

In [34]:
wikilabels = wikigraph.labelPropagation(maxIter=20)
wikilabels.show()

+----+-----+
|  id|label|
+----+-----+
|  26| 3021|
|  29| 3021|
| 474| 3021|
| 964| 3021|
|1677| 3021|
|1697| 3021|
|1806| 3021|
|1950| 3021|
|2040| 3021|
|2214| 3021|
|2250| 3021|
|2453| 3021|
|2509| 3021|
|2529| 3021|
|2927| 3021|
|3091| 3021|
|3506| 3021|
|3764| 3021|
|4590| 3021|
|  65| 3021|
+----+-----+
only showing top 20 rows



Find the number of vertices in each label and compare the results with the connected components above.

In [35]:
wikilabels.groupBy(wikilabels.label).count().orderBy(F.col('count').desc()).show()

+-----+-----+
|label|count|
+-----+-----+
| 3021| 4589|
| 3849|    2|
| 1600|    1|
+-----+-----+



## Aggregating messages

`aggregateMessages` is a low-level method of implementing various graph algorithms. With its `sendToDst` and `sendToSrc` expressions you can send messages (as SQL expressions) to source and/or destination vertices of each triplet in the graph. `aggCol` expression is used for aggregating all the received messages.

Try using `aggregateMessages` on the Simpson's family graph to calculate the number of friends for each vertex (number of edges with relation `friend` going into the vertex). You need to provide `sendToMsg` and `aggCol` "functions" as Column definitions. To construct these Column definitions you can use Spark SQL functions and `graphframes.lib.AggregateMessages`'s `src`, `edge`, `dst`, and `msg` fields. 

The result of `aggregateMessages` will be a DataFrame with two columns (vertex `id` and the aggregation result), with generic column names. You will need to join this DataFrame with the original vertices to add the new column to the data (you should rename the generic column names first). 

In [44]:
from graphframes.lib import AggregateMessages as AM
msgToDst = F.when(AM.edge["relation"] == "friend", 1).otherwise(0)
aggres = graph.aggregateMessages(sendToDst=msgToDst, aggCol=F.sum(AM.msg)).toDF("aggid", "friends")
aggres.show()

+-----+-------+
|aggid|friends|
+-----+-------+
|    1|      0|
|    3|      1|
|    2|      0|
+-----+-------+



In [51]:
graph.vertices.join(aggres, F.col('id') == F.col("aggid")).select("id", 'name', 'age', 'friends').show()

+---+-----+---+-------+
| id| name|age|friends|
+---+-----+---+-------+
|  1|Homer| 39|      0|
|  3| Bart| 12|      1|
|  2|Marge| 39|      0|
+---+-----+---+-------+

