In [5]:
# mapping of covid locations (fips + GPS) to the closest weather station

In [97]:
import configparser
import os
import pyspark
from pyspark import SparkConf
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.functions import udf, col, lit
from pyspark.sql.types import MapType, StringType, FloatType
from pyspark.sql import DataFrame
from collections import OrderedDict
import pandas as pd
import numpy as np

In [7]:
config = configparser.ConfigParser()
config.read("capstone.cfg")

os.chdir(config["PATH"]["project"])
project_path = config["PATH"]["project"]


In [8]:
def create_spark_session():
    spark = SparkSession \
        .builder \
        .appName("covid_DB") \
        .getOrCreate()
    
    return spark

In [10]:
spark = create_spark_session()

In [83]:
# Load relevant stations with weather element
path = os.path.join(project_path, "OUT_DATA", "filtered_stations")
selected_stations = spark.read.parquet(path)

In [84]:
selected_stations.printSchema()

root
 |-- station_id: string (nullable = true)
 |-- measured: string (nullable = true)



In [85]:
selected_stations.select("measured").distinct().show()

+--------+
|measured|
+--------+
|    TMIN|
|    SNOW|
|    AWND|
|    PRCP|
+--------+



In [86]:
# load all stations, with GPS location
raw_stations = spark.read.csv( os.path.join(project_path, "DATA", "WEATHER", "ghcnd-stations.txt"))

In [87]:
# parse raw stations into columns
@udf(MapType( StringType(), StringType()))
def ParseStationsUDF(line):
    return{
        "station_id": line[0:11],
        "latitude" : line[13:20], 
        "longitude" : line[21:30], 
        "elevation" : line[31:38], 
        "state" : line[38:40], 
        "station_name" : line[41:]
        
    }

fields = OrderedDict( [
        ( "station_id" , "string"),
        ( "latitude" , "float"), 
        ("longitude" , "float"), 
        ("elevation" , "float"),
        ("state" , "string"), 
        ("station_name" , "string")
] )

#exprs = [ f"parsed['{field}'].cast({fld_type}) as {field}" for field, fld_type in fields.items() ]
exprs = [ f"CAST(parsed['{field}'] AS {fld_type}) AS {field}" for field, fld_type in fields.items() ]

df_stations = raw_stations.withColumn("parsed", ParseStationsUDF("_c0")).selectExpr( *exprs)

In [88]:
df_stations.printSchema()

root
 |-- station_id: string (nullable = true)
 |-- latitude: float (nullable = true)
 |-- longitude: float (nullable = true)
 |-- elevation: float (nullable = true)
 |-- state: string (nullable = true)
 |-- station_name: string (nullable = true)



In [89]:
df_stations = df_stations.join(selected_stations, ["station_id"])

In [90]:
df_stations.count()

23556

In [91]:
df_stations.printSchema()

root
 |-- station_id: string (nullable = true)
 |-- latitude: float (nullable = true)
 |-- longitude: float (nullable = true)
 |-- elevation: float (nullable = true)
 |-- state: string (nullable = true)
 |-- station_name: string (nullable = true)
 |-- measured: string (nullable = true)



In [92]:
df_stations.show(10)

+-----------+--------+---------+---------+-----+--------------------+--------+
| station_id|latitude|longitude|elevation|state|        station_name|measured|
+-----------+--------+---------+---------+-----+--------------------+--------+
|AQC00914000| 14.3167|-170.7667|    408.4|   AS|AASUFOU          ...|    PRCP|
|AQC00914141| 14.2667|-170.6167|      4.6|   AS|FAGAITUA         ...|    PRCP|
|AQC00914594| 14.3333|-170.7667|     42.4|   AS|MALAELOA         ...|    PRCP|
|AQW00061705| 14.3306|-170.7136|      3.7|   AS|PAGO PAGO WSO AP ...|    AWND|
|AQW00061705| 14.3306|-170.7136|      3.7|   AS|PAGO PAGO WSO AP ...|    TMIN|
|AQW00061705| 14.3306|-170.7136|      3.7|   AS|PAGO PAGO WSO AP ...|    PRCP|
|CQC00914080| 15.2136| 145.7497|    252.1|   MP|CAPITOL HILL 1   ...|    TMIN|
|CQC00914080| 15.2136| 145.7497|    252.1|   MP|CAPITOL HILL 1   ...|    PRCP|
|CQC00914801| 14.1717| 145.2428|    179.2|   MP|ROTA AP          ...|    TMIN|
|CQC00914801| 14.1717| 145.2428|    179.2|   MP|ROTA

In [164]:
# Load NYT locations (FIPS + GPS)
path = os.path.join(project_path, "OUT_DATA", "nyt_locations_geography")
df_locations = spark.read.parquet(path)

In [165]:
df_locations = df_locations.where( ( ~ F.isnan("latitude") ) | (~ F.isnan("longitude")) )

