In [None]:
from pyspark.sql.types import DateType, IntegerType
from pyspark.sql.functions import year, month, dayofweek, avg, coalesce, expr, when

In [None]:
def process_data(df):
  """
  Function that does pre-processing of the data, filling for potential NULLs in the data
  
  Parameters
  ----------
  df: pyspark.sql.dataframe.DataFrame
    Dataframe with raw information
    
  Returns
  ----------
  df: pyspark.sql.dataframe.DataFrame
    Dataframe with cleaned information, and no NULLs
  """
  # Change data types
  df = df\
  .withColumn("date", df['date'].cast(DateType()))\
  .withColumn("demand", df['demand'].cast(IntegerType()))  
  
  # Order registers by SKU and then, date
  df = df.orderBy(
    "n_sku",
    "date"
  )
  
  # To fill NULLs, we'll calculate an average by year, month and weekday. 
  # To do this, we start by calculating year, month and weekday columns
  df = df\
    .withColumn("year", year(df['date']))\
    .withColumn("month", month(df['date']))\
    .withColumn("weekday", dayofweek(df['date']))
  
  # Group by and calculate average
  df_avg = df.groupBy(
    "n_sku",
    "year",
    "month",
    "weekday"
  ).agg(
    avg("demand").alias("demand_avg")
    )
  
  # Merge base dataframe with df_avg
  df = df.join(df_avg, ['n_sku', 'month', 'year', 'weekday'])
  
  # Fill NULLs
  df = df.withColumn("demand", coalesce("demand", "demand_avg"))
  
  # Keep only relevant columns
  df = df.select(
    "date",
    "n_sku",
    "demand"
  )
  
  return df

In [None]:
def remove_outliers(df):
  """
  Function that removes outliers, finding them statistically after de-trending and de-seasonalizing each time series
  
  Parameters
  ----------
  df: pyspark.sql.dataframe.DataFrame
    Dataframe with outliers
    
  Returns
  ----------
  df: pyspark.sql.dataframe.DataFrame
    Dataframe without outliers
  """
  # First, we de-seasonalize the series, dividing each value by the weekday average
  # To do this, we start by calculating the average per weekday
  df = df\
    .withColumn("year", year(df['date']))\
    .withColumn("month", month(df['date']))\
    .withColumn("weekday", dayofweek(df['date']))
  
  df_avg_weekday = df.groupBy(
    "n_sku",
    "weekday"
  ).agg(
    avg("demand").alias("avg_per_weekday")
    )
  
  # We de-seasonalize, dividing by this average
  df = df.join(df_avg_weekday, ["n_sku", "weekday"])
  df = df\
    .withColumn("demand_de_seas", df["demand"]/df["avg_per_weekday"])
  
  # We now repeat the same, calculating the average per year and month of the de-seasonalized series
  df_avg_month = df.groupBy(
    "n_sku",
    "year",
    "month"
  ).agg(
    avg("demand_de_seas").alias("avg_per_month")
  )
  df = df.join(df_avg_month, ["n_sku", "year", "month"])
  df = df\
    .withColumn("demand_de_trend", df["demand_de_seas"]/df["avg_per_month"])
  
  # Calculate percentiles by DFU
  df_perc = df.groupBy(
    "n_sku"
  ).agg(
    expr('percentile(demand_de_trend, array(0.75))')[0].alias("Q3"),
    expr('percentile(demand_de_trend, array(0.25))')[0].alias("Q1")
    )
  # A point will be an outlier if
  # - It is greater than ThirdQuartile + 3*InterQuartileRange
  # - It is smaller than FirstQuartile - 3*InterQuartileRange
  df_perc = df_perc\
    .withColumn("HighOutlierThreshold", df_perc["Q3"] + 3*(df_perc['Q3'] - df_perc["Q1"]))\
    .withColumn("LowOutlierThreshold", df_perc["Q1"] - 3*(df_perc['Q3'] - df_perc["Q1"]))
  
  # Join information
  df = df.join(df_perc, ["n_sku"])
  
  # Create outlier indicator column
  df = df.withColumn("outlier_flag", 
                      (when(df['demand_de_trend'] > df['HighOutlierThreshold'], 1)
                      .when(df['demand_de_trend'] < df['LowOutlierThreshold'], 1)
                      .otherwise(0))
                    )
  # We will replace outliers with an average per weekday, year and month
  df_avg = df.groupBy(
    "n_sku",
    "year",
    "month",
    "weekday"
  ).agg(
    avg("demand").alias("avg_per_month_weekday")
  )
  df = df.join(df_avg, ["n_sku", "year", "month", "weekday"])
    
  df = df.withColumn("demand_corrected", 
                    (when(df['outlier_flag'] == 1, df['avg_per_month_weekday'])
                    .otherwise(df['demand']))
                  )
  # Rename columns and select what we want
  df = df.withColumn('demand', df['demand_corrected'])
  df = df.select(
  "date",
  "n_sku",
  "demand"
)

  return df