## Set Parameters

In [None]:
NN = 7 # Number of nearest neighbours for each city
spark.conf.set("spark.sql.shuffle.partitions", "4")

## Imports

In [None]:
from graphframes import GraphFrame

In [None]:
from pyspark.sql.types import *
from pyspark.sql.functions import *
import math

## Load and adjust basic DataFrame

In [None]:
citiesDFSchema = StructType([
  StructField("city", StringType(), False),
  StructField("city_ascii", StringType(), True),
  StructField("latitude", FloatType(), False),
  StructField("longitude", FloatType(), False),
  StructField("country", StringType(), True),
  StructField("iso2", StringType(), True),
  StructField("iso3", StringType(), True),
  StructField("admin_name", StringType(), True),
  StructField("capital", StringType(), True),
  StructField("population", FloatType(), True),
  StructField("id", LongType(), True)
])

In [None]:
citiesDF = spark.read.format("csv")\
  .option("sep", ",")\
  .option("inferSchema", "true")\
  .option("header", "true")\
  .schema(citiesDFSchema)\
  .load("/Users/petergerngross/Programming/Data/csv/Graph/simplemaps_worldcities_basicv1/worldcities.csv")\
  .drop("population")\
  .drop("id")\
  .drop("city_ascii")\
  .drop("country")\
  .drop("iso3")\
  .drop("admin_name")\
  .withColumnRenamed("city", "id")\
  .withColumn("capNum", when(col("capital") == "primary", 4)\
                        .when(col("capital") == "admin", 3)\
                        .when(col("capital") == "minor", 2)\
                        .otherwise(1))\
  .drop("capital")\
  .where("capNum != 1")\
  .drop("capNum")
  

citiesDF.printSchema()

In [None]:
citiesDF.show(10)

In [None]:
citiesDF.count()

## Function for calculating distances from geographic coordinates

In [None]:
def deg2rad(degrees):
  return math.pi * degrees / 180

def geoDistFlat(phi1, lambda1, phi2, lambda2):
  phiMRad = deg2rad((phi1 + phi2) / 2)
  k1 = 111.13209 - 0.56605 * math.cos(2 * phiMRad) + 0.00120 * math.cos(4 * phiMRad)
  k2 = 111.41513 * math.cos(phiMRad) - 0.09455 * math.cos(3 * phiMRad) + 0.00012 * math.cos(5 * phiMRad)
  return math.sqrt(math.pow(k1 * (phi2 - phi1), 2) + math.pow(k2 * (lambda2 - lambda1), 2))

### Register the function in Spark

In [None]:
geoDistUDF = udf(geoDistFlat)

## Create Edges DataFrame

In [None]:
cities1DF = citiesDF\
  .drop("iso2")\
  .withColumnRenamed("id", "src")\
  .withColumnRenamed("latitude", "latSrc")\
  .withColumnRenamed("longitude", "longSrc")
    
cities2DF = cities1DF\
  .withColumnRenamed("src", "dst")\
  .withColumnRenamed("latSrc", "latDst")\
  .withColumnRenamed("longSrc", "longDst")

In [None]:
cityConnectsDF = cities1DF.crossJoin(cities2DF)\
  .where("src != dst")\
  .withColumn("cityDistance", round(geoDistUDF(col("latSrc"), col("longSrc"), col("latDst"), col("longDst"))).cast("int"))\
  .select("src", "dst", "cityDistance")

In [None]:
cityConnectsDF.printSchema()

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(cityConnectsDF["src"]).orderBy(cityConnectsDF["cityDistance"])

cityConnectsNNDF = cityConnectsDF.select('*', rank().over(window).alias('rank'))\
  .filter(col('rank') <= NN)\
  .drop(col("rank"))

In [None]:
cityConnectsNNDF.count()

In [None]:
cityConnectsNNDF.where("src = 'Berlin' OR dst = 'Berlin'").show()

## Create and modify GraphFrames

In [None]:
cityGraph = GraphFrame(citiesDF, cityConnectsNNDF)
cityGraph.cache()

In [None]:
print("The cityGraph has "\
      + str(cityGraph.vertices.count()) + " nodes and "\
      + str(cityGraph.edges.count()) + " edges.")

In [None]:
citySubgraph = cityGraph.filterVertices("iso2 IN ('DE','AT','CH')")
print("The citySubgraph has "\
      + str(citySubgraph.vertices.count()) + " nodes and "\
      + str(citySubgraph.edges.count()) + " edges.")

In [None]:
citySubgraph.outDegrees.where("id IN ('Frankfurt','Stralsund')").show()

## Shortest Paths with BFS

In [None]:
paths = citySubgraph.bfs("id = 'Frankfurt'", "id = 'Basel'")

In [None]:
paths.printSchema()

In [None]:
paths.select(concat(col("from.id"), lit(", ")\
              ,col("v1.id"), lit(", ")\
              ,col("v2.id"), lit(", ")\
              ,col("v3.id"), lit(", ")\
              ,col("to.id")\
             ).alias("path"),\
             (col("e0.cityDistance")\
             + col("e1.cityDistance")\
             + col("e2.cityDistance")\
             + col("e3.cityDistance"))\
             .alias("totalDistance")\
            ).show(3,False)