In [1]:
# ! pip install holidays

In [2]:
import findspark
findspark.init("/opt/spark")

In [3]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import *
from datetime import date
import holidays

In [4]:
spark = SparkSession.builder.appName("Spark ML Custom Transformation").master("local[*]") \
.config("spark.driver.memory", "4g") \
.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/01 11:55:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
data = [('2013-01-01',1, 1, 13),('2013-01-02', 1, 1, 11), ('2013-01-03', 1, 1, 14), ('2013-01-04', 1, 1, 10), ('2013-01-05', 1, 1, 10) ]
cols = ['date', 'store', 'item', 'label']

df = spark.createDataFrame(data, cols).withColumn('date', F.to_date('date', 'yyy-MM-dd'))

df.printSchema()

root
 |-- date: date (nullable = true)
 |-- store: long (nullable = true)
 |-- item: long (nullable = true)
 |-- label: long (nullable = true)



In [6]:
df.show()

                                                                                

+----------+-----+----+-----+
|      date|store|item|label|
+----------+-----+----+-----+
|2013-01-01|    1|   1|   13|
|2013-01-02|    1|   1|   11|
|2013-01-03|    1|   1|   14|
|2013-01-04|    1|   1|   10|
|2013-01-05|    1|   1|   10|
+----------+-----+----+-----+



In [7]:
from pyspark.ml import Transformer
from pyspark.sql.functions import lit, udf
from pyspark.ml.param.shared import HasInputCols, HasOutputCol
from datetime import date

# @udf(IntegerType())
# def is_holiday(date_str: date, country_code: str='TR'):
#     date_str = str(date_str)
#     country_holidays = holidays.CountryHoliday(country_code) 
#     date_obj = date.fromisoformat(date_str)
#     if date_obj in country_holidays:
#         return 1
#     else:
#         return 0
        
class AddDateFeaturesTransformer(Transformer, HasInputCols, HasOutputCol):
    def __init__(self, inputCol=None, outputCols=None, country_code=None):
        super(AddDateFeaturesTransformer, self).__init__()
        self.inputCol = inputCol
        self.outputCols = outputCols
        self.country_code = country_code
        
    def is_holiday(self, date_str: date, country_code: str='TR'):
        date_str = str(date_str)
        country_holidays = holidays.CountryHoliday(country_code) 
        date_obj = date.fromisoformat(date_str)
        if date_obj in country_holidays:
            return 1
        else:
            return 0
        
    def _transform(self, df):
        is_holiday = udf(self.is_holiday, IntegerType())
        
        df = df.withColumn(self.outputCols[0], F.year(self.inputCol)) \
        .withColumn(self.outputCols[1], F.month(self.inputCol)) \
        .withColumn(self.outputCols[2], F.dayofweek(self.inputCol)) \
        .withColumn(self.outputCols[3], is_holiday(self.inputCol))

        return df

In [8]:
sct = AddDateFeaturesTransformer(inputCol='date', outputCols=['y','m','d','is_holiday'], country_code='TR')

In [9]:
sct.transform(df).show(5)



+----------+-----+----+-----+----+---+---+----------+
|      date|store|item|label|   y|  m|  d|is_holiday|
+----------+-----+----+-----+----+---+---+----------+
|2013-01-01|    1|   1|   13|2013|  1|  3|         1|
|2013-01-02|    1|   1|   11|2013|  1|  4|         0|
|2013-01-03|    1|   1|   14|2013|  1|  5|         0|
|2013-01-04|    1|   1|   10|2013|  1|  6|         0|
|2013-01-05|    1|   1|   10|2013|  1|  7|         0|
+----------+-----+----+-----+----+---+---+----------+



                                                                                