In [1]:
!pip install pyspark



In [2]:
import random
random.seed(42)

In [3]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .getOrCreate()

sc = spark.sparkContext
sc

In [4]:
# edges = [(1, 2, 10), (2, 3, 3), (2, 4, 24), (3, 2, 1)]
N = 10
C = 50
edges = [(random.randint(1,N), random.randint(1,N), random.randint(1,C)) for _ in range(0,N*N)]
for edge in edges:
  if edge[0] == edge[1]:
    edges.remove(edge)
nodes = set()
for edge in edges:
  x, y, _ = edge
  nodes.add(x)
  nodes.add(y)
nodes_rdd = sc.parallelize(list(nodes))
edges_rdd = sc.parallelize(edges)

In [5]:
start_node, *_ = random.choice(edges)
start_values = [(x, float("inf")) if x != start_node else (x, 0) for x in nodes_rdd.collect()]
start_values_rdd = sc.parallelize(start_values)

In [6]:
def to_send(edges, nodes_values):
  bc = sc.broadcast({k:v for (k,v) in nodes_values.collect()})
  out_rdd = edges.map(lambda x: (x[1], bc.value[x[0]] + x[2]))
  return out_rdd

In [7]:
msg = to_send(edges_rdd, start_values_rdd)

In [8]:
def send_msgs(msgs, nodes_values):
  recv = nodes_values.leftOuterJoin(msgs).map(lambda x: (x[0], min(tuple([i for i in x[1] if i is not None])))).reduceByKey(lambda x, y: min(x, y))
  return recv

In [9]:
iter1_values = send_msgs(msg, start_values_rdd)

In [10]:
def single_source_shortest_path(edges, nodes_values, max_iter=100):
  iter_count = 0
  while True:
    prev = nodes_values.collect()
    msg = to_send(edges, nodes_values)
    nodes_values = send_msgs(msg, nodes_values)
    if prev == nodes_values.collect() or iter_count > max_iter:
      return nodes_values
    iter_count += 1

In [11]:
out = single_source_shortest_path(edges_rdd, start_values_rdd)

In [12]:
out.collect()

[(1, 6),
 (2, 3),
 (3, 16),
 (4, 0),
 (5, 6),
 (6, 11),
 (7, 4),
 (8, 13),
 (9, 7),
 (10, 9)]