# Spark Anaysis of Total Performance Data: GPS and Sectional Data

Analyzing the TPD GPS data alongside Equibase (EQB) data can provide significant predictive insights into horse performance. By focusing on the granular details of a horse’s movement during a race, you can derive valuable metrics that complement traditional EQB ratings and help identify under- or over-rated horses. 

## Roadmap for maximizing the predictive capabilities of the TPD GPS data:

### Key Ideas and Strategies

1. Derive Advanced Pace Metrics

Understanding how a horse’s speed changes over the course of a race can reveal its racing style and potential strengths or weaknesses:

	•	Early Pace: Average speed and acceleration during the first segment (e.g., first 20% of the race).
	•	Mid-Race Pace: Average speed and deceleration during the middle segments.
	•	Late Pace: Average speed and deceleration in the final segment.
	•	Sustained Speed: Identify segments where the horse maintains a steady speed.
	•	Peak Speed Timing: The point in the race where the horse reaches its peak speed.

2. Fatigue Factor

Calculate how much the horse slows down as the race progresses:

	•	Use metrics like:
	•	Percentage drop from peak speed to finish speed.
	•	Maximum acceleration vs. deceleration ratios.
	•	Change in stride frequency as the race progresses.

3. Sectional Efficiency

Quantify how efficiently the horse runs its sections:

    •	Compare actual times vs. expected times for each section based on the route characteristics.
	•	Efficiency Ratio:
    •	A high ratio might indicate a horse ran extra distance due to poor cornering or positioning.

4. Overlay TPD with EQB Ratings

Use EQB’s traditional metrics (e.g., speed ratings, form) to cross-reference with TPD data:

	•	Identify horses that consistently outperform EQB predictions.
	•	Investigate horses with high EQB ratings but poor TPD-based performance (e.g., poor fatigue factors or inefficient sectional running).

5. Route Characteristics

If the routes table contains track-specific details (e.g., turn sharpness, surface type, gradient):

	•	Incorporate these into the analysis.
	•	Evaluate how specific horses handle different track conditions (e.g., wide turns, long stretches).
	•	Identify patterns like “performs better on flatter tracks” or “struggles on uphill finishes.”

6. Horse vs. Peer Comparisons

Evaluate how each horse performs relative to its competition in the same race:

	•	Compare sectional times and speeds with other horses in the race.
	•	Rank horses based on performance within each race segment.

7. Acceleration Profiles

	•	Plot acceleration over time to identify patterns (e.g., burst speed vs. steady acceleration).
    •	Highlight horses with exceptional closing speed (valuable in longer races) or fast starts (important in short sprints).

9. Cluster Analysis of Racing Styles

Use clustering techniques to group horses by similar racing profiles:

	•	Inputs: Early pace, mid-pace, late pace, fatigue factor, sectional efficiency.
	•	Output: Clusters representing different racing styles (e.g., “early speed burners,” “closers,” “steady sustainers”).

9. Historical Analysis

Identify trends over a horse’s career:

	•	Does the horse improve or decline over time?
	•	Are there patterns in performance tied to specific jockeys, trainers, or race conditions?



# Using Spark for Efficient Processing

Steps to Implement

	1.	Load the Data
	•	Load gpspoint, gps_aggregated_results, and routes into Spark DataFrames.
	2.	Segment the Race
	•	Divide each race into sections (e.g., by distance markers or time intervals).
	•	Use PARTITION BY in Spark to process horses within each race separately.
	3.	Derive Metrics
	•	Speed Metrics: Use window functions to calculate average, min, max speeds.
	•	Acceleration/Deceleration: Compute using differences in speed and timestamps.
	•	Efficiency: Calculate distance ran vs. track distance.
	4.	Integrate with EQB Data
	•	Join with EQB tables on course_cd, race_date, and race_number for comparison.
	5.	Save Aggregated Data
	•	Save results into gps_aggregated_results and tpd_features.

Example Spark Code

Here’s a high-level implementation for deriving pace metrics:


In [9]:
import os
import logging
import pprint
from pyspark.sql.functions import (
    col, unix_timestamp, when, first, last, lag, udf, sum as spark_sum,
    mean as spark_mean, min as spark_min, max as spark_max, round as spark_round
)
from pyspark.sql.window import Window
from pyspark.sql.types import TimestampType
from datetime import timedelta

# Importing utility functions from your project structure
from src.data_preprocessing.data_prep2.data_healthcheck import (
    time_series_data_healthcheck, dataframe_summary
)
from src.data_preprocessing.data_prep1.data_loader import (
    load_data_from_postgresql, reload_parquet_files, 
    load_named_parquet_files, merge_results_sectionals
)
from src.data_preprocessing.data_prep1.sql_queries import sql_queries
from src.data_preprocessing.gps_aggregated.data_features_enhancements import (
    data_enhancements
)
from src.data_preprocessing.gps_aggregated.merge_gps_sectionals_agg import (
    merge_gps_sectionals
)
from src.data_preprocessing.data_prep1.data_utils import (
    save_parquet, gather_statistics, initialize_environment,
    load_config, initialize_logging, initialize_spark, drop_duplicates_with_tolerance,
    identify_and_impute_outliers, identify_and_remove_outliers, process_merged_results_sectionals,
    identify_missing_and_outliers
)
    
spark = None
try:
    spark, jdbc_url, jdbc_properties, queries, parquet_dir, log_file = initialize_environment()
except Exception as e:
    print(f"An error occurred during initialization: {e}")
    logging.error(f"An error occurred during initialization: {e}")

2024-12-22 21:24:51,295 - INFO - Environment setup initialized.


Initializing Spark session...
Spark session created successfully with Sedona and GeoTools integrated.


In [10]:
gpspoint = spark.read.parquet(os.path.join(parquet_dir, "gpspoint.parquet"))
sectionals = spark.read.parquet(os.path.join(parquet_dir, "merge_results_sectionals.parquet"))  

In [11]:
sectionals.count()

1108928

In [12]:
gpspoint.count()

20358278

In [79]:
enriched_data = spark.read.parquet(os.path.join(parquet_dir, "enriched_data.parquet"))

In [85]:
categorical_cols = [
        "course_cd", "equip", "surface", "trk_cond", "weather", 
        "race_type", "sex", "med", "stk_clm_md", "turf_mud_mark"
    ]

from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
    

In [86]:
indexers = [
    StringIndexer(inputCol=c, outputCol=c+"_index") for c in categorical_cols
]

In [87]:
categorical_cols = [
    "course_cd", "equip", "surface", "trk_cond", "weather", 
    "race_type", "sex", "med", "stk_clm_md", "turf_mud_mark"
]
indexers = [
    StringIndexer(inputCol=c, outputCol=c+"_index", handleInvalid="keep") 
    for c in categorical_cols
]
encoders = [
    OneHotEncoder(inputCols=[c+"_index"], outputCols=[c+"_ohe"]) 
    for c in categorical_cols
]


In [88]:
ohe_cols = [c+"_ohe" for c in categorical_cols]

In [89]:
print(ohe_cols)

['course_cd_ohe', 'equip_ohe', 'surface_ohe', 'trk_cond_ohe', 'weather_ohe', 'race_type_ohe', 'sex_ohe', 'med_ohe', 'stk_clm_md_ohe', 'turf_mud_mark_ohe']


In [81]:
enriched_data.select("distance_meters").show(10, truncate=False)

+---------------+
|distance_meters|
+---------------+
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
|1207.008       |
+---------------+
only showing top 10 rows



In [63]:
# Just do a distinct count on the 'horse_id' column
unique_horse_count = enriched_data.select("horse_id").distinct().count()

print(f"Number of unique horses: {unique_horse_count}")

Number of unique horses: 32104




In [64]:
from pyspark.sql.functions import avg, col

# 1) Count how many rows each horse has:
entries_per_horse_df = enriched_data.groupBy("horse_id").count()

# 2) Calculate the average of those counts:
mean_entries_df = entries_per_horse_df.agg(avg(col("count")).alias("mean_entries_per_horse"))

mean_entries_df.show()

+----------------------+
|mean_entries_per_horse|
+----------------------+
|     32.35369424370795|
+----------------------+



In [69]:
from pyspark.sql.functions import median, col, count

# 1) Count how many rows each horse has:
entries_per_horse_df = enriched_data.groupBy("horse_id").count()

# 2) Calculate the average of those counts:
median_entries_df = entries_per_horse_df.agg(median(col("count")).alias("median_entries_per_horse"))

median_entries_df.show()

+------------------------+
|median_entries_per_horse|
+------------------------+
|                    24.0|
+------------------------+



In [71]:
primary_keys = ["course_cd", "race_date", "race_number", "saddle_cloth_number", "sectionals_gate_name"]

duplicates = enriched_data.groupBy(*primary_keys) \
                   .agg(count("*").alias("cnt")) \
                   .filter(col("cnt") > 1)
dup_count = duplicates.count()

print(dup_count)

0


In [72]:
from pyspark.sql.functions import when, col, sum as F_sum, count

# 1. Create indicator columns for missingness
df_missing_indicators = enriched_data.select(
    "*",  # keep all original cols
    (col("acceleration_m_s2").isNull().cast("int")).alias("accel_missing"),
    (col("gps_section_avg_stride_freq").isNull().cast("int")).alias("stridefreq_missing"),
    (col("prev_speed").isNull().cast("int")).alias("prev_speed_missing"),
    (col("sectionals_distance_back").isNull().cast("int")).alias("dist_back_missing"),
    (col("sectionals_number_of_strides").isNull().cast("int")).alias("strides_missing"),
    (col("weather").isNull().cast("int")).alias("weather_missing")
)

# 2. Group by race or by horse (depending on your analysis needs)
#    Example: Group by a "race-level" ID (course_cd, race_date, race_number).
#    Then sum each missing indicator and also count total rows:

missing_by_race = (
    df_missing_indicators
    .groupBy("course_cd", "race_date", "race_number")
    .agg(
        F_sum("accel_missing").alias("accel_missing_sum"),
        F_sum("stridefreq_missing").alias("stridefreq_missing_sum"),
        F_sum("prev_speed_missing").alias("prev_speed_missing_sum"),
        F_sum("dist_back_missing").alias("dist_back_missing_sum"),
        F_sum("strides_missing").alias("strides_missing_sum"),
        F_sum("weather_missing").alias("weather_missing_sum"),
        count("*").alias("race_row_count")
    )
    .orderBy("accel_missing_sum", ascending=False)
)