In [166]:
df_locations.printSchema()

root
 |-- fips: string (nullable = true)
 |-- county: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- state: string (nullable = true)



In [120]:
def precompute_distance(l_ref : DataFrame) -> DataFrame:
    l_ref = l_ref.withColumnRenamed("latitude", "latitude_degrees")
    l_ref = l_ref.withColumnRenamed( "longitude", "longitude_degrees") 
    @udf( FloatType())
    def degree_to_radian(x):
        return  x* np.pi / 180.
    l_ref = l_ref.withColumn("latitude", degree_to_radian("latitude_degrees") )
    l_ref = l_ref.withColumn("longitude", degree_to_radian("longitude_degrees") )
    l_ref = l_ref.withColumn("cos_latitude", F.cos("latitude") )  
    print(type(l_ref))
    return l_ref

In [121]:
df_stations_precompute =precompute_distance(df_stations)

<class 'pyspark.sql.dataframe.DataFrame'>


In [101]:
df_stations_precompute.count()

23556

In [123]:
df_stations_precompute.printSchema()

root
 |-- station_id: string (nullable = true)
 |-- latitude_degrees: float (nullable = true)
 |-- longitude_degrees: float (nullable = true)
 |-- elevation: float (nullable = true)
 |-- state: string (nullable = true)
 |-- station_name: string (nullable = true)
 |-- measured: string (nullable = true)
 |-- latitude: float (nullable = true)
 |-- longitude: float (nullable = true)
 |-- cos_latitude: double (nullable = true)



In [124]:
df_stations_precompute.agg({"latitude" : "min", "longitude" : "min", "latitude_degrees" : "min", "longitude_degrees" : "min" } ).collect()

[Row(min(longitude_degrees)=-170.76669311523438, min(latitude_degrees)=13.389399528503418, min(latitude)=0.2336890995502472, min(longitude)=-2.980441093444824)]

In [125]:
df_stations_precompute.agg({"latitude" : "max", "longitude" : "max", "latitude_degrees" : "max", "longitude_degrees" : "max" } ).collect()

[Row(max(longitude_degrees)=145.74969482421875, max(latitude_degrees)=71.32140350341797, max(latitude)=1.2447932958602905, max(longitude)=2.5438120365142822)]

In [167]:
df_locations_precompute = precompute_distance(df_locations)

<class 'pyspark.sql.dataframe.DataFrame'>


In [168]:
df_locations_precompute.agg({"latitude" : "min", "longitude" : "min", "latitude_degrees" : "min", "longitude_degrees" : "min" } ).collect()

[Row(min(longitude_degrees)=-164.1889190673828, min(latitude_degrees)=13.444, min(latitude)=0.23464205861091614, min(longitude)=-2.8656373023986816)]

In [169]:
df_locations_precompute.agg({"latitude" : "max", "longitude" : "max", "latitude_degrees" : "max", "longitude_degrees" : "max" } ).collect()

[Row(max(longitude_degrees)=178.33880615234375, max(latitude_degrees)=69.4493408203125, max(latitude)=1.212119698524475, max(longitude)=3.1125993728637695)]

# closest station for all fips

In [193]:
sub_fips = df_locations_precompute.limit(100)\
    .select("fips", "latitude", "longitude", "cos_latitude")\
    .withColumnRenamed("latitude", "latitude_fips")\
    .withColumnRenamed("longitude", "longitude_fips")\
    .withColumnRenamed("cos_latitude", "cos_latitude_fips")

In [194]:
sub_stations = df_stations_precompute.limit(100)\
    .select("station_id", "measured","latitude", "longitude", "cos_latitude")\
    .withColumnRenamed("latitude", "latitude_station")\
    .withColumnRenamed("longitude", "longitude_station")\
    .withColumnRenamed("cos_latitude", "cos_latitude_station")

In [195]:
fips_cross_stations = sub_fips.crossJoin(sub_stations)

In [191]:
fips_cross_stations.count()

10000

In [196]:
fips_cross_stations.printSchema()

root
 |-- fips: string (nullable = true)
 |-- latitude_fips: float (nullable = true)
 |-- longitude_fips: float (nullable = true)
 |-- cos_latitude_fips: double (nullable = true)
 |-- station_id: string (nullable = true)
 |-- measured: string (nullable = true)
 |-- latitude_station: float (nullable = true)
 |-- longitude_station: float (nullable = true)
 |-- cos_latitude_station: double (nullable = true)



In [197]:
@udf( FloatType())
def haversine(lat_1, long_1, cos_lat_1, lat_2, long_2, cos_lat_2):
        #print(l_ref.head())
        delta_lat = ( np.sin( (lat_1 - lat_2) * 0.5 ) )**2
        delta_long = ( np.sin( (long_1 - long_2) * 0.5) )**2
        a = delta_lat + delta_long \
                        * cos_lat_1 * cos_lat_2
        sqrt_a = np.sqrt(a)
        sqrt_1_a = np.sqrt(1.-a)
        return  np.arctan2( sqrt_a, sqrt_1_a )

