### Breadth-first search in Spark SQL
In the first course of the specialization you've already implemented BFS (Breadth-first search) using the RDD API. In this assignment you will implement the same algorithm using the Dataframe API.

The point of this assignment is to see in practice how fast Spark SQL is and why is this the default API in Spark now.

Your goal is to compute the length of the shortest path between two vertices. But now your implementation will be tested against the dataset of the greater size. **Notice, that the answer will change because the graph is more dense now.**

It is instructive to remember the implementation of the algorithm in Spark Core:

`def parse_edge(s):
  user, follower = s.split("\t")
  return (int(user), int(follower))
def step(item):
  prev_v, prev_d, next_v = item[0], item[1][0], item[1][1]
  return (next_v, prev_d + 1)
def complete(item):
  v, old_d, new_d = item[0], item[1][0], item[1][1]
  return (v, old_d if old_d is not None else new_d)
n = 400  # number of partitions
edges = sc.textFile("/data/twitter/twitter_sample.txt").map(parse_edge).cache()
forward_edges = edges.map(lambda e: (e[1], e[0])).partitionBy(n).persist()
x = 12
d = 0
distances = sc.parallelize([(x, d)]).partitionBy(n)
while True:
  candidates = distances.join(forward_edges, n).map(step)
  new_distances = distances.fullOuterJoin(candidates, n).map(complete, True).persist()
  count = new_distances.filter(lambda i: i[1] == d + 1).count()
  if count > 0:
    d += 1
    distances = new_distances
    print("d = ", d, "count = ", count)
  else:
    break`
    
Your goal is to implement the same algorithm, using Spark SQL. **Keep in mind that you should avoid using UDFs, if you are stuck, take a look at pyspark.sql.functions module. You will definitely need it.**

Your task is to find the shortest path between vertices 12 and 34. In case of multiple shortest paths, the first one will suffice. Output format is a single number which is the length of the shortest path.

**The result on the sample dataset:**

`8`

In [33]:
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.sql.functions import *

In [34]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("local").getOrCreate()

In [35]:
graph_schema = StructType([
    StructField("to_v", IntegerType(), False),
    StructField("from_v", IntegerType(), False)
])

In [36]:
dist_schema = StructType([
    StructField("vertex", IntegerType(), False),
    StructField("distance", IntegerType(), False)
])

In [37]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [16]:
#spark.read.csv("/data/twitter/twitter_sample2.txt", sep="\t", schema=graph_schema).show();

+----+------+
|to_v|from_v|
+----+------+
|  12|    18|
|  12|    41|
|  12|    57|
|  12|    62|
|  12|   235|
|  12|   278|
|  12|   291|
|  12|   338|
|  12|   456|
|  12|   614|
|  12|   648|
|  12|   711|
|  12|   875|
|  12|   988|
|  12|  1469|
|  12|  1507|
|  12|  1688|
|  12|  1974|
|  12|  2167|
|  12|  2241|
+----+------+
only showing top 20 rows



In [49]:
def shortest_path(v_from, v_to, dataset_path=None):
    edges = spark.read.csv(dataset_path, sep="\t", schema=graph_schema)
    edges.cache()

    distances = spark.createDataFrame([(v_from, 0)], dist_schema)
    path_length = 0
    
    # Here you should implement the BFS algorithm. It should return the length
    # of the minimal path (single integer) between v_from and v_to
    while True:
        candidates = (distances
                      .join(edges, distances.vertex==edges.from_v)
                      .select(edges.to_v.alias("vertex"), (distances.distance+1).alias("distance")) 
                     ).cache()

        new_distances = (distances
                         .join(candidates, on="vertex", how="full_outer")
                         .select("vertex",
                                 when(
                                     distances.distance.isNotNull(), distances.distance
                                 ).otherwise(
                                     candidates.distance
                                 ).alias("distance"))
                        ).persist()
        
        count = new_distances.where(new_distances.distance==path_length+1).count()
        
        if count > 0:
            path_length += 1            
            distances = candidates
        else:
            break  
            
        target = (new_distances
                  .where(new_distances.vertex == v_to)
                 ).count()
        
        if  target > 0:
            break

    return int(path_length)

In [48]:
print (shortest_path(12, 34, "/data/twitter/twitter_sample2.txt"));