missing_by_race.show(50, truncate=False)

# Similarly, you could group by (horse_id) or (course_cd, race_date, race_number, saddle_cloth_number).
# For instance, if you suspect some *horses* always lack data:

missing_by_horse = (
    df_missing_indicators
    .groupBy("horse_id")
    .agg(
        F_sum("accel_missing").alias("accel_missing_sum"),
        # ...
        count("*").alias("horse_row_count")
    )
    .orderBy("accel_missing_sum", ascending=False)
)
missing_by_horse.show(50, truncate=False)

                                                                                

+---------+----------+-----------+-----------------+----------------------+----------------------+---------------------+-------------------+-------------------+--------------+
|course_cd|race_date |race_number|accel_missing_sum|stridefreq_missing_sum|prev_speed_missing_sum|dist_back_missing_sum|strides_missing_sum|weather_missing_sum|race_row_count|
+---------+----------+-----------+-----------------+----------------------+----------------------+---------------------+-------------------+-------------------+--------------+
|TWO      |2022-11-06|7          |14               |14                    |14                    |0                    |0                  |0                  |252           |
|TWO      |2024-10-05|6          |13               |13                    |13                    |0                    |0                  |0                  |286           |
|TWO      |2024-10-05|4          |13               |13                    |13                    |0                    |

+--------+-----------------+---------------+
|horse_id|accel_missing_sum|horse_row_count|
+--------+-----------------+---------------+
|30899   |19               |243            |
|44568   |18               |234            |
|11793   |17               |256            |
|45136   |17               |226            |
|15091   |16               |221            |
|1258    |16               |217            |
|2329    |16               |203            |
|27865   |16               |218            |
|15485   |16               |224            |
|85831   |15               |207            |
|12045   |15               |201            |
|54019   |15               |193            |
|24952   |15               |206            |
|39507   |15               |231            |
|24980   |15               |192            |
|43730   |15               |208            |
|20094   |15               |192            |
|16016   |14               |219            |
|43295   |14               |203            |
|19642   |



In [76]:
from pyspark.sql.functions import when, col, sum as F_sum, count, lit
enriched_data = enriched_data.withColumn(
    "label",
    when(col("official_fin") == 1, lit(0))                          # Win
    .when(col("official_fin") == 2, lit(1))                         # Place
    .when(col("official_fin") == 3, lit(2))                         # Show
    .when(col("official_fin") == 4, lit(3))                         # Fourth
    .otherwise(lit(4))                                              # Outside top-4
)    


In [78]:
enriched_data.groupBy("label").count().show()

+-----+------+
|label| count|
+-----+------+
|    1|137802|
|    3|137777|
|    4|486781|
|    2|137838|
|    0|138485|
+-----+------+



In [6]:
total_sectionals_horses = sectionals.select("course_cd", "race_date", 
                                            "race_number", "saddle_cloth_number").distinct().count()

total_gps_horses = gpspoint.select("course_cd", "race_date", 
                                   "race_number", "saddle_cloth_number").distinct().count()

print("Distinct combos in sectionals:", total_sectionals_horses)
print("Distinct combos in gpspoint:", total_gps_horses)

                                                                                

Distinct combos in sectionals: 79098
Distinct combos in gpspoint: 139500


In [None]:
from pyspark.sql.functions import col, min as spark_min, sum as spark_sum, udf
from pyspark.sql.window import Window
from pyspark.sql.types import TimestampType
from datetime import timedelta

# UDF to add seconds (including fractional seconds)
def add_seconds(ts, seconds):
    if ts is None or seconds is None:
        return None
    return ts + timedelta(seconds=seconds)

add_seconds_udf = udf(add_seconds, TimestampType())

# Define your race identification columns
race_id_cols = ["course_cd", "race_date", "race_number", "saddle_cloth_number"]


In [None]:
######################################
# 1) Compute earliest GPS time per race and horse
######################################
gps_earliest_df = gpspoint.groupBy(*race_id_cols).agg(
    spark_min("time_stamp").alias("earliest_time_stamp_gps")
)

print("Earliest GPS time per race and horse computed.")
gps_earliest_df.printSchema()
gps_earliest_df.show(10, truncate=False)

In [None]:

######################################
# 2) Sort the sectional data by gate_index
######################################
# Just ensure sectionals are ordered by gate_index. 
# We'll do cumulative sums after we join with earliest_time.
sectionals_sorted = sectionals.orderBy(*race_id_cols, "gate_index")

print("Sectionals sorted by gate_index.")
sectionals_sorted.printSchema()
sectionals_sorted.select(*race_id_cols, "gate_index").show(10, truncate=False)

In [None]:

######################################
# 3) Join the earliest GPS time with sectionals
######################################
sectionals_with_earliest = sectionals_sorted.join(
    gps_earliest_df,
    on=race_id_cols,
    how="left"
)

print("Joined sectionals with earliest GPS time:")
sectionals_with_earliest.printSchema()
#sectionals_with_earliest.show(10, truncate=False)
sectionals_with_earliest.select(*race_id_cols, "gate_index", "earliest_time_stamp_gps").show(10, truncate=False)


In [None]:
######################################
# 4) Compute sec_time_stamp in sectionals
#
# First, compute cumulative_sectional_time per race/horse ordered by gate_index.
######################################
window_spec = Window.partitionBy(*race_id_cols).orderBy("gate_index").rowsBetween(Window.unboundedPreceding, 0)
sectionals_with_cum = sectionals_with_earliest.withColumn(
    "cumulative_sectional_time",
    spark_sum("sectionals_sectional_time").over(window_spec)
)

# Now add earliest_time_stamp_gps to cumulative_sectional_time to get sec_time_stamp
sectionals_with_sec_time = sectionals_with_cum.withColumn(
    "sec_time_stamp",
    add_seconds_udf(col("earliest_time_stamp_gps"), col("cumulative_sectional_time"))
)

print("Computed sec_time_stamp by adding cumulative_sectional_time to earliest GPS timestamp.")
sectionals_with_sec_time.printSchema()
#sectionals_with_sec_time.show(10, truncate=False)
# This next on works but removing earliest_time_stamp_gps for space reasons
#sectionals_with_sec_time.select(*race_id_cols, "gate_index", "earliest_time_stamp_gps",
#                               "sec_time_stamp", "sectionals_sectional_time").show(10, truncate=False)
    
sectionals_with_sec_time.select(*race_id_cols, "gate_index", "sec_time_stamp", "sectionals_sectional_time").show(10, truncate=False)    

In [None]:




######################################
# 5) Create a temporary view to inspect the data
#
# We'll select a subset of columns that are essential for inspection:
# course_cd, race_date, race_number, saddle_cloth_number, gate_name, gate_index, sec_time_stamp
######################################
view_df = sectionals_with_sec_time.select(
    "course_cd", "race_date", "race_number", "saddle_cloth_number", "sectionals_gate_name", "sectionals_sectional_time","gate_index", "sec_time_stamp"
).orderBy(*race_id_cols, "gate_index")

view_name = "sectionals_with_sec_time_view"
view_df.createOrReplaceTempView(view_name)

print(f"Temporary view '{view_name}' created. You can run SQL queries like:")
print(f"SELECT * FROM {view_name} WHERE course_cd='XYZ' AND race_number=123 AND saddle_cloth_number='A1' ORDER BY gate_index;")

######################################
# Verification Step:
# At this point, you can run:
#
# spark.sql(f"SELECT * FROM {view_name} WHERE course_cd='...' AND race_number=... AND saddle_cloth_number='...' ORDER BY gate_index").show(50, truncate=False)
#
# to verify that sec_time_stamp increments by sectional_time as expected.
######################################

In [None]:
spark.sql(f"SELECT saddle_cloth_number, sectionals_gate_name, \
gate_index, \
sec_time_stamp, \
sectionals_sectional_time \
FROM {view_name} \
WHERE course_cd='LRL' \
AND race_date = '2022-07-30' \
AND race_number = 5 \
GROUP BY course_cd, race_date, race_number, saddle_cloth_number, sectionals_gate_name, \
gate_index, sec_time_stamp, gate_index, sectionals_sectional_time \
ORDER BY saddle_cloth_number, gate_index").show(50, truncate=False)

In [None]:
spark.sql(f"SELECT course_cd, race_date, race_number, saddle_cloth_number, \
COUNT(*) AS num_gates \
FROM {view_name} \
GROUP BY course_cd, race_date, race_number, saddle_cloth_number \
ORDER BY course_cd, race_date, race_number, saddle_cloth_number").show(50, truncate=False)

In [None]:
################################################
#  6) To aggregate GPS data for the interval leading up to each gate, you need a 
# start_time and an end_time for that interval. A common approach:
#  1. start_time for gate i: sec_time_stamp from the previous gate 
#    (or the earliest GPS timestamp if it’s the first gate).
#  2. end_time for gate i: sec_time_stamp of gate i itself.
#
#  1A. Sort Each Horse’s Sectionals by gate_index
#  In 5) above I already sorted with:
# window_spec = Window.partitionBy(*race_id_cols).orderBy("gate_index").rowsBetween(Window.unboundedPreceding, 0)
# 
# 1B. Use a lag Window Function to Get start_time
#
#  For each row (gate i), define the “start” as the sec_time_stamp of the previous gate i-1:

from pyspark.sql.window import Window
from pyspark.sql.functions import lag, col, when

race_id_cols = ["course_cd", "race_date", "race_number", "saddle_cloth_number"]

# Window for each horse, ordered by gate_index
w = Window.partitionBy(*race_id_cols).orderBy("gate_index")

sectionals_intervals = sectionals_with_sec_time \
    .withColumn("start_time",
        lag("sec_time_stamp").over(w)
    )

# If start_time is null (i.e., this is the first gate), default to earliest_time_stamp_gps 
# or the same sec_time_stamp
# whichever logic you prefer:

sectionals_intervals = sectionals_intervals.withColumn(
    "start_time",
    when(col("start_time").isNull(), col("earliest_time_stamp_gps"))
    .otherwise(col("start_time"))
)

# The "end_time" is simply this row's sec_time_stamp
sectionals_intervals = sectionals_intervals.withColumn("end_time", col("sec_time_stamp"))

sectionals_intervals.select(
    *race_id_cols, "gate_index", "start_time", "end_time"
).show(10, truncate=False)

