# A* algorithm using Spark
__`MIDS w261: Machine Learning at Scale | UC Berkeley School of Information | Fall 2019`__

## VERY IMPORTANT NOTE: 
Unlike the regular assignments, this assignment is much more open ended. It is up to you to think through the problem, and make decisions about what is applicable, and what intermediate tasks are suitable to solve the problem. There are 4 tasks listed below, but these are only a guide. Your grade will be based on the quality of your approach, your reasoning, etc., and not as much on the accuracy of your results. So please provide justification for your design choices. And as always, have fun!!

# Background

[Dijsktras video](https://www.youtube.com/watch?v=pVfj6mxhdMw)   
[A* video](https://www.youtube.com/watch?v=eSOJ3ARN5FM&feature=youtu.be)

In our distributed SSSP algorithm, in order to find the shortest path, we must visit all nodes in the graph. This is also true of the single core Dijkstras algorithm. On a very large graph, this can be very resource intensive. A * addresses this problem by assigning a heuristic to each node which is a 'best guess' as to the distance from that node to the destination. This hueristic is used to prioritize the path taken to the target node, and the algorithm terminates when the target node is reached. This means that not all nodes must neccessarily be visited. The key to A * is to choose a good heuristic. If we underestimate the distance, the algorithm may end up visiting all nodes after all. If we overestimate the distance, we could end up with a poor solution.



## 2GB Wikipedia dataset:
https://www.dropbox.com/sh/2c0k5adwz36lkcw/AAAAKsjQfF9uHfv-X9mCqr9wa?dl=0.

## DATA Format:   
`node_id \t {neighbor_id:count, neighbor_id:count,...}`    
Where the count is the number of times the link appears on the page. We'll treat this as the "weight" of the edge between pages. 

The challenge in this case, is to find a good heuristic "distance" metric between pages. What does distance mean in this setting, where weights are the number of links? Is there a trade-off in calculating a "good" heuristic vs visiting all nodes?


# Task 1
* Come up with a good representative toy example. 
* Come up with your heuristic.
* Hand calculate A*

### 1.1 Come up with a good representative toy example

In [5]:
# imports
import re
import ast
import time
import numpy as np
import pandas as pd
import seaborn as sns
import networkx as nx
import matplotlib.pyplot as plt

In [6]:
%reload_ext autoreload
%autoreload 2

In [7]:
# store path to notebook
PWD = !pwd
PWD = PWD[0]

In [8]:
# start Spark Session
from pyspark.sql import SparkSession
app_name = "Astar_notebook"
master = "local[*]"
spark = SparkSession\
        .builder\
        .appName(app_name)\
        .master(master)\
        .getOrCreate()
sc = spark.sparkContext

In [9]:
sc = spark.sparkContext

In [14]:
# 1. Toy example - weighted directed toy example
!cat data/directed_toy.txt

1	{'2': 1, '3': 2}
2	{'3': 2, '4': 3, '5': 5}
3	{'4': 2}
4	{'5': 3}


In [15]:
toyRDD = sc.textFile('data/directed_toy.txt')

### 1.2. Come up with your heuristic

> Heruistic distances of each node to the end node are marked down in **red**, whereas the weights on each edge are marked down in **blue**. The image is first drawn on grids. Using the coordinates of each node, the heuristic distance is calculated based on Manhattan Distance. 

> <img src="Toy_Example2.jpg">

### 1.3 Hand calculate A*

- Starting with Node 1: $f = g + h = 0 + 8 = 8$
- Node 1 goes to Node 2 and Node 3. 
    - Node 2: $f = g + h = 1 + 7 = 8 \text{ (Node 1)}$ 
    - Node 3: $f = g + h = 2 + 9 = 11 \text{ (Node 1)}$
- Since $8<11$, next we expand to Node 2, which is connected to Node 3, 4, and 5.
    - Node 3: $f = g + h = 3 + 9 = 12 \text{ (Node 2)}$
    - Node 4: $f = g + h = 4 + 5 = 9 \text{ (Node 2)}$
    - Node 5: $f = g + h = 6 + 0 = 6 \text{ (Node 2)}$
- Since $6$ is the smallest, next we expand to Node 5. Notice that Node 5 is our end node. Our calculation stops here. 
- Therefore, the shortest path from Node 1 to 5 is: Node 1 - 2 - 5. 

# Task 2
Implement Parallel A*. You can use the provided code as a strating point, but you don't have to. You are free to use RDDs, DataFrames, or GraphFrames. 

In [16]:
from pyspark.accumulators import AccumulatorParam

# Spark only implements Accumulator parameter for numeric types.
# This class extends Accumulator support to the string type.
class StringAccumulatorParam(AccumulatorParam):
    def zero(self, value):
        return value
    def addInPlace(self, val1, val2):
        return val1 + val2

# Task 3
Run your implementation on your toy examples. Note: You may want to come up with several toy examples to test how your algorithm performs.

# Task 4
Finally, run your algorithm on the full dataset and compare your results to the SSSP implementation in terms of runtime performance, as well as accuracy. Discuss your findings, tradeoffs, challenges, etc..

## Standalone SSSP implementation

In [2]:
%%writefile sssp.py
from __future__ import print_function
import ast
import sys
from pyspark.accumulators import AccumulatorParam
from pyspark.sql import SparkSession

# Spark only implements Accumulator parameter for numeric types.
# This class extends Accumulator support to the string type.
class StringAccumulatorParam(AccumulatorParam):
    def zero(self, value):
        return value
    def addInPlace(self, val1, val2):
        return val1 + val2


###################################
# INITIALIZE GRAPH
###################################

def parseData(line):
    line = line.strip()
    key, value = line.split("\t")
    key = str(key)
  
    if key == startNode.value:
        return (key, ("Q",ast.literal_eval(value),0,key))
    else:
        return (key, ("U",ast.literal_eval(value),float("inf"),""))
    
    
    
###################################
# MAPPER
###################################   

def expandFrontier(row):
    key = row[0]
    status = row[1][0]
    neighbors = row[1][1]
    distance = row[1][2]
    path = row[1][3]
  
    if status == "Q":
    
    # put neighbors in Q mode and update path length by incrementing path length of N
        for neighbor in neighbors:
            yield neighbor, ("Q", {}, distance + int(neighbors[neighbor]), str(path)+" -> "+str(neighbor))
      
    # Update status of current node to Visited
    status = "V"
      
    yield key, (status, neighbors, distance, path)


###################################
# REDUCER
###################################

def restoreGraph(a,b):
    
    # It's important that the node in status Q comes first.
    a,b = sorted([a,b]) 
    
    _status, _neighbors, _distance, _path = a # <- Q state (if there is a Q state)
    status, neighbors, distance, path = b # <- V or U state

    if distance > _distance: # if the new path we discovered is shorter than the distance in a visited node, reset the visited node to Q state
        status = "Q" # <- the magic for weighted graphs
        distance = _distance
        path = _path            

    return (status, neighbors, distance, path)  

  
###################################
# ACCUMULATORS
###################################  
  
def terminate(row):
    if row[1][0] == "V" and row[0] == targetNode.value:  
        targetAccum.add(1)
        pathAccum.add(str(row[1][3])+" distance: "+str(row[1][2]))
    if row[1][0] == "Q":
        statusAccum.add(1)

    
    
if __name__ == "__main__":
  
    if len(sys.argv) != 5:
        print("Usage: SSSP <file> <startNode> <targetNode> <isWeighted: 0|1>", file=sys.stderr)
        sys.exit(-1)

    
    app_name = "graphs-intro"
    master = "local[*]"
  
    spark = SparkSession \
          .builder \
          .appName(app_name) \
          .master(master) \
          .getOrCreate()
  
    sc = spark.sparkContext
  
    # remember to broadcast global variables:
    dataFile = sc.textFile(sys.argv[1])
    startNode = sc.broadcast(sys.argv[2])
    targetNode = sc.broadcast(sys.argv[3])
    weighted = sys.argv[4]
  
    rdd = dataFile.map(parseData).cache()

    notconverged = True
    iteration = 0
    while notconverged:
        iteration = iteration + 1
        targetAccum = sc.accumulator(0)
        statusAccum = sc.accumulator(0)
        pathAccum = sc.accumulator("", StringAccumulatorParam())

        rdd = rdd.flatMap(expandFrontier).reduceByKey(restoreGraph)

        rdd.foreach(terminate)

        if weighted == "1":
            if statusAccum.value == 0: # no more nodes in Q status
            notconverged = False
        else:
            if targetAccum.value == 1: # reached target node
            notconverged = False

        print("-"*50)  
        print ("After Iteration "+str(iteration))
        print("Node id, (Status, {out_nodes},distance,path)")

        for i in rdd.collect():
            print(i)

        print("Num nodes in Q status: ",statusAccum.value)
        #print("Target node in V status: ",targetAccum.value)  # we only care about this in unweighted graphs, where reaching target node terminates the algorithim
        print("-"*50)    
    

    print("Num nodes in Q status: ",statusAccum.value)
    #print("Target node in V status: ",targetAccum.value)
    print("Iterations: ", iteration)
    print("Path: ",pathAccum.value)
    print("="*20)

    spark.stop()

Writing sssp.py


## Going further
Alternative Search algorithms
* Ripple Search [Brand et al., 2012]
* I Fringe Search