# Help functions

In [None]:
"""

based on: https://github.com/datasciencecampus/UNGP-AIS-ETL/blob/feature/StructuredStreaming/spark_etl_bench/src/utils.py 

"""
from ._poly import poly_container

from datetime import datetime
from typing import Set, Dict, List, Tuple, Optional
import logging

import h3
import h3.api.numpy_int as h3int

import pandas as pd

from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import functions as F

from shapely.geometry import shape
from sedona.sql.types import GeometryType
from pyspark.sql.types import StructField, StructType

# from sedona.register import SedonaRegistrator

logging.basicConfig(format="%(asctime)s %(levelname)s:%(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)


def apply_small_filter(
    spark: SparkSession,
    ais_df: DataFrame,
    filter_criteria,
    column_filter: str
) -> DataFrame:
    """A wrapper function on broadcast-join. This is efficient for small list.
    
    Parameters
    ----------
    spark: SparkSession
    
    ais_df: Spark DataFrame
        The dataframe should contain 'column_filter' column
    
    filter_values: Pandas DataFrame or List
        If List - list of values for the filter
        If Pandas Dataframe, should contain 'column_filter' column. Other columns will be included
        in the output.
        
    column_filter: str
        name of column to apply the filter
        
    Returns
    -------
    Spark dataframe with filter applied
    
    """
    
    if column_filter not in ais_df.columns:
        raise Exception("column_filter should be in ais_df columns")
    
    if type(filter_criteria) is list:
        matrix = [[i] for i in filter_criteria]
        df_filter = spark.createDataFrame(matrix, schema=[column_filter])
        
    elif isinstance(filter_criteria, pd.DataFrame):
        if column_filter in filter_criteria.columns:
            df_filter = spark.createDataFrame(filter_criteria)
        else:
            raise Exception("filter criteria dataframe should have column_filter")
            
    filtered_ais_df = ais_df.join(F.broadcast(df_filter), on=column_filter)

    return filtered_ais_df

def apply_geo_filter(
    spark: SparkSession,
    df: DataFrame,
    polygon: dict
) -> DataFrame:
    """A wrapper function sedona ST_Within filter
    
    Parameters
    ----------
    df: Spark DataFrame
        The dataframe should contain latitude, longitude columns
    
    polygon: dict
        GeoJson representation of polygon.
        

    Returns
    -------
    Spark dataframe with filter applied
    
    """
    # SedonaRegistrator.registerAll(spark)
    
    schema = StructType([StructField("geom", GeometryType(), False)])
    gdf = spark.createDataFrame([[shape(polygon)]], schema)
    
    gdf.createOrReplaceTempView("temp2")
    
    #drop rows with missing latitude or longitude
    df = df.filter(~df.latitude.isNull() & ~df.longitude.isNull())
    
    df.createOrReplaceTempView("temp")
    
    df = spark.sql("""
            select *  
            from 
                ( select *, ST_Point(longitude,latitude) as point, 
                            (select geom from temp2
                            ) as polygon 
                 from temp 
                ) 
            where ST_Within(point, polygon) 
           """.format(polygon))
    return df.drop('point','polygon')

def get_ais(
    spark: SparkSession,
    start_date: datetime,
    end_date: datetime = None,
    h3_list: Optional[List[int]] = None,
    polygon_hex_df: Optional[DataFrame] = None,
    mmsi_list: Optional[List[int]] = None,
    message_type: Optional[List[int]] = [1,2,3,4,18,19,27],
    columns: Optional[List[str]] = ["*"],
    polygon: Optional[Dict] = None,
    polygon_hex_resolution: Optional[int] = 8
) -> DataFrame:
    """A wrapper function to apply filters on the AIS data.
    Note that default parameters for message type are 
    position message types 
    
    Parameters
    ----------
    spark: SparkSession
    
    start_date: datetime
        the start date filter to apply
        
    end_date: datetime
        the end date filter to apply. To filter a single date, use end_date equal to start_date
        
    h3_list: list of int, default None
        h3 indices must be in int format and must have the same resolution. 
        if None then it is not applied. 
        
    polygon_hex_df: dataframe, from polygon_to_hex_df function
        Dataframe with the following columns (minimum columns to contain) :
        - hex_id: the h3 hex ids (64-bit ints)
        - polygon_name: the name of the polygon 
        - hex_resolution: the resolution of the hex (should be the same for all)
        The hex_ids should be contained in only one polygon_name, otherwise resulting dataframe
        will contain duplicate entries.
        
    mmsi_list: list of int, default None
        the list of mmsi filter to apply. if None, then it is not applied
        
    message_type: list of int, default [1,2,3,4,18,19,27] <- position messages
        the list of message types to retain. if not supplied then the default message type filter is applied
        use ["*"] to get all message types
        
    columns: list of str, default ["*"]
        the list of columns to retain. if not supplied, all columns are returned
        
    polygon: Optional[Dict] = None
        GeoJson representation of polygon. If supplied, then the hex approximation of the polygon
        is calculated using poly_container (polygon, hex_resolution, overfill=True). The AIS data 
        will be filtered according to hexes first  and then according to polygon using 
        Sedona functions: 
        
        select *,  ST_Point(longitude,latitude) as point, ST_GeomFromGeoJSON('{polygon}') as 
        polygon from temp where ST_Within(point, polygon)
        
        
    polygon_hex_resolution: int = 8
        The resolution of the hexagons to fill the input polygon with. Default is 8, a hex with an avg area of 0.737 sq km. 
        A polygon with an area of 100 sq. km will contain ~136 resolution 8 hexes. The same 100 sq. km polygon 
        can be approximated by ~949 hexes using resolution 9. Note that the higher the resolution, the higher 
        the polygon area covered by the hexes. However, a small increase in resolution dramatically increases
        the number of hexes. See https://h3geo.org/docs/core-library/restable/ for a table of hex resolutions.
    
    Returns
    -------
    Spark dataframe with the filters applied. 
    
    Notes
    -----
    If multiple filters are provided, the most restrictive filters are applied. For example, both polygon and h3_list
    are provided where h3_list is a list of hexes fully contained within the polygon. The filtered AIS data will only contain 
    those within the hexes. Data within the polygon but outside the hexes will not be included. 
    
    """
    
    #read ais_df 
    basepath = "s3a://ungp-ais-data-historical-backup/exact-earth-data/transformed/prod/"
    if end_date is None:
        end_date = start_date
    dates = pd.date_range(start_date, end_date)
    paths = [basepath+"year=" +str(x.year) + "/month="+ str(x.month).zfill(2) + "/day="+ str(x.day).zfill(2) for x in dates]
    
    filtered_ais_df = spark.read.parquet(*paths)
        
    #Apply mmsi filter
    if mmsi_list is not None:
        filtered_ais_df = apply_small_filter(spark,
            filtered_ais_df,
            mmsi_list,
            "mmsi"
        )
        
    #apply message type filter
    if message_type is not ["*"]:
        filtered_ais_df = apply_small_filter(spark,
                filtered_ais_df,
                message_type,
                "message_type"
            )
    
    #apply h3 filter
    if h3_list is not None:
        unique_h3_res = list(set([h3int.h3_get_resolution(x) for x in h3_list]))
        if len(unique_h3_res) != 1:
            raise Exception("h3_list should contain h3 indices with the same resolution.")
        
        filtered_ais_df = apply_small_filter(spark,
            filtered_ais_df,
            h3_list,
            f"H3_int_index_{unique_h3_res[0]}"
        )
    
    #apply polygon_hex_df filter
    if polygon_hex_df is not None:
        if not set(["hex_id","polygon_name","hex_resolution"]) <= set(polygon_hex_df.columns):
            raise Exception("polygon_hex_df should contain columns 'hex_id','polygon_name','hex_resolution'")
          
        
        column_filter = f"H3_int_index_{polygon_hex_df.hex_resolution.mode()[0]}"
        poly_copy = polygon_hex_df.rename(columns={'hex_id':column_filter})
        
        filtered_ais_df = apply_small_filter(spark,
            filtered_ais_df,
            poly_copy,
            column_filter
        )
        
        if columns != ["*"]:
            columns = list(set(columns + poly_copy.columns.tolist()))
    
    #apply polygon filter
    if polygon is not None:
        h3_list = poly_container(polygon, hex_resolution=polygon_hex_resolution, overfill=True)
        
        filtered_ais_df = apply_small_filter(spark,
            filtered_ais_df,
            h3_list,
            f"H3_int_index_{polygon_hex_resolution}"
        )
        
        
        filtered_ais_df = apply_geo_filter(spark,
            filtered_ais_df,
            polygon
        )

    #apply column filter
    filtered_ais_df = filtered_ais_df.select(columns)
    
    return filtered_ais_df