In [None]:
##############################################
# 	1.	Join on (course_cd, race_date, race_number, saddle_cloth_number) 
#       so we only consider the correct horse in the correct race.
# 	2.	Filter where gpspoint.time_stamp is between start_time and end_time.
##############################################

interval_join = gpspoint.join(
    sectionals_intervals,
    on=race_id_cols,
    how="left"  # or 'left', if you want all intervals even if no GPS data
).filter(
    (col("time_stamp") >= col("start_time")) &
    (col("time_stamp") <= col("end_time"))
)

In [None]:
############################################################################################
# After the filter, you have all rows that satisfy:
# 	•	(course_cd, race_date, race_number, saddle_cloth_number) is the same
# 	•	time_stamp is in [start_time, end_time]
#
# Now you can groupBy the unique ID of each gate interval. Typically, that’s:
# 	•	(course_cd, race_date, race_number, saddle_cloth_number, gate_index)
#
# Then compute your aggregates:
# 	1.	avg_speed: average of gpspoint.speed
# 	2.	avg_stride_freq: average of gpspoint.stride_frequency
# 	3.	gps_first_progress: the first progress in that interval
# 	4.	gps_last_progress: the last progress in that interval
# 	5.	gps_first_longitude, gps_first_latitude if desired
# 	6.	gps_last_longitude, gps_last_latitude
# 	7.	Possibly, max_speed, min_speed, etc.
############################################################################################

from pyspark.sql.functions import (
    mean as spark_mean, max as spark_max, min as spark_min, first, last, count
)

aggregated = interval_join.groupBy(
    *race_id_cols, "gate_index"
).agg(
    spark_mean("speed").alias("gps_section_avg_speed"),
    spark_mean("stride_frequency").alias("gps_section_avg_stride_freq"),
    first("progress").alias("gps_first_progress"),
    last("progress").alias("gps_last_progress"),
    first("longitude").alias("gps_first_longitude"),
    first("latitude").alias("gps_first_latitude"),
    last("longitude").alias("gps_last_longitude"),
    last("latitude").alias("gps_last_latitude"),
    first("location").alias("gps_first_location"),
    last("location").alias("gps_last_location"),   
    # add more aggregates if needed
    count("*").alias("gps_num_points")  # how many points fell in the interval
)

In [None]:
aggregated.count()

In [None]:
############################################################################################
#
# Rejoin Aggregates Back to Sectionals
#
############################################################################################
final_df = sectionals_intervals.join(
    aggregated,
    on=[*race_id_cols, "gate_index"],  # same grouping key
    how="left"  # or 'inner' if you only want intervals that had GPS data
)

In [None]:
final_df.count()

In [None]:
############################################################################################
#
# How Many Intervals Have No GPS Data?
#
############################################################################################

missing_count = final_df.filter("gps_num_points IS NULL").count()
print("Number of intervals with no GPS data:", missing_count)

In [None]:
final_df.count() - aggregated.count()

In [None]:
# To see a handful of intervals that have no GPS:
final_df.select(
    "course_cd", "race_date", "race_number", "saddle_cloth_number", "gate_index",
    "gps_num_points", "gps_section_avg_speed", "gps_section_avg_stride_freq"
).filter("gps_num_points IS NULL").show(20, truncate=False)

In [None]:
############################################################################################
# You expect one row in final_df for each (course_cd, race_date, race_number, 
# saddle_cloth_number, gate_index) if your aggregator and join logic are correct. 
# Confirm that no duplicates crept in:
############################################################################################

duplicates_df = final_df.groupBy(
    "course_cd", "race_date", "race_number", "saddle_cloth_number", "gate_index"
).count().filter("count > 1")

duplicates_df.show(20, truncate=False)

In [None]:
# It can be helpful to see how many GPS points are typically aggregated per gate interval. 
# For example:

from pyspark.sql.functions import col

final_df.select("gps_num_points") \
        .describe() \
        .show()

# Or a quick distribution:
final_df.groupBy("gps_num_points").count().orderBy(col("gps_num_points").asc()).show(50)

In [None]:
final_df.count()

In [None]:
############################################################################################
#
# To verify if the aggregator columns (gps_section_avg_speed, 
# gps_section_avg_stride_freq, etc.) are missing in intervals that do have gps_num_points:
#
############################################################################################

final_df.select([
    "gps_section_avg_speed",
    "gps_section_avg_stride_freq",
    "gps_first_progress",
    "gps_last_progress"
]).summary().show()


In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, when, col

race_id_cols = ["course_cd", "race_date", "race_number", "saddle_cloth_number"]

# Sort by gate_index for each horse
w = Window.partitionBy(*race_id_cols).orderBy("gate_index")

sectionals_intervals = sectionals_with_sec_time \
    .withColumn("start_time", lag("sec_time_stamp").over(w)) \
    .withColumn(
        "start_time",
        when(col("start_time").isNull(), col("earliest_time_stamp_gps"))   # First gate starts at earliest GPS time
        .otherwise(col("start_time"))
    ) \
    .withColumn("end_time", col("sec_time_stamp"))

In [None]:
interval_join = gpspoint.join(
    sectionals_intervals,
    on=race_id_cols,  # (course_cd, race_date, race_number, saddle_cloth_number)
    how="inner"       # or 'left' if you also want intervals even if no GPS
).filter(
    (col("time_stamp") >= col("start_time")) &
    (col("time_stamp") <= col("end_time"))
)

In [None]:
from pyspark.sql.functions import (
    mean as spark_mean, first, last, count
)

aggregated = interval_join.groupBy(*race_id_cols, "gate_index").agg(
    spark_mean("speed").alias("gps_section_avg_speed"),
    spark_mean("stride_frequency").alias("gps_section_avg_stride_freq"),
    first("progress").alias("gps_first_progress"),
    last("progress").alias("gps_last_progress"),
    first("longitude").alias("gps_first_longitude"),
    first("latitude").alias("gps_first_latitude"),
    last("longitude").alias("gps_last_longitude"),
    last("latitude").alias("gps_last_latitude"),
    count("*").alias("gps_num_points")   # e.g. number of GPS rows in that interval
)

In [None]:
final_df = sectionals_intervals.join(
    aggregated,
    on=[*race_id_cols, "gate_index"],
    how="left"       # Use 'left' so intervals with no matches still appear
)

In [None]:
no_gps_df = final_df.filter(col("gps_num_points").isNull())
print("Number of intervals with no GPS:", no_gps_df.count())
no_gps_df.show(20, truncate=False)

In [None]:
final_df.filter("gps_num_points IS NULL").count()

In [None]:
missing_gps_count = final_df.filter(col("earliest_time_stamp_gps").isNull()).count()
print("Number of intervals with missing earliest_time_stamp_gps:", missing_gps_count)

final_df.filter(col("earliest_time_stamp_gps").isNull()) \
        .select(*race_id_cols) \
        .distinct() \
        .show(50, truncate=False)

In [None]:
final_df.groupBy("gps_num_points").count().orderBy("gps_num_points").show(50)

In [None]:
# If your gps_num_points is NULL or 0, you want to mark gps_coverage=False, otherwise True:
from pyspark.sql.functions import when, col, lit

final_df_with_flag = final_df.withColumn(
    "gps_coverage",
    when((col("gps_num_points").isNotNull()) & (col("gps_num_points") > 0), lit(True))
    .otherwise(lit(False))
)

In [None]:
# Check a small sample:

final_df_with_flag.select(
    "gps_num_points", "gps_coverage").show(20, truncate=False)

In [None]:
total_sectionals_horses = sectionals.select("course_cd", "race_date", 
                                    "race_number", "saddle_cloth_number").distinct().count()

total_gps_horses = gpspoint.select("course_cd", "race_date", \
                                "race_number", "saddle_cloth_number").distinct().count()

print("Distinct combos in sectionals:", total_sectionals_horses)
print("Distinct combos in gpspoint:", total_gps_horses)

In [None]:
merged_df.printSchema()

In [None]:
routes.printSchema()

In [None]:
def calculate_average_speed(enriched_df):
    """
    Calculates the average speed for each horse during a race.
    Uses sum of distance_m and sum of time_diff_s computed per row by calculate_instantaneous_speed.
    """
    # Define the race columns and ordering
    race_cols = ["course_cd", "race_date", "race_number", "horse_id"]

    # Window specification partitioned by race and ordered by gate
    window_spec = Window.partitionBy(*race_cols).orderBy("sectionals_gate_numeric")

    # Compute per-sectional average speed (sec_avg_spd)
    enriched_df = enriched_df.withColumn(
        "sec_avg_spd",
        (col("sectionals_distance_ran") / col("sectionals_sectional_time"))
    )

    # Compute cumulative distance and cumulative time up to the current gate
    enriched_df = enriched_df.withColumn(
        "cum_distance",
        spark_sum("sectionals_distance_ran").over(window_spec)
    ).withColumn(
        "cum_time",
        spark_sum("sectionals_sectional_time").over(window_spec)
    )

    # Compute cumulative average speed (cum_avg_spd)
    enriched_df = enriched_df.withColumn(
        "cum_avg_spd",
        (col("cum_distance") / col("cum_time"))
    )

    # Compute speed change (spd_chng) as difference between sec_avg_spd and cum_avg_spd
    enriched_df = enriched_df.withColumn(
        "spd_chng",
        col("sec_avg_spd") - col("cum_avg_spd")
    )    
    
    enriched_df.drop("cum_distance", "cum_time")
    
    enriched_df = enriched_df.withColumn("sec_avg_spd", spark_round(col("sec_avg_spd"), 3)) \
                         .withColumn("cum_avg_spd", spark_round(col("cum_avg_spd"), 3)) \
                         .withColumn("spd_chng", spark_round(col("spd_chng"), 3))
                    
    return enriched_df

In [None]:
merged_df = calculate_average_speed(merged_df)

In [None]:
merged_df.select("sectionals_gate_numeric","sec_avg_spd", "cum_avg_spd", "spd_chng").show(30)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pyspark.sql.functions import col

