In [1]:
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.sql.functions import *

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().master("local[2]").getOrCreate()

In [3]:
graph_schema = StructType([
    StructField("to", IntegerType(), False),
    StructField("from", IntegerType(), False)
])

In [4]:
dist_schema = StructType([
    StructField("vertex", IntegerType(), False),
    StructField("distance", IntegerType(), False)
])

In [None]:
def parse_edge(s):
  user, follower = s.split("\t")
  return (int(user), int(follower))

def step(item):
  prev_v, prev_d, next_v = item[0], item[1][0], item[1][1]
  return (next_v, prev_d + 1)

def complete(item):
  v, old_d, new_d = item[0], item[1][0], item[1][1]
  return (v, old_d if old_d is not None else new_d)

n = 400  # number of partitions
edges = sc.textFile("/data/twitter/twitter_sample.txt").map(parse_edge).cache()
forward_edges = edges.map(lambda e: (e[1], e[0])).partitionBy(n).persist()

x = 12
d = 0
distances = sc.parallelize([(x, d)]).partitionBy(n)
while True:
  candidates = distances.join(forward_edges, n).map(step)
  new_distances = distances.fullOuterJoin(candidates, n).map(complete, True).persist()
  count = new_distances.filter(lambda i: i[1] == d + 1).count()
  if count > 0:
    d += 1
    distances = new_distances
    print("d = ", d, "count = ", count)
  else:
    break

In [22]:
edges = spark.read.csv("/data/twitter/twitter_sample2.txt", sep="\t", schema=graph_schema).cache().alias("e")

In [12]:
edges.show(2)

+---+----+
| to|from|
+---+----+
| 12|  18|
| 12|  41|
+---+----+
only showing top 2 rows



In [30]:
distances = edges.where("from=12").withColumn("distance",lit(0)).withColumnRenamed("to","to_d").withColumnRenamed("from","from_d")

In [33]:
new_distances = distances.join(edges,distances["to_d"] == edges["from"])

In [34]:
new_distances.show(2)

+----+------+--------+---+----+
|to_d|from_d|distance| to|from|
+----+------+--------+---+----+
| 380|    12|       0|690| 380|
| 380|    12|       0| 31| 380|
+----+------+--------+---+----+
only showing top 2 rows



In [5]:
def shortest_path(v_from, v_to, dataset_path=None):
    edges = spark.read.csv(dataset_path, sep="\t", schema=graph_schema)
    path_length = None
    
    # Here you should implement the BFS algorithm. It should return the length
    # of the minimal path (single integer) between v_from and v_to
    
    return path_length

In [6]:
print shortest_path(12, 34, "/data/twitter/twitter_sample2.txt")

None