In [200]:
fips_cross_stations_distance = fips_cross_stations.withColumn("distance", 
        haversine("latitude_fips", "longitude_fips", "cos_latitude_fips", "latitude_station", "longitude_station", "cos_latitude_station"))

In [202]:
fips_cross_stations_distance.show(10)

Py4JJavaError: An error occurred while calling o1393.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 201.0 failed 1 times, most recent failure: Lost task 0.0 in stage 201.0 (TID 4163, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 342, in dump_stream
    self.serializer.dump_stream(self._batched(iterator), stream)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 141, in dump_stream
    for obj in iterator:
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 331, in _batched
    for item in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 80, in <lambda>
    return lambda *a: f(*a)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-197-29e08f0fb3df>", line 7, in haversine
NameError: name 'cos_lat_2' is not defined

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:81)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:64)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage10.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1887)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1875)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1874)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1874)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2108)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2057)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2046)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in main
    process()
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 367, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 342, in dump_stream
    self.serializer.dump_stream(self._batched(iterator), stream)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 141, in dump_stream
    for obj in iterator:
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 331, in _batched
    for item in iterator:
  File "<string>", line 1, in <lambda>
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 80, in <lambda>
    return lambda *a: f(*a)
  File "/home/user/anaconda3/envs/pyspark/lib/python3.7/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-197-29e08f0fb3df>", line 7, in haversine
NameError: name 'cos_lat_2' is not defined

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:81)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:64)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage10.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


In [203]:
@udf( FloatType())
def toto(lat_1, lat_2):
    delta_lat = ( np.sin( (lat_1 - lat_2) * 0.5 ) )**2
    return delta_lat

In [204]:
fips_cross_stations_distance = fips_cross_stations.withColumn("distance", toto("latitude_fips", "latitude_station") )

In [229]:
fips_cross_stations_distance = fips_cross_stations.withColumn("delta_lat", F.pow( F.sin(0.5*col("latitude_fips") - col("latitude_station") ), 2 ))

In [230]:
fips_cross_stations_distance.show(10)

+-----+-------------+--------------+-------------------+-----------+--------+----------------+-----------------+--------------------+-------------------+
| fips|latitude_fips|longitude_fips|  cos_latitude_fips| station_id|measured|latitude_station|longitude_station|cos_latitude_station|          delta_lat|
+-----+-------------+--------------+-------------------+-----------+--------+----------------+-----------------+--------------------+-------------------+
|02198|    0.9718477|     -2.324122| 0.5637744149985534|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108|0.05469250905819356|
|02240|    1.1146545|    -2.4996367|0.44048764879312186|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108|0.09158657767490873|
|02261|    1.0669032|    -2.5544662| 0.4828384076125879|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108|0.07828389561501606|
|02090|    1.1288098|    -2.5577478| 0.4277359281117351|AQC00914000|    PRCP

In [231]:
type(F.pow)

function

In [240]:
def delta_coord(col1, col2):
    return F.pow( F.sin(0.5*col1 - col2 ), 2 )

In [250]:
def haversine( lat1, long1, cos_lat1, lat2, long2, cos_lat2):
    delta_lat = delta_coord(lat1, lat2)
    delta_long = delta_coord(long1, long2)
    a = delta_lat + delta_long * cos_lat1 * cos_lat2
    return  F.atan2( F.sqrt(a), F.sqrt( 1.-a ) )

In [251]:
df_toto = fips_cross_stations.withColumn("toto",                                  
        haversine(col("latitude_fips"), col("longitude_fips"),col("cos_latitude_fips"), 
                    col("latitude_station"), col("longitude_station"), col("cos_latitude_station")) )

In [252]:
df_toto.show(10)

+-----+-------------+--------------+-------------------+-----------+--------+----------------+-----------------+--------------------+------------------+
| fips|latitude_fips|longitude_fips|  cos_latitude_fips| station_id|measured|latitude_station|longitude_station|cos_latitude_station|              toto|
+-----+-------------+--------------+-------------------+-----------+--------+----------------+-----------------+--------------------+------------------+
|02198|    0.9718477|     -2.324122| 0.5637744149985534|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108| 0.853762920078671|
|02240|    1.1146545|    -2.4996367|0.44048764879312186|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108|0.7929826951260189|
|02261|    1.0669032|    -2.5544662| 0.4828384076125879|AQC00914000|    PRCP|      0.24987355|        -2.980441|  0.9689436985050108|0.8234070924966638|
|02090|    1.1288098|    -2.5577478| 0.4277359281117351|AQC00914000|    PRCP|     