def plot_distances(enriched_df):
    """
    Plots distances to running and winning lines for a selected race and horse.
    
    :param enriched_df: DataFrame enriched with metrics.
    :param selected_race: Dictionary specifying the race and horse to plot.
    """
    # Define the race and horse you want to visualize
    selected_race = {
        "course_cd": "TGP",               # Example racecourse code
        "race_date": "2024-09-15",        # Example race date
        "race_number": 2,                 # Example race number
        "horse_id": 898247                    # Example horse ID
    }

    # Filter for the specific race and horse
    filtered_df = enriched_df.filter(
        (col("course_cd") == selected_race["course_cd"]) &
        (col("race_date") == selected_race["race_date"]) &
        (col("race_number") == selected_race["race_number"]) &
        (col("horse_id") == selected_race["horse_id"])
    )
    
    
    # Select relevant columns and convert to Pandas
    plot_df = filtered_df.select(
        "sectionals_gate_numeric",
        "run_ln_str_dist_run",
        "run_ln_end_dist_run",
        "win_ln_str_dist_win",
        "win_ln_end_dist_win"
    ).toPandas()
    
    # Plotting
    plt.figure(figsize=(14, 7))
    
    # Running Line Distances
    plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_str_dist_run'], marker='o', label='Running Line Start Distance', color='blue')
    plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_end_dist_run'], marker='o', label='Running Line End Distance', color='blue', linestyle='--')
    
    # Winning Line Distances
    plt.plot(plot_df['sectionals_gate_numeric'], plot_df['win_ln_str_dist_win'], marker='s', label='Winning Line Start Distance', color='red')
    plt.plot(plot_df['sectionals_gate_numeric'], plot_df['win_ln_end_dist_win'], marker='s', label='Winning Line End Distance', color='red', linestyle='--')
    
    plt.xlabel('Gate Order')
    plt.ylabel('Distance to Line (m)')
    plt.title(f'Distances to Running and Winning Lines for Horse ID {selected_race["horse_id"]} in Race {selected_race["race_number"]} on {selected_race["race_date"]}')
    plt.legend()
    plt.show()

In [None]:
plot_distances(distance_pivot_df)

In [None]:
def calculate_distance_to_lines(spark, df1, routes):
    """
    Calculates distances from GPS points (start and end of sectionals) to both the running and winning lines.
    """
    # Ensure coordinates are converted to geometry
    routes = routes.withColumn(
        "coordinates_geom",
        when(col("coordinates").rlike("^[0-9A-F]+$"), expr("ST_GeomFromWKB(unhex(coordinates))"))
        .otherwise(expr("ST_GeomFromWKT(coordinates)"))
    )

    # Verify schemas
    #df1.printSchema()
    #routes.printSchema()

    # Create temporary views
    df1.createOrReplaceTempView("df1")
    routes.createOrReplaceTempView("routes")

    # SQL query to calculate distances
    distance_df = spark.sql("""
        SELECT 
            g.course_cd, 
            g.race_date, 
            g.race_number, 
            g.saddle_cloth_number,
            g.sectionals_gate_numeric,
            g.horse_id,
            r.line_type,
            ST_Distance(g.gps_first_location_geom, r.coordinates_geom) AS str_dist_to_line,
            ST_Distance(g.gps_last_location_geom, r.coordinates_geom) AS end_dist_to_line
        FROM 
            df1 g
        JOIN 
            routes r 
        ON 
            g.course_cd = r.course_cd
        WHERE 
            r.line_type IN ('RUNNING_LINE', 'WINNING_LINE')
    """)

    # Pivot distances by line_type
    distance_pivot_df = distance_df.groupBy(
        "course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id", "sectionals_gate_numeric"
    ).pivot("line_type", ["RUNNING_LINE", "WINNING_LINE"]).agg(
        first("str_dist_to_line").alias("str_dist_to_run_line"),
        first("end_dist_to_line").alias("end_dist_to_run_line"),
        first("str_dist_to_line").alias("str_dist_to_win_line"),
        first("end_dist_to_line").alias("end_dist_to_win_line")  # Corrected from `end_dis_to_line`
    )
    
    # Dictionary to map old column names to new shorter names
    rename_mapping = {
        "RUNNING_LINE_str_dist_to_run_line": "run_ln_str_dist_run",
        "RUNNING_LINE_end_dist_to_run_line": "run_ln_end_dist_run",
        "RUNNING_LINE_str_dist_to_win_line": "run_ln_str_dist_win",
        "RUNNING_LINE_end_dist_to_win_line": "run_ln_end_dist_win",
        "WINNING_LINE_str_dist_to_run_line": "win_ln_str_dist_win",
        "WINNING_LINE_end_dist_to_run_line": "win_ln_end_dist_win",
        "WINNING_LINE_str_dist_to_win_line": "win_line_str_dist_win",
        "WINNING_LINE_end_dist_to_win_line": "win_line_end_dist_win",
    }

    # Rename columns using withColumnRenamed
    for old_name, new_name in rename_mapping.items():
        distance_pivot_df = distance_pivot_df.withColumnRenamed(old_name, new_name)

    # Print schema for debugging
    distance_pivot_df.printSchema()
    
    return distance_pivot_df

In [None]:
distance_pivot_df = calculate_distance_to_lines(spark, merged_df, routes)

In [None]:
distance_pivot_df.select('sectionals_gate_numeric').show().sort()

In [None]:
from pyspark.sql.functions import col

# Define the race and horse you want to visualize
selected_race = {
    "course_cd": "TGP",               # Example racecourse code
    "race_date": "2024-09-15",        # Example race date
    "race_number": 2,                 # Example race number
    "horse_id": 898247                    # Example horse ID
}

# Apply the filter
filtered_df = distance_pivot_df.filter(
    (col("course_cd") == selected_race["course_cd"]) &
    (col("race_date") == selected_race["race_date"]) &
    (col("race_number") == selected_race["race_number"]) &
    (col("horse_id") == selected_race["horse_id"])
)

In [None]:
import pandas
import matplotlib.pyplot as plt
import seaborn as sns

# Select the relevant columns
plot_df = filtered_df.select(
    "sectionals_gate_numeric",
    "run_ln_str_dist_run",
    "run_ln_end_dist_run",
    "run_ln_str_dist_win",
    "run_ln_end_dist_win",
    "win_ln_str_dist_win",
    "win_ln_end_dist_win"
).toPandas()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Optional: Enhance plot aesthetics
sns.set(style="whitegrid")

In [None]:
plt.figure(figsize=(12, 6))

# Plot start and end distances to the running line
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_str_dist_run'], marker='o', label='Running Line Start Distance')
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_end_dist_run'], marker='o', label='Running Line End Distance')

plt.xlabel('Sectional Gate Numeric')
plt.ylabel('Distance to Running Line (m)')
plt.title('Distances to Running Line per Sectional Gate')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))

# Plot start and end distances to the winning line
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_str_dist_win'], marker='o', label='Winning Line Start Distance')
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_end_dist_win'], marker='o', label='Winning Line End Distance')

plt.xlabel('Sectional Gate Numeric')
plt.ylabel('Distance to Winning Line (m)')
plt.title('Distances to Winning Line per Sectional Gate')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))

# Plot start and end distances to the running line
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_str_dist_run'], marker='o', label='Running Line Start Distance')
plt.plot(plot_df['sectionals_gate_numeric'], plot_df['run_ln_end_dist_run'], marker='o', label='Running Line End Distance')

plt.xlabel('Sectional Gate Numeric')
plt.ylabel('Distance to Running Line (m)')
plt.title('Distances to Running Line per Sectional Gate')
plt.legend()
plt.show()

In [None]:
def register_temp_views(merged_df, routes):
    """
    Registers DataFrames as temporary SQL views.

    :param gps_df: DataFrame containing GPS data.
    :param routes_df: DataFrame containing Routes data.
    """
    merged_df.createOrReplaceTempView("merged_df")
    routes.createOrReplaceTempView("routes")
    print("DataFrames registered as temporary views: 'merged_df' and 'routes'.")

In [None]:
def calculate_distance_to_lines(spark):
    """
    Calculates distance from each GPS point to the running and winning lines.

    :param spark: SparkSession object.
    :return: DataFrame with distance metrics.
    """
    distance_df = spark.sql("""
        SELECT 
            g.course_cd, 
            g.race_date, 
            g.race_number, 
            g.saddle_cloth_number, 
            g.horse_id,
            r.line_type,
            ST_Distance(g.geometry, r.route_geometry) AS distance_to_line_m
        FROM 
            merged_df g
        JOIN 
            routes r 
        ON 
            g.course_cd = r.course_cd
        WHERE 
            r.line_type IN ('RUNNING_LINE', 'WINNING_LINE')
    """)

    # Pivot the distance metrics
    from pyspark.sql.functions import first
    distance_pivot_df = distance_df.groupBy(
        "course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"
    ).pivot("line_type", ["RUNNING_LINE", "WINNING_LINE"]).agg(first("distance_to_line_m"))

    # Rename columns for clarity
    distance_pivot_df = distance_pivot_df.withColumnRenamed("RUNNING_LINE", "distance_to_running_line_m") \
                                         .withColumnRenamed("WINNING_LINE", "distance_to_winning_line_m")

    return distance_pivot_df

In [None]:
def integrate_distance_metrics(main_df, distance_pivot_df):
    """
    Joins the main GPS DataFrame with distance metrics.

    :param main_df: DataFrame containing main GPS data.
    :param distance_pivot_df: DataFrame containing distance metrics.
    :return: Enriched DataFrame with distance metrics.
    """
    enriched_df = main_df.join(
        distance_pivot_df,
        on=["course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"],
        how="left"
    )
    return enriched_df

In [None]:
def calculate_and_integrate_distances(spark, parquet_dir):
    # Load data
    merged_df, routes = load_parquet_data(spark, parquet_dir)
    
    # Inspect data
    inspect_data(merged_df, routes)
    
    # Register temporary views
    register_temp_views(merged_df, routes)
    
    # Calculate distance metrics
    distance_pivot_df = calculate_distance_to_lines(spark)
    
    # Integrate with main DataFrame
    enriched_df = integrate_distance_metrics(merged_df, distance_pivot_df)
    
    # Show enriched data
    print("Enriched Data with Distance Metrics:")
    enriched_df.show(5, truncate=False)
    
    # Optionally, write enriched data back to Parquet
    enriched_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "enriched_gpspoint.parquet"))
    print("Enriched GPS data written to 'enriched_gpspoint.parquet'.")
    
    return enriched_df

In [None]:
def main():
    try:
        # Initialize Spark session
        jdbc_driver_path = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/jdbc/postgresql-42.7.4.jar"
        sedona_jar_abs_path = "/home/exx/sedona/apache-sedona-1.7.0-bin/sedona-spark-shaded-3.4_2.12-1.7.0.jar"
        
        # Paths to GeoTools JAR files
        geotools_jar_paths = [
            "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/geotools-wrapper-1.1.0-25.2.jar",
            "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/sedona-python-adapter-3.0_2.12-1.2.0-incubating.jar",
            "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/sedona-viz-3.0_2.12-1.2.0-incubating.jar",
        ]
        
        # Initialize Spark session
        spark = initialize_spark(jdbc_driver_path, sedona_jar_abs_path, geotools_jar_paths)
        
        # Test Sedona integration
        test_sedona_integration(spark)
        
        # Define paths
        parquet_dir = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/data/parquet/"
        
        # Create dummy Parquet files if needed
        # create_dummy_parquet_files(parquet_dir, spark)
        
        # Calculate and integrate distance and speed metrics
        enriched_df = calculate_and_integrate_metrics(spark, parquet_dir)
        
        # Additional metrics can be calculated here...
        
    except Exception as e:
        print(f"An error occurred during processing: {e}")
    finally:
        if 'spark' in locals():
            spark.stop()
            print("Spark session stopped.")

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, unix_timestamp, when, col

def calculate_instantaneous_speed(spark, enriched_df):
    """
    Calculates instantaneous speed between consecutive GPS points for each horse.

    :param spark: SparkSession object.
    :param enriched_df: DataFrame enriched with distance metrics.
    :return: DataFrame with speed metrics.
    """
    # Add timestamp column if not present
    # Assuming you have a 'timestamp' column. If not, you'll need to add it.
    # For this example, we'll assume it's present.
    
    # Define window specification
    window_spec = Window.partitionBy("horse_id").orderBy("race_date", "race_number", "saddle_cloth_number", "timestamp")
    
    # Calculate previous geometry and timestamp
    enriched_df = enriched_df.withColumn("prev_geom", lag("geometry").over(window_spec)) \
                             .withColumn("prev_timestamp", lag("timestamp").over(window_spec))
    
    # Calculate distance between current and previous points
    enriched_df = enriched_df.withColumn("distance_m", 
        when(col("prev_geom").isNotNull(), 
             expr("ST_Distance(geometry, prev_geom)")) \
        .otherwise(0)
    )
    
    # Calculate time difference in seconds
    enriched_df = enriched_df.withColumn("time_diff_s", 
        when(col("prev_timestamp").isNotNull(),
             (unix_timestamp(col("timestamp")) - unix_timestamp(col("prev_timestamp")))
        ).otherwise(0)
    )
    
    # Calculate speed in m/s
    enriched_df = enriched_df.withColumn("speed_m_s", 
        when(col("time_diff_s") > 0, col("distance_m") / col("time_diff_s"))
        .otherwise(0)
    )
    
    return enriched_df

In [None]:
def calculate_and_integrate_metrics(spark, parquet_dir):
    # Load data
    merged_df, routes = load_parquet_data(spark, parquet_dir)
    
    # Inspect data
    inspect_data(merged_df, routes)
    
    # Register temporary views
    register_temp_views(merged_df, routes)
    
    # Calculate distance metrics
    distance_pivot_df = calculate_distance_to_lines(spark)
    
    # Integrate with main DataFrame
    enriched_df = integrate_distance_metrics(merged_df, distance_pivot_df)
    
    # Calculate instantaneous speed
    enriched_df = calculate_instantaneous_speed(spark, enriched_df)
    
    # Show enriched data with speed metrics
    print("Enriched Data with Distance and Speed Metrics:")
    enriched_df.show(5, truncate=False)
    
    # Optionally, write enriched data back to Parquet
    enriched_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "enriched_gpspoint_with_speed.parquet"))
    print("Enriched GPS data with speed metrics written to 'enriched_gpspoint_with_speed.parquet'.")
    
    return enriched_df

In [None]:
def calculate_average_speed(enriched_df):
    """
    Calculates average speed for each horse in each race.

    :param enriched_df: DataFrame enriched with distance and speed metrics.
    :return: DataFrame with average speed per horse per race.
    """
    average_speed_df = enriched_df.groupBy(
        "course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"
    ).agg(
        (sum("distance_m") / sum("time_diff_s")).alias("average_speed_m_s")
    )
    
    return average_speed_df

In [None]:
from pyspark.sql.functions import sum

def calculate_and_integrate_metrics(spark, parquet_dir):
    # Load data
    merged_df, routes = load_parquet_data(spark, parquet_dir)
    
    # Inspect data
    inspect_data(merged_df, routes)
    
    # Register temporary views
    register_temp_views(merged_df, routes)
    
    # Calculate distance metrics
    distance_pivot_df = calculate_distance_to_lines(spark)
    
    # Integrate with main DataFrame
    enriched_df = integrate_distance_metrics(merged_df, distance_pivot_df)
    
    # Calculate instantaneous speed
    enriched_df = calculate_instantaneous_speed(spark, enriched_df)
    
    # Calculate average speed
    average_speed_df = calculate_average_speed(enriched_df)
    
    # Show average speed data
    print("Average Speed per Horse per Race:")
    average_speed_df.show(5, truncate=False)
    
    # Optionally, write average speed data back to Parquet
    average_speed_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "average_speed.parquet"))
    print("Average speed data written to 'average_speed.parquet'.")
    
    return enriched_df, average_speed_df

In [None]:
from pyspark.sql.functions import lag

def calculate_instantaneous_acceleration(enriched_df):
    """
    Calculates instantaneous acceleration between consecutive speed measurements.

    :param enriched_df: DataFrame enriched with speed metrics.
    :return: DataFrame with acceleration metrics.
    """
    window_spec = Window.partitionBy("horse_id").orderBy("timestamp")
    
    # Lag the speed column to get previous speed
    enriched_df = enriched_df.withColumn("prev_speed_m_s", lag("speed_m_s").over(window_spec))
    
    # Calculate acceleration (delta_speed / delta_time)
    enriched_df = enriched_df.withColumn("acceleration_m_s2", 
        when(col("time_diff_s") > 0, 
             (col("speed_m_s") - col("prev_speed_m_s")) / col("time_diff_s")
        ).otherwise(0)
    )
    
    return enriched_df

In [None]:
def calculate_path_deviation(spark, enriched_df):
    """
    Calculates the deviation of each horse's path from the running line.

    :param spark: SparkSession object.
    :param enriched_df: DataFrame enriched with distance metrics.
    :return: DataFrame with path deviation metrics.
    """
    deviation_df = enriched_df.withColumn(
        "deviation_m",
        col("distance_to_running_line_m")
    )
    
    return deviation_df

In [None]:
def calculate_and_integrate_metrics(spark, parquet_dir):
    # Load data
    merged_df, routes = load_parquet_data(spark, parquet_dir)
    
    # Inspect data
    inspect_data(merged_df, routes)
    
    # Register temporary views
    register_temp_views(merged_df, routes)
    
    # Calculate distance metrics
    distance_pivot_df = calculate_distance_to_lines(spark)
    
    # Integrate with main DataFrame
    enriched_df = integrate_distance_metrics(merged_df, distance_pivot_df)
    
    # Calculate instantaneous speed
    enriched_df = calculate_instantaneous_speed(spark, enriched_df)
    
    # Calculate acceleration
    enriched_df = calculate_instantaneous_acceleration(enriched_df)
    
    # Calculate average speed
    average_speed_df = calculate_average_speed(enriched_df)
    
    # Calculate path deviation
    deviation_df = calculate_path_deviation(spark, enriched_df)
    
    # Show metrics
    print("Enriched Data with All Metrics:")
    enriched_df.show(5, truncate=False)
    
    print("Average Speed per Horse per Race:")
    average_speed_df.show(5, truncate=False)
    
    print("Path Deviation per GPS Point:")
    deviation_df.show(5, truncate=False)
    
    # Optionally, write metrics to Parquet
    enriched_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "enriched_gpspoint_all_metrics.parquet"))
    average_speed_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "average_speed.parquet"))
    deviation_df.write.mode("overwrite").parquet(os.path.join(parquet_dir, "path_deviation.parquet"))
    
    print("All metrics data written to respective Parquet files.")
    
    return enriched_df, average_speed_df, deviation_df

In [None]:
def validate_metrics(enriched_df):
    """
    Validates the correctness of calculated metrics.

    :param enriched_df: DataFrame enriched with metrics.
    """
    enriched_df.select(
        "course_cd",
        "race_date",
        "race_number",
        "saddle_cloth_number",
        "horse_id",
        "timestamp",
        "distance_m",
        "time_diff_s",
        "speed_m_s",
        "acceleration_m_s2",
        "deviation_m"
    ).show(5, truncate=False)

In [None]:
def handle_anomalies(enriched_df):
    """
    Identifies and handles anomalies in speed and acceleration metrics.

    :param enriched_df: DataFrame enriched with metrics.
    :return: Cleaned DataFrame.
    """
    # Example: Remove entries with speed > 20 m/s (arbitrary threshold)
    cleaned_df = enriched_df.filter(col("speed_m_s") <= 20)
    
    # Example: Remove entries with acceleration > 10 m/s^2
    cleaned_df = cleaned_df.filter(col("acceleration_m_s2") <= 10)
    
    return cleaned_df

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def visualize_average_speed(average_speed_df):
    """
    Plots average speed per horse.

    :param average_speed_df: DataFrame with average speed metrics.
    """
    # Convert to Pandas DataFrame
    pandas_df = average_speed_df.toPandas()
    
    plt.figure(figsize=(10, 6))
    sns.barplot(x='horse_id', y='average_speed_m_s', hue='race_number', data=pandas_df)
    plt.title('Average Speed per Horse per Race')
    plt.xlabel('Horse ID')
    plt.ylabel('Average Speed (m/s)')
    plt.legend(title='Race Number')
    plt.show()

In [None]:
from sedona.utils import SedonaKryoRegistrator, KryoSerializer
from sedona.sql.types import GeometryType

def visualize_horse_trajectory(spark, horse_id, parquet_dir):
    """
    Visualizes the trajectory of a specific horse.

    :param spark: SparkSession object.
    :param horse_id: ID of the horse to visualize.
    :param parquet_dir: Directory where Parquet files are stored.
    """
    # Load enriched data
    enriched_df = spark.read.parquet(os.path.join(parquet_dir, "enriched_gpspoint_all_metrics.parquet"))
    
    # Filter for the specific horse
    horse_df = enriched_df.filter(col("horse_id") == horse_id).orderBy("timestamp")
    
    # Collect geometries
    geometries = horse_df.select("geometry").rdd.map(lambda row: row[0]).collect()
    
    # Use GeoTools or other libraries to plot the trajectory
    # This requires integrating with a geospatial library in Python
    # Alternatively, export the data and use external tools like QGIS
    print(f"Collected {len(geometries)} geometries for horse {horse_id}.")
    
    # Example: Export to WKT for external visualization
    horse_df.select("timestamp", "geometry").write.csv(os.path.join(parquet_dir, f"horse_{horse_id}_trajectory.csv"), header=True)
    print(f"Trajectory data for horse {horse_id} exported to CSV.")

In [None]:
def cache_dataframes(enriched_df, average_speed_df):
    """
    Caches DataFrames that are accessed multiple times.

    :param enriched_df: Enriched GPS DataFrame.
    :param average_speed_df: Average Speed DataFrame.
    """
    enriched_df.cache()
    average_speed_df.cache()
    print("DataFrames cached for improved performance.")

In [None]:
from pyspark.sql.functions import broadcast

def optimize_joins(spark, main_df, lookup_df):
    """
    Optimizes joins by broadcasting the lookup DataFrame.

    :param spark: SparkSession object.
    :param main_df: Main DataFrame to join.
    :param lookup_df: Lookup DataFrame to broadcast.
    :return: Joined DataFrame.
    """
    joined_df = merged_df.join(broadcast(lookup_df), on="key_column", how="left")
    return joined_df

In [None]:
import os

def initialize_environment_secure():
    """
    Initializes the environment securely by loading credentials from environment variables.

    :return: Tuple containing SparkSession and connection parameters.
    """
    # Paths and configurations
    config_path = '/home/exx/myCode/horse-racing/FoxRiverAIRacing/config.ini'
    log_file = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/logs/SparkPy_load.log"
    jdbc_driver_path = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/jdbc/postgresql-42.7.4.jar"
    sedona_jar_abs_path = "/home/exx/sedona/apache-sedona-1.7.0-bin/sedona-spark-shaded-3.4_2.12-1.7.0.jar"
    
    # Paths to GeoTools JAR files
    geotools_jar_paths = [
        "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/geotools-wrapper-1.1.0-25.2.jar",
        "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/sedona-python-adapter-3.0_2.12-1.2.0-incubating.jar",
        "/home/exx/anaconda3/envs/mamba_env/envs/tf_310/lib/python3.10/site-packages/pyspark/jars/sedona-viz-3.0_2.12-1.2.0-incubating.jar",
    ]
    
    # Load configuration
    config = load_config(config_path)

    # Database credentials from config and environment variables
    db_host = config['database']['host']
    db_port = config['database']['port']
    db_name = config['database']['dbname']
    db_user = config['database']['user']
    db_password = os.getenv("DB_PASSWORD")  # Ensure DB_PASSWORD is set
    
    if not db_password:
        raise ValueError("Database password is missing. Set it in the DB_PASSWORD environment variable.")
    
    # JDBC URL and properties
    jdbc_url = f"jdbc:postgresql://{db_host}:{db_port}/{db_name}"
    jdbc_properties = {
        "user": db_user,
        "password": db_password,
        "driver": "org.postgresql.Driver"
    }
    
    # Initialize Spark session
    spark = initialize_spark(jdbc_driver_path, sedona_jar_abs_path, geotools_jar_paths)
    
    # Initialize logging
    initialize_logging(log_file)
    queries = sql_queries()
    
    return spark, jdbc_url, jdbc_properties, queries

In [None]:
import os
import logging
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_timestamp, sum as spark_sum, when, expr
from pyspark.sql.window import Window
from sedona.register import SedonaRegistrator
from sedona.utils import SedonaKryoRegistrator, KryoSerializer

def load_config(config_path):
    config = ConfigParser()
    config.read(config_path)
    return config

def initialize_logging(log_file):
    logging.basicConfig(
        filename=log_file,
        filemode='a',
        format='%(asctime)s - %(levelname)s - %(message)s',
        level=logging.INFO
    )
    logging.info("Logging initialized.")

def initialize_spark(jdbc_driver_path, sedona_jar_path, geotools_jar_paths):
    spark = SparkSession.builder \
        .appName("Horse Racing Data Processing") \
        .master("local[16]") \
        .config("spark.driver.memory", "48g") \
        .config("spark.executor.memory", "32g") \
        .config("spark.executor.memoryOverhead", "8g") \
        .config("spark.default.parallelism", "256") \
        .config("spark.sql.shuffle.partitions", "256") \
        .config("spark.jars", f"{jdbc_driver_path},{sedona_jar_path},{','.join(geotools_jar_paths)}") \
        .config("spark.serializer", KryoSerializer.getName) \
        .config("spark.kryo.registrator", SedonaKryoRegistrator.getName) \
        .config("spark.kryo.buffer", "64k") \
        .config("spark.kryo.buffer.max", "512m") \
        .config("spark.sql.parquet.datetimeRebaseModeInWrite", "LEGACY") \
        .config("spark.sql.parquet.int96RebaseModeInWrite", "LEGACY") \
        .config("spark.local.dir", "/path/to/nvme/storage") \
        .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
        .getOrCreate()
    
    # Register Sedona functions and types
    SedonaRegistrator.registerAll(spark)
    
    # Set log level to ERROR
    spark.sparkContext.setLogLevel("ERROR")
    print("Spark session created successfully with Sedona and GeoTools integrated.")
    return spark

def load_data_from_postgresql(spark, jdbc_url, table_name, jdbc_properties, output_path):
    try:
        df = spark.read \
            .format("jdbc") \
            .option("url", jdbc_url) \
            .option("dbtable", table_name) \
            .option("user", jdbc_properties["user"]) \
            .option("password", jdbc_properties["password"]) \
            .option("driver", jdbc_properties["driver"]) \
            .load()
        
        df.write.mode("overwrite").parquet(output_path)
        logging.info(f"Data from {table_name} loaded successfully to {output_path}.")
        print(f"Data from {table_name} loaded successfully to {output_path}.")
    except Exception as e:
        logging.error(f"Error loading data from PostgreSQL: {e}")
        print(f"Error loading data from PostgreSQL: {e}")

def load_parquet_data(spark, parquet_dir):
    gps_df = spark.read.parquet(os.path.join(parquet_dir, "gpspoint.parquet"))
    routes_df = spark.read.parquet(os.path.join(parquet_dir, "routes.parquet"))
    return gps_df, routes_df

def inspect_data(gps_df, routes_df):
    print("GPS Data Sample:")
    gps_df.show(5, truncate=False)
    
    print("Routes Data Sample:")
    routes_df.show(5, truncate=False)

def register_temp_views(gps_df, routes_df):
    gps_df.createOrReplaceTempView("gps_points")
    routes_df.createOrReplaceTempView("routes")
    print("DataFrames registered as temporary views: 'gps_points' and 'routes'.")

def calculate_distance_to_lines(spark):
    distance_df = spark.sql("""
        SELECT 
            g.course_cd, 
            g.race_date, 
            g.race_number, 
            g.saddle_cloth_number, 
            g.horse_id,
            r.line_type,
            ST_Distance(g.geometry, r.route_geometry) AS distance_to_line_m
        FROM 
            gps_points g
        JOIN 
            routes r 
        ON 
            g.course_cd = r.course_cd
        WHERE 
            r.line_type IN ('RUNNING_LINE', 'WINNING_LINE')
    """)

    from pyspark.sql.functions import first
    distance_pivot_df = distance_df.groupBy(
        "course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"
    ).pivot("line_type", ["RUNNING_LINE", "WINNING_LINE"]).agg(first("distance_to_line_m"))

    distance_pivot_df = distance_pivot_df.withColumnRenamed("RUNNING_LINE", "distance_to_running_line_m") \
                                         .withColumnRenamed("WINNING_LINE", "distance_to_winning_line_m")

    return distance_pivot_df

def integrate_distance_metrics(main_df, distance_pivot_df):
    enriched_df = main_df.join(
        distance_pivot_df,
        on=["course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"],
        how="left"
    )
    return enriched_df

def calculate_instantaneous_speed(spark, enriched_df):
    window_spec = Window.partitionBy("horse_id").orderBy("race_date", "race_number", "saddle_cloth_number", "timestamp")
    
    enriched_df = enriched_df.withColumn("prev_geom", lag("geometry").over(window_spec)) \
                             .withColumn("prev_timestamp", lag("timestamp").over(window_spec))
    
    enriched_df = enriched_df.withColumn("distance_m", 
        when(col("prev_geom").isNotNull(), 
             expr("ST_Distance(geometry, prev_geom)")) \
        .otherwise(0)
    )
    
    enriched_df = enriched_df.withColumn("time_diff_s", 
        when(col("prev_timestamp").isNotNull(),
             (unix_timestamp(col("timestamp")) - unix_timestamp(col("prev_timestamp")))
        ).otherwise(0)
    )
    
    enriched_df = enriched_df.withColumn("speed_m_s", 
        when(col("time_diff_s") > 0, col("distance_m") / col("time_diff_s"))
        .otherwise(0)
    )
    
    return enriched_df

def calculate_instantaneous_acceleration(enriched_df):
    window_spec = Window.partitionBy("horse_id").orderBy("timestamp")
    
    enriched_df = enriched_df.withColumn("prev_speed_m_s", lag("speed_m_s").over(window_spec))
    
    enriched_df = enriched_df.withColumn("acceleration_m_s2", 
        when(col("time_diff_s") > 0, 
             (col("speed_m_s") - col("prev_speed_m_s")) / col("time_diff_s")
        ).otherwise(0)
    )
    
    return enriched_df

def calculate_average_speed(enriched_df):
    average_speed_df = enriched_df.groupBy(
        "course_cd", "race_date", "race_number", "saddle_cloth_number", "horse_id"
    ).agg(
        (spark_sum("distance_m") / spark_sum("time_diff_s")).alias("average_speed_m_s")
    )
    
    return average_speed_df

def calculate_path_deviation(spark, enriched_df):
    deviation_df = enriched_df.withColumn(
        "deviation_m",
        col("distance_to_running_line_m")
    )
    
    return deviation_df

def handle_anomalies(enriched_df):
    cleaned_df = enriched_df.filter(col("speed_m_s") <= 20) \
                             .filter(col("acceleration_m_s2") <= 10)
    return cleaned_df

def validate_metrics(enriched_df):
    enriched_df.select(
        "course_cd",
        "race_date",
        "race_number",
        "saddle_cloth_number",
        "horse_id",
        "timestamp",
        "distance_m",
        "time_diff_s",
        "speed_m_s",
        "acceleration_m_s2",
        "deviation_m"
    ).show(5, truncate=False)

def visualize_average_speed(average_speed_df):
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    pandas_df = average_speed_df.toPandas()
    
    plt.figure(figsize=(10, 6))
    sns.barplot(x='horse_id', y='average_speed_m_s', hue='race_number', data=pandas_df)
    plt.title('Average Speed per Horse per Race')
    plt.xlabel('Horse ID')
    plt.ylabel('Average Speed (m/s)')
    plt.legend(title='Race Number')
    plt.show()

def main():
    try:
        # Initialize environment securely
        spark, jdbc_url, jdbc_properties, parquet_dir, log_file = initialize_environment_secure()
        
        # Initialize Spark session
        spark = initialize_spark(jdbc_driver_path, sedona_jar_abs_path, geotools_jar_paths)
        
        # Test Sedona integration
        test_sedona_integration(spark)
        
        # Define paths
        parquet_dir = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/data/parquet/"
        
        # Create dummy Parquet files if needed
        # create_dummy_parquet_files(parquet_dir, spark)
        
        # Calculate and integrate metrics
        enriched_df, average_speed_df, deviation_df = calculate_and_integrate_metrics(spark, parquet_dir)
        
        # Validate metrics
        validate_metrics(enriched_df)
        
        # Visualize average speed
        visualize_average_speed(average_speed_df)
        
    except Exception as e:
        print(f"An error occurred during processing: {e}")
    finally:
        if 'spark' in locals():
            spark.stop()
            print("Spark session stopped.")

In [None]:
main()

In [None]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import (
    col, avg, max as F_max, min as F_min, count, stddev, row_number, lag, lead,
    when, lit, first, last, abs, sum as F_sum, udf
)
from pyspark.sql.types import DoubleType
import math
import logging

# Initialize Spark session
def initialize_spark():
    jdbc_driver_path = "/home/exx/myCode/horse-racing/FoxRiverAIRacing/jdbc/postgresql-42.7.4.jar"
    extra_class_path = jdbc_driver_path  # Ensure this is the correct path to your JDBC JAR
    spark = SparkSession.builder \
        .appName("GPS Sectionals Analysis - Enhanced Aggregation") \
        .config("spark.driver.extraClassPath", extra_class_path) \
        .config("spark.executor.extraClassPath", extra_class_path) \
        .config("spark.driver.memory", "64g") \
        .config("spark.executor.memory", "32g") \
        .config("spark.executor.memoryOverhead", "8g") \
        .config("spark.sql.debug.maxToStringFields", "1000") \
        .config("spark.sql.adaptive.enabled", "true") \
        .getOrCreate()
    spark.sparkContext.setLogLevel("ERROR")
    logging.info("Spark session created successfully.")
    return spark

# Define the haversine function and UDF
def define_haversine_udf():
    def haversine(lat1, lon1, lat2, lon2):
        # Check for None values
        if None in (lat1, lon1, lat2, lon2):
            return 0.0
        # Convert decimal degrees to radians
        lon1, lat1, lon2, lat2 = map(
            lambda x: math.radians(float(x)), [lon1, lat1, lon2, lat2]
        )
        # Haversine formula
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = math.sin(dlat / 2) ** 2 + \
            math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
        c = 2 * math.asin(math.sqrt(a))
        # Radius of earth in meters
        r = 6371000
        return c * r
    return udf(haversine, DoubleType())

# Load data from the database
def load_data(spark, course):
    # Load sectionals data
    sectionals_df = spark.read.jdbc(
        url=jdbc_url,
        table="sectionals",
        properties=jdbc_properties
    ).filter(col("course_cd") == course).select(
        col("course_cd"),
        col("race_date"),
        col("race_number"),
        col("saddle_cloth_number"),
        col("gate_name"),
        col("length_to_finish"),
        col("sectional_time"),
        col("running_time"),
        col("distance_back"),
        col("distance_ran"),
        col("number_of_strides")
    )

    # Load gpspoint data
    gps_df = spark.read.jdbc(
        url=jdbc_url,
        table="gpspoint",
        properties=jdbc_properties
    ).filter(col("course_cd") == course).select(
        col("course_cd"),
        col("race_date"),
        col("race_number"),
        col("saddle_cloth_number"),
        "time_stamp",
        "longitude",
        "latitude",
        "progress",
        "speed",
        "stride_frequency"
    )

    # Load races data to get nominal race distance
    races_df = spark.read.jdbc(
        url=jdbc_url,
        table="races",
        properties=jdbc_properties
    ).filter(col("course_cd") == course).select(
        "course_cd",
        "race_date",
        "race_number",
        col("distance").alias("nominal_distance"),
        col("dist_unit").alias("nominal_dist_unit")
    )
    return sectionals_df, gps_df, races_df

# Rename columns for clarity
def rename_columns(sectionals_df, gps_df, races_df):
    sectionals_df = sectionals_df.select(
        col("course_cd").alias("s_course_cd"),
        col("race_date").alias("s_race_date"),
        col("race_number").alias("s_race_number"),
        col("saddle_cloth_number").alias("s_saddle_cloth_number"),
        "gate_name",
        "length_to_finish",
        "sectional_time",
        "running_time",
        "distance_back",
        "distance_ran",
        "number_of_strides"
    )

    gps_df = gps_df.select(
        col("course_cd").alias("g_course_cd"),
        col("race_date").alias("g_race_date"),
        col("race_number").alias("g_race_number"),
        col("saddle_cloth_number").alias("g_saddle_cloth_number"),
        "time_stamp",
        "longitude",
        "latitude",
        "progress",
        "speed",
        "stride_frequency"
    )

    races_df = races_df.select(
        col("course_cd").alias("r_course_cd"),
        col("race_date").alias("r_race_date"),
        col("race_number").alias("r_race_number"),
        col("nominal_distance"),
        col("nominal_dist_unit")
    )
    return sectionals_df, gps_df, races_df

# Join sectionals and GPS data
def join_sectionals_gps(sectionals_df, gps_df):
    gps_with_gates = sectionals_df.alias("s").join(
        gps_df.alias("g"),
        (col("s.s_course_cd") == col("g.g_course_cd")) &
        (col("s.s_race_date") == col("g.g_race_date")) &
        (col("s.s_race_number") == col("g.g_race_number")) &
        (col("s.s_saddle_cloth_number") == col("g.g_saddle_cloth_number")),
        how="inner"
    )
    return gps_with_gates

# Find the closest GPS point to each gate
def find_closest_gps_points(gps_with_gates):
    gps_with_gates = gps_with_gates.withColumn(
        "progress_diff",
        abs(col("g.progress") - col("s.length_to_finish"))
    )

    window_spec_gate = Window.partitionBy(
        "s.s_course_cd",
        "s.s_race_date",
        "s.s_race_number",
        "s.s_saddle_cloth_number",
        "gate_name"
    ).orderBy("progress_diff")

    gps_at_gates = gps_with_gates.withColumn(
        "row_number",
        row_number().over(window_spec_gate)
    ).filter(col("row_number") == 1).drop("row_number", "progress_diff")
    return gps_at_gates

# Calculate speed changes and acceleration
def calculate_speed_acceleration(gps_at_gates):
    gate_order_window = Window.partitionBy(
        "s.s_course_cd",
        "s.s_race_date",
        "s.s_race_number",
        "s.s_saddle_cloth_number"
    ).orderBy("gate_name")  # Modify if gate_name is not sortable

    gps_at_gates = gps_at_gates.withColumn(
        "previous_speed",
        lag("g.speed").over(gate_order_window)
    ).withColumn(
        "speed_change",
        col("g.speed") - col("previous_speed")
    ).withColumn(
        "acceleration",
        when(col("previous_speed").isNotNull(),
             col("speed_change") / col("previous_speed")
             ).otherwise(lit(0))
    )
    return gps_at_gates

# Identify fastest and slowest gates
def identify_fastest_slowest_gates(gps_at_gates):
    speed_window = Window.partitionBy(
        "s.s_course_cd",
        "s.s_race_date",
        "s.s_race_number",
        "s.s_saddle_cloth_number"
    )

    gps_at_gates = gps_at_gates.withColumn(
        "max_speed",
        F_max("g.speed").over(speed_window)
    ).withColumn(
        "min_speed",
        F_min("g.speed").over(speed_window)
    ).withColumn(
        "is_fastest_gate",
        when(col("g.speed") == col("max_speed"), lit(1)).otherwise(lit(0))
    ).withColumn(
        "is_slowest_gate",
        when(col("g.speed") == col("min_speed"), lit(1)).otherwise(lit(0))
    )
    return gps_at_gates

# Calculate fatigue factor
def calculate_fatigue_factor(gps_at_gates):
    # Extract finish speed
    finish_speed = gps_at_gates.filter(col("gate_name") == "Finish").select(
        col("s.s_course_cd").alias("course_cd"),
        col("s.s_race_date").alias("race_date"),
        col("s.s_race_number").alias("race_number"),
        col("s.s_saddle_cloth_number").alias("saddle_cloth_number"),
        col("g.speed").alias("finish_speed")
    )

    # Join finish speed back to gps_at_gates
    gps_at_gates = gps_at_gates.join(
        finish_speed,
        on=[
            gps_at_gates["s.s_course_cd"] == finish_speed["course_cd"],
            gps_at_gates["s.s_race_date"] == finish_speed["race_date"],
            gps_at_gates["s.s_race_number"] == finish_speed["race_number"],
            gps_at_gates["s.s_saddle_cloth_number"] == finish_speed["saddle_cloth_number"]
        ],
        how="left"
    )

    # Calculate fatigue factor
    gps_at_gates = gps_at_gates.withColumn(
        "fatigue_factor",
        (col("max_speed") - col("finish_speed")) / col("max_speed")
    )
    return gps_at_gates

# Prepare aggregated metrics per gate
def prepare_aggregated_metrics(gps_at_gates):
    aggregated_metrics = gps_at_gates.select(
        col("s.s_course_cd").alias("course_cd"),
        col("s.s_race_date").alias("race_date"),
        col("s.s_race_number").alias("race_number"),
        col("s.s_saddle_cloth_number").alias("saddle_cloth_number"),
        "gate_name",
        col("g.speed").alias("speed"),
        "acceleration",
        "fatigue_factor",
        "is_fastest_gate",
        "is_slowest_gate"
    )

    per_gate_metrics = aggregated_metrics.groupBy(
        "course_cd",
        "race_date",
        "race_number",
        "saddle_cloth_number",
        "gate_name"
    ).agg(
        avg("speed").alias("avg_speed"),
        avg("acceleration").alias("avg_acceleration"),
        F_max("speed").alias("max_speed"),
        F_min("speed").alias("min_speed"),
        F_max("fatigue_factor").alias("fatigue_factor"),
        F_max("is_fastest_gate").alias("is_fastest_gate"),
        F_max("is_slowest_gate").alias("is_slowest_gate")
    )
    return per_gate_metrics

# Calculate actual distance run and ground loss
def calculate_ground_loss(gps_df, races_df, haversine_udf):
    # Define window specification
    window_spec_time = Window.partitionBy(
        "g_course_cd", "g_race_date", "g_race_number", "g_saddle_cloth_number"
    ).orderBy("time_stamp")

    # Get previous latitude and longitude
    gps_df = gps_df.withColumn("prev_latitude", lag("latitude").over(window_spec_time))
    gps_df = gps_df.withColumn("prev_longitude", lag("longitude").over(window_spec_time))

    # Calculate segment distance using haversine formula
    gps_df = gps_df.withColumn(
        "segment_distance",
        haversine_udf(
            col("prev_latitude"),
            col("prev_longitude"),
            col("latitude"),
            col("longitude")
        )
    )

    # Fill null values
    gps_df = gps_df.fillna({"segment_distance": 0})

    # Calculate cumulative distance
    gps_df = gps_df.withColumn(
        "cumulative_distance",
        F_sum("segment_distance").over(window_spec_time)
    )

    # Get total distance per horse
    total_distance_df = gps_df.groupBy(
        "g_course_cd", "g_race_date", "g_race_number", "g_saddle_cloth_number"
    ).agg(
        F_max("cumulative_distance").alias("total_distance_run")
    )

    # Convert nominal distance to meters with proper scaling
    races_df = races_df.withColumn(
        "nominal_distance_meters",
        when(col("nominal_dist_unit") == 'F', (col("nominal_distance") / 100) * 201.168)
        .when(col("nominal_dist_unit") == 'Y', col("nominal_distance") * 0.9144)
        .otherwise(lit(None))
    )


    # Join total_distance_df with races_df
    distance_comparison_df = total_distance_df.join(
        races_df,
        (total_distance_df["g_course_cd"] == races_df["r_course_cd"]) &
        (total_distance_df["g_race_date"] == races_df["r_race_date"]) &
        (total_distance_df["g_race_number"] == races_df["r_race_number"]),
        how="left"
    )

    # Calculate ground loss
    distance_comparison_df = distance_comparison_df.withColumn(
        "ground_loss",
        col("total_distance_run") - col("nominal_distance_meters")
    )

    # Join and select required columns
    distance_comparison_df = total_distance_df.join(
        races_df,
        (total_distance_df["g_course_cd"] == races_df["r_course_cd"]) &
        (total_distance_df["g_race_date"] == races_df["r_race_date"]) &
        (total_distance_df["g_race_number"] == races_df["r_race_number"]),
        how="left"
    ).withColumn(
        "ground_loss",
        col("total_distance_run") - col("nominal_distance_meters")
    ).select(
        col("g_course_cd").alias("course_cd"),
        col("g_race_date").alias("race_date"),
        col("g_race_number").alias("race_number"),
        col("g_saddle_cloth_number").alias("saddle_cloth_number"),
        "total_distance_run",
        "ground_loss"
    )

    return distance_comparison_df

# Integrate ground loss into final metrics
def integrate_ground_loss(per_gate_metrics, distance_comparison_df):
    final_metrics = per_gate_metrics.join(
        distance_comparison_df,
        on=["course_cd", "race_date", "race_number", "saddle_cloth_number"],
        how="left"
    )
    return final_metrics
    
# Write final metrics to the database
def write_to_database(final_metrics):
    # Ensure all necessary columns are included
    required_columns = [
        'course_cd', 'race_date', 'race_number', 'saddle_cloth_number',
        'gate_name', 'avg_speed', 'avg_acceleration', 'max_speed', 'min_speed',
        'fatigue_factor', 'is_fastest_gate', 'is_slowest_gate',
        'total_distance_run', 'ground_loss'
    ]

    # Check if all required columns are present
    missing_columns = [col for col in required_columns if col not in final_metrics.columns]
    if missing_columns:
        logging.error(f"Missing columns in final_metrics: {missing_columns}")
        return

    # Write final metrics to the database
    final_metrics.write.jdbc(
        url=jdbc_url,
        table="gps_aggregated_results_with_gates",
        mode="append",
        properties=jdbc_properties
    )
    
# Main processing function for each course
def process_course(course, spark, haversine_udf):
    print(f"Processing course: {course}")

    # Load data
    sectionals_df, gps_df, races_df = load_data(spark, course)
    logging.info(f"Loading data from database for course: {course}")
    # Rename columns
    sectionals_df, gps_df, races_df = rename_columns(sectionals_df, gps_df, races_df)

    # Join sectionals and GPS data
    gps_with_gates = join_sectionals_gps(sectionals_df, gps_df)

    # Find closest GPS points to gates
    gps_at_gates = find_closest_gps_points(gps_with_gates)

    # Calculate speed changes and acceleration
    gps_at_gates = calculate_speed_acceleration(gps_at_gates)

    # Identify fastest and slowest gates
    gps_at_gates = identify_fastest_slowest_gates(gps_at_gates)

    # Calculate fatigue factor
    gps_at_gates = calculate_fatigue_factor(gps_at_gates)

    # Prepare aggregated metrics
    per_gate_metrics = prepare_aggregated_metrics(gps_at_gates)

    # Calculate ground loss
    distance_comparison_df = calculate_ground_loss(gps_df, races_df, haversine_udf)

    per_gate_metrics.printSchema()
    logging.info(per_gate_metrics.printSchema())
    distance_comparison_df.printSchema()
    logging.info(distance_comparison_df.printSchema())
    # Integrate ground loss into final metrics
    final_metrics = integrate_ground_loss(per_gate_metrics, distance_comparison_df)
    
    final_metrics = integrate_ground_loss(per_gate_metrics, distance_comparison_df)
    final_metrics.printSchema()
    logging.info(final_metrics.printSchema())

    # Write final metrics to database
    write_to_database(final_metrics)

    print(f"Completed processing for course: {course}")

def main():
    # Initialize Spark session
    spark = initialize_spark()
    spark.sparkContext.setLogLevel("ERROR")
    # Define haversine UDF
    haversine_udf = define_haversine_udf()

    # List of courses
    courses = ['CNL', 'SAR', 'PIM', 'TSA', 'BEL', 'MVR', 'TWO', 'CLS', 'KEE', 'TAM',
               'TTP', 'TKD', 'ELP', 'PEN', 'HOU', 'DMR', 'TLS', 'AQU', 'MTH', 'TGP',
               'TGG', 'CBY', 'LRL', 'TED', 'IND', 'CTD', 'ASD', 'TCD', 'LAD', 'MED',
               'TOP', 'HOO']

    # Process each course
    for course in courses:
        try:
            process_course(course, spark, haversine_udf)
        except Exception as e:
            logging.error(f"Error processing course {course}: {e}")

    # Stop Spark session
    spark.stop()

# if __name__ == "__main__":
main()

1. Modular Functions

>    •	initialize_spark(): Sets up the Spark session with the necessary configurations.
> 
>    •	define_haversine_udf(): Defines the Haversine function to calculate distances between GPS points and registers it as a UDF.
> 
>	•	**load_data()**: Loads the sectionals, gpspoint, and races data from the database for a given course.
> 
> 	•	rename_columns(): Aliases columns in the DataFrames for clarity and to avoid naming conflicts.
> 
> 	•	join_sectionals_gps(): Joins the sectionals and gpspoint DataFrames on the race and horse identifiers.
> 
>	•	find_closest_gps_points(): Finds the closest GPS point to each gate for each horse.
> 
> 	•	calculate_speed_acceleration(): Calculates speed changes and acceleration between gates.
> 
>	•	identify_fastest_slowest_gates(): Identifies the fastest and slowest gates for each horse.
> 
>	•	calculate_fatigue_factor(): Computes the fatigue factor for each horse based on their maximum speed and finish speed.
> 
>	•	prepare_aggregated_metrics(): Aggregates the metrics per gate and per horse.
> 
>	•	calculate_ground_loss(): Calculates the actual distance run by each horse and computes the ground loss compared to the nominal race distance.
> 
>	•	integrate_ground_loss(): Integrates the ground loss metric into the final aggregated metrics.
>	•	write_to_database(): Writes the final metrics to the database.
> 
>	•	process_course(): Orchestrates the processing steps for a single course.
> 
>	•	main(): The main function that initializes the Spark session, processes each course, and stops the Spark session.


Next Steps

	1.	Visualization:
	•	Plot speed and stride frequency trends across gates for specific horses to validate the analysis.
	2.	Testing:
	•	Apply this to a few more courses to confirm that the calculations are meaningful and robust across different races.
	3.	Analysis:
	•	Compare fatigue factors or speed trends between winners and non-winners to derive insights about race dynamics.

These enhancements should provide valuable insights and improve the predictive capabilities of your analysis!


# Additional steps to derive fatigue factors, sectional efficiency, etc.

Innovative Applications

	1.	Predictive Fatigue Model: Train a model using TPD data to predict fatigue thresholds for horses.
	2.	Race Simulation: Use historical TPD data to simulate how horses might perform in upcoming races.
	3.	Dynamic Betting Insights: Provide real-time insights into how race conditions or competitor performance might influence outcomes.

Spark’s distributed computing will allow you to process the large dataset efficiently and scale as needed. By creatively combining TPD and EQB data, you can uncover insights that traditional analysis might overlook.
