In [None]:
# ---
# jupyter:
#   jupytext:
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.5.1
#   kernelspec:
#     display_name: Python 3
#     language: python
#     name: python3
# ---

# flake8: noqa
# pylint: skip-file

# # Noob Pipeline Examples

# +
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from noob.utils.test import PySparkTest
from noob.utils.io import read_parquet
from warp.spark.date import create_calendar_table

spark = SparkSession.builder.getOrCreate()

# ## [Daily data generation](api/noob.data.daily_data.rst#noob.data.daily_data.base.BaseStockedDayGenerator)

# ### Basic Usage

from datetime import datetime
from pyspark.sql import functions as F
from noob.data.daily_data import BaseStockedDayGenerator
from os.path import dirname, abspath, join
from tests.data.test_daily_data import DailyDataTest as tables


data_path = join(dirname(dirname(abspath(''))), "data")
tables.create_test_data(spark)
sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar
daily_data_path = join(data_path, "model-data", "daily-data")

stocked_day_generator = BaseStockedDayGenerator(sales=sales, inventory=inventory, calendar=calendar, target_date="2019-07-10")
stocked_day_generator.generate_daily_data()
stocked_day_generator.write(daily_data_path)

# ### Daily data with extra_info

from datetime import datetime

extra_info_df = tables.products
extra_info_df = extra_info_df.select("product_id", "product_description").distinct()
extra_info = [
    {
    "dataframe": extra_info_df,
    "name": "brands",
    "join_keys": ["product_id"],
    "join_columns": ["product_description"],
    "how": "left",
    "rename": False
    }
]

sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar
inventory.printSchema()

from noob.data.daily_data import BaseStockedDayGenerator
b = BaseStockedDayGenerator(sales=sales, inventory=inventory, calendar=calendar, target_date="2019-07-10")
b.generate_daily_data(extra_data_config=extra_info)
b.inventory_sales.printSchema()

# ### Daily data with sku_family and multiplier

# Let's assume that the product 0 is sold in a pack of six and it has two different parents for
# two different stores e.g. 3 and 4. Both parents are sold as a single item. Thus, we need to
# set the multiplier as 6. The rest of the products do not have any parents. First, let's generate daily data
# without multiplier adjustment.
#

# +
sku_family = spark.createDataFrame(
            [(3, 2, 0),
             (4, 1, 0)],
            ["store_id", "parent_id", "product_id"])

sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar

stocked_day_generator = BaseStockedDayGenerator(sales=sales, inventory=inventory, calendar=calendar, target_date="2019-07-10")
stocked_day_generator.generate_daily_data(sku_family=sku_family)

daily_wo_multiplier = stocked_day_generator.inventory_sales
# -

# Now, let's generate the daily data with multiplier adjustment and observe the difference. columns_to_multiply is given as
# ["sales_quantity", "inventory", "incoming_inventory"]

# +
sku_family = spark.createDataFrame(
        [(3, 2, 0, 6),
         (4, 1, 0, 6)],
        ["store_id", "parent_id", "product_id", "multiplier"]
)

stocked_day_generator = BaseStockedDayGenerator(sales=sales, inventory=inventory, calendar=calendar, target_date="2019-07-10")
stocked_day_generator.generate_daily_data(sku_family=sku_family, multiply_exclude_cols=("sales_revenue",))

daily_with_multiplier = stocked_day_generator.inventory_sales
# -

daily_with_multiplier.exceptAll(daily_wo_multiplier).orderBy("date", "store_id", "product_id").show(10)

daily_wo_multiplier.exceptAll(daily_with_multiplier).orderBy("date", "store_id", "product_id").show(10)

# Now assume that the parent is sold as six pack instead. Then the multiplier should be 1/6.

# +
sku_family = spark.createDataFrame(
    [(3, 2, 0, 1 / 6),
     (4, 1, 0, 1 / 6)],
    ["store_id", "parent_id", "product_id", "multiplier"]
)

stocked_day_generator = BaseStockedDayGenerator(sales=sales, inventory=inventory, calendar=calendar, target_date="2019-07-10")
stocked_day_generator.generate_daily_data(sku_family=sku_family, multiply_exclude_cols=("sales_revenue",))

daily_reverse = stocked_day_generator.inventory_sales
# -

daily_reverse.exceptAll(daily_wo_multiplier).show(10)


# ### Daily data with extra inventory column and store aggregated

sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar

daily_data = BaseStockedDayGenerator(
    sales=sales,
    inventory=inventory,
    target_date="2019-07-10",
    calendar=calendar,
    update_range=None,
    groupby_columns=("year", "week", "date", "product_id"), # store dimension is removed
    inventory_columns=("inventory", "incoming_inventory", "outgoing_inventory"), # outgoing_inventory is added
    sales_columns=("sales_quantity", "sales_revenue"))

daily_data.generate_daily_data()

daily_data.inventory_sales.show(5)

# ### Daily data with different aggregation types

sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar

daily_data = BaseStockedDayGenerator(
    sales=sales,
    inventory=inventory,
    target_date="2019-07-10",
    calendar=calendar,
    update_range=None,
    groupby_columns=("year", "week", "date", "product_id"),
    inventory_columns=({"agg_name": "min",
                        "col_name": "inventory",
                        "alias": "inventory"},
                       {"agg_name": "min",
                        "col_name": "incoming_inventory",
                        "alias": "incoming_inventory"},
                       {"agg_name": "min",
                        "col_name": "outgoing_inventory",
                        "alias": "outgoing_inventory"}),
    sales_columns=("sales_quantity", "sales_revenue")) # defaults to {"agg_name": "sum",
                                                       #              "col_name": "sales_quantity",
                                                       #              "alias": "sales_quantity"}

daily_data.generate_daily_data()

daily_data.inventory_sales.show(5)

# ### Daily data with Input/Output Control

sales = tables.sales
inventory = tables.inventory
calendar = tables.calendar
daily_data = BaseStockedDayGenerator(
    sales=sales, inventory=inventory,
    target_date="2019-07-10",
    calendar=calendar,
    update_range=None)
daily_data.generate_daily_data()
daily_data.inventory_sales.cache()

new_data = daily_data.inventory_sales \
    .filter(F.col("date") == "2019-07-10")
old_data = daily_data.inventory_sales \
    .filter(F.col("date") == "2019-07-03")

summary_new = daily_data.generate_summary_rowwise(new_data)
summary_old = daily_data.generate_summary_rowwise(old_data)
comparison = daily_data \
    .control(summary_old, summary_new, controller=None, logger=None, groupby_features=["store_id", "product_id"])
summary_old.show()
summary_new.show()

# Input output controls have passed and thus, we have an empty comparison dataframe.

comparison.show()

# ## [Outlier detection](api/noob.outlier.rst#noob.outlier.base.BaseOutlierDetection)
#

# ### Store-Product level

from noob.outlier import BaseOutlierDetection
from tests.outlier.test_outlier_detection import OutlierDetectionTest

OutlierDetectionTest.create_test_data(spark)

bod = BaseOutlierDetection(
    data = OutlierDetectionTest.data,
    groupby_features=["store_id", "product_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar = OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=10)
outliers, thresholds = bod.calculate(sales)

# ### Store-product level with rules & extra info

OutlierDetectionTest.rules.show()
OutlierDetectionTest.extra_info_df.show()
extra_info = [{
    "dataframe": OutlierDetectionTest.extra_info_df,
    "name": "brands",
    "join_keys": ["product_id"],
    "join_columns": ["brand"],
    "how": "left"
}]

bod = BaseOutlierDetection(
    data=OutlierDetectionTest.data,
    groupby_features=["store_id", "product_id"],
    stdev_multip=1,
    rules=OutlierDetectionTest.rules
)
sales = bod.preprocess(
    calendar = OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    extra_info=extra_info,
    lookback_days=10)
outliers, thresholds = bod.calculate(sales)

# Assume that we have calculated these thresholds beforehand. We can give it to the class as a parameter and it will not calculate thresholds for the pairs that already have a threshold calculated.
# New pairs' thresholds will be calculated from a higher level (new_thresholds_level) in this case.

bod = BaseOutlierDetection(
    data=OutlierDetectionTest.data,
    groupby_features=["store_id", "product_id"],
    stdev_multip=1,
    recalculation_limit=8,
    current_thresholds=OutlierDetectionTest.thresholds,
    new_thresholds_level=["brand"]
)
sales = bod.preprocess(
    calendar = OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=7,
    extra_info=extra_info)

outliers_new, thresholds_new = bod.calculate(sales)

# ### Product-Promo level

# Add artificial promo info to the data
# Sales increase by 1000% in summer
import pyspark.sql.functions as F
data = OutlierDetectionTest.data

data = data.withColumn("sales_quantity", F.when(F.col("date") > "2020-01-21", F.col("sales_quantity")*10).otherwise(F.col("sales_quantity")))
data = data.withColumn("promo", F.when(F.col("date") > "2020-01-21", F.lit(1)).otherwise(F.lit(0)))

bod = BaseOutlierDetection(
    data=data,
    groupby_features=["promo", "product_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar=OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=30)
outliers, thresholds = bod.calculate(sales)
outliers.show(5)
thresholds.show(5)
# ### Store level

bod = BaseOutlierDetection(
    data=data,
    groupby_features=["store_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar=OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=30)
outliers, thresholds = bod.calculate(sales)
outliers.show(5)
thresholds.show(5)
# ### Product level

bod = BaseOutlierDetection(
    data=data,
    groupby_features=["product_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar=OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=30)
outliers, thresholds = bod.calculate(sales)
outliers.show(5)
thresholds.show(5)

# ### Controls

# +
# Take an old run
bod = BaseOutlierDetection(
    data=data,
    groupby_features=["store_id", "product_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar=OutlierDetectionTest.calendar,
    target_date="2020-01-15",
    lookback_days=10)
outliers_old, _ = bod.calculate(sales)

# Take a new run
bod = BaseOutlierDetection(
    data=data,
    groupby_features=["store_id", "product_id"],
    stdev_multip=1)
sales = bod.preprocess(
    calendar=OutlierDetectionTest.calendar,
    target_date="2020-01-31",
    lookback_days=10)
outliers_new, _ = bod.calculate(sales)

summary_old = bod.generate_summary(outliers_old, ["store_id", "product_id"])
summary_new = bod.generate_summary(outliers_new, ["store_id", "product_id"])
rowwise_summary_old = bod.generate_summary_rowwise(outliers_old, ["product_id"])
rowwise_summary_new = bod.generate_summary_rowwise(outliers_new, ["product_id"])

# +
# Control general sum and row count of two results, also check if duplicate exists
try:
    bod.control(summary_old, summary_new, controller={"row_count": 0.5, "outlier_sum": 0.5}, logger=bod.logger, check_duplicates=True)
except:
    print("Test 1 failed")

# Control product based outlier sum of two results
try:
    bod.control(rowwise_summary_old, rowwise_summary_new, controller={"outlier_sum": 0.5}, logger=bod.logger, groupby_features=["product_id"])
except:
    print("Test 2 failed")
# -

# ## [Enriched data](api/noob.outlier.rst#noob.preprocess.enrich.base.BaseEnrich)
#

# ### Daily enrich store product level

from noob.preprocessing.enrich import BaseEnrich

scope = spark.createDataFrame([(1, 1)], ["product_id", "store_id"])
outlier_df = spark.createDataFrame(
    [(1, 1, "2018-03-08", 1)],
    ["store_id", "product_id", "date", "outlier"],
)
extra_info = [
    {
        "name": "outlier_product_store",
        "dataframe": outlier_df,
        "join_columns": ["outlier"],
        "join_keys": ["store_id", "product_id", "date"],
        "how": "left",
        "default_value": 0
    }
]
calendar = create_calendar_table().select(
    F.col("date").cast("string"),
    F.col("iso_8601_year").cast("string"),
    F.col("iso_8601_week").alias("week").cast("string"),
    F.col("month").cast("string"),
    F.col("day_of_week").cast("string"),
    F.col("iso_8601_year").alias("year").cast("string"),
    F.col("week_start_date").cast("string"),
    F.col("week_end_date").cast("string"),
)
daily_data = spark.createDataFrame(
    [
        (2018, 10, 1, 1, "2018-03-08", 1, 1, 1, 1),
        (2018, 10, 1, 1, "2018-03-09", 1, 1, 1, 1),
    ],
    [
        "year",
        "week",
        "store_id",
        "product_id",
        "date",
        "inventory",
        "sales_quantity",
        "sales_revenue",
        "usable",
    ],
)
aggs = [
    {
        "groupby_features": [
            "store_id",
            "product_id"
        ],
        "aggs": [
            {
                "agg_name": "agg",
                "params": {
                    "agg_name": "sum",
                    "col_name": "sales_quantity",
                    "alias": "sales_quantity"
                }
            },
            {
                "agg_name": "agg",
                "params": {
                    "agg_name": "sum",
                    "col_name": "sales_revenue",
                    "alias": "sales_revenue"
                }
            },
            {
                "agg_name": "agg",
                "params": {
                    "agg_name": "sum",
                    "col_name": "inventory",
                    "alias": "inventory"
                }
            },
            {
                "agg_name": "agg",
                "params": {
                    "agg_name": "sum",
                    "col_name": "usable",
                    "alias": "usable"
                }
            },
            {
                "agg_name": "agg",
                "params": {
                    "agg_name": "sum",
                    "col_name": "outlier",
                    "alias": "outlier"
                }
            },
            {
                "agg_name": "cond_feature",
                "params": {
                    "col_name": "outlier",
                    "operation": "==",
                    "val": 1,
                    "value_for_true_case": 0,
                    "value_for_false_case": "F.col('sales_quantity')",
                    "agg_name": "sum",
                    "alias": "replaced_sales"
                }
            }
        ]
    }
]
lags = [
    # Store-product sales lag
    {
        "lags": [1, 7],
        "features": ["sales_quantity", "sales_revenue"],
        "groupby_features": ["store_id", "product_id"],
        "prefix": "",
        "operation": "lag"
    },
    # Product sales lag
    {
        "lags": [1],
        "features": ["sales_quantity"],
        "groupby_features": ["product_id"],
        "prefix": "product",
        "operation": "lag"
    },
    # lags of MA values
    {
        "lags": [3],
        "features": ["avg_-14_-1_sales_quantity", "avg_-7_0_sales_quantity"],
        "groupby_features": ["product_id", "store_id"],
        "prefix": "",
        "operation": "lag"
    }
]

ma_calculations = [
    {
        "ma_ranges": [(-14, -1), (-7, 0)],
        "features": ["sales_quantity"],
        "groupby_features": ["store_id", "product_id"],
        "prefix": "",
        "operation": "avg"
    },
    {
        "ma_ranges": [(-14, -1), (-7, 0)],
        "features": ["sales_quantity"],
        "groupby_features": ["product_id"],
        "prefix": "product",
        "operation": "avg"
    }
]

enrich = BaseEnrich(
    calendar,
    start_date="2018-03-05",
    target_date="2018-03-12",
    mode="create",
    period="day",
    partition_cols=["year", "week"],
    scope_df=scope
)

enrich.enrich(daily_data, 0, daily_extra_info=extra_info,
              aggs=aggs, ma_calculations=ma_calculations, lags=lags)

enrich.data.show()

# #### Daily enrich summary generation & control

enrich_summary_metadata = enrich.generate_summary(
    enrich.data,
    groupby_columns_list=[["store_id", "product_id", "day_index"]],
    negative_count_cols=["sales_quantity", "sales_revenue", "inventory"],
    null_count_cols=["sales_quantity", "sales_revenue", "inventory"]
)

# passing case
prev_summary = enrich_summary_metadata
compared = BaseEnrich.control(prev_summary, enrich_summary_metadata,
    duplicate_check_key="|".join(["store_id", "product_id", "day_index"]))

# failing case
try:
    prev_summary = enrich_summary_metadata.withColumn("row_count",
        F.col("row_count") * F.lit(1000))
    compared = BaseEnrich.control(prev_summary, enrich_summary_metadata)
except:
    print("IO control failed!")

# #### Daily enrich summary generation & control without specifying params

enrich_summary_metadata = enrich.generate_summary(enrich.data)

# passing case
prev_summary = enrich_summary_metadata
compared = BaseEnrich.control(prev_summary, enrich_summary_metadata)

# failing case
try:
    prev_summary = enrich_summary_metadata.withColumn("row_count",
        F.col("row_count") * F.lit(1000))
    compared = BaseEnrich.control(prev_summary, enrich_summary_metadata)
except:
    print("IO control failed!")

# #### Daily enrich summary generation rowwise & control rowwise

rowwise_summary_agg_config = {
    "groupby_columns": ["product_id", "store_id",
                        "{}_index".format(enrich.period)],
    "aggs": [
        {
            "agg_name": "agg",
            "params": {
                "agg_name": "sum",
                "col_name": "sales_quantity",
                "alias": "sales_quantity"
            }
        },
        {
            "agg_name": "agg",
            "params": {
                "agg_name": "sum",
                "col_name": "inventory",
                "alias": "inventory"
            }
        }
    ]
}
enrich_summary_rowwise = enrich.generate_summary_rowwise(
    enrich.data, agg_config=rowwise_summary_agg_config)

# passing case
prev_summary = enrich_summary_rowwise
compared = BaseEnrich.control_rowwise(
    prev_summary,
    enrich_summary_rowwise,
    key_columns=rowwise_summary_agg_config["groupby_columns"],
    column_tolerances={"inventory": 1, "sales_quantity": 1}
)

# failing case
try:
    prev_summary = enrich_summary_rowwise.withColumn("sales_quantity",
        F.col("sales_quantity") * F.lit(1000))
    compared = BaseEnrich.control_rowwise(
        prev_summary,
        enrich_summary_rowwise,
        key_columns=rowwise_summary_agg_config["groupby_columns"],
        column_tolerances={"sales_quantity": 0.5}
    )
except:
    print("IO control failed")

# ### Daily enrich product level

from noob.preprocessing.enrich import BaseEnrich

# +
aggs[0]["groupby_features"] = ["product_id"]
aggs[0]["aggs"].append({
    "agg_name": "agg",
    "params": {
        "agg_name": "countDistinct",
        "col_name": "store_id",
        "alias": "store_count"
    }
})

enrich = BaseEnrich(
    calendar,
    start_date="2018-03-05",
    target_date="2018-03-12",
    mode="create",
    period="day",
    partition_cols=["year", "week"],
    scope_df=scope
)
# -

enrich.enrich(daily_data, 0, daily_extra_info=extra_info, aggs=aggs)

enrich.data.show()

# ### Daily enrich product level with extra_info having join_expr and select_expr

from noob.preprocessing.enrich import BaseEnrich

# +
aggs[0]["groupby_features"] = ["product_id"]

daily_extra_info = [
    {
        "name": "outlier_product_store",
        "dataframe": outlier_df,
        "join_columns": ["outlier"],
        "join_keys": ["store_id", "product_id", "date"],
        "how": "left",
        "default_value": 0,
        "join_expr": lambda left_df, right_df: (
            (left_df.store_id == right_df.store_id) &
            (left_df.product_id == right_df.product_id) &
            (left_df.date == right_df.date)
        ),
        "select_expr": lambda left_df, right_df: (
            [left_df["*"], right_df.outlier]
        )
    }
]

enrich = BaseEnrich(
    calendar,
    start_date="2018-03-05",
    target_date="2018-03-12",
    mode="create",
    period="day",
    partition_cols=["year", "week"],
    scope_df=scope
)
# -

enrich.enrich(daily_data, 0, daily_extra_info=daily_extra_info, aggs=aggs)

enrich.data.show()

# ### Weekly enrich store-product level

from noob.preprocessing.enrich import BaseEnrich

# We need to remove the following aggregation since we don't want to aggregate stores.

aggs[0]['aggs'][-1]

del aggs[0]['aggs'][-1]

# We also need to change the groupby features

aggs[0]["groupby_features"] = ["store_id", "product_id"]

# And set the period as week

enrich = BaseEnrich(
    calendar,
    start_date="2018-03-05",
    target_date="2018-03-12",
    mode="create",
    period="week",
    partition_cols=["year", "week"],
    scope_df=scope
)

enrich.enrich(daily_data, 0, daily_extra_info=extra_info, aggs=aggs)

enrich.data.show()

# ### Weekly enrich product level

from noob.preprocessing.enrich import BaseEnrich

aggs[0]["groupby_features"] = ["product_id"]
aggs[0]["aggs"].append({
    "agg_name": "agg",
    "params": {
        "agg_name": "countDistinct",
        "col_name": "store_id",
        "alias": "store_count"
    }
})

enrich = BaseEnrich(
    calendar,
    start_date="2018-03-05",
    target_date="2018-03-12",
    mode="create",
    period="week",
    partition_cols=["year", "week"],
    scope_df=scope
)

enrich.enrich(daily_data, 0, daily_extra_info=extra_info, aggs=aggs)

enrich.data.show()

# ## [Lost sales](api/noob.lostsales.rst#noob.lostsales.base.StoreShareLostSales)

# #### Case 1: Default lost sales run

from noob.lostsales import StoreShareLostSales

# Load example dataframes
from tests.lostsales.test_lostsales import LostSalesTest as tables
tables.create_test_data(spark)

lostsales = StoreShareLostSales(
    calendar=tables.calendar,
    agg_cols_day_effect=["cluster"],
    agg_cols_store_share=["product_id"],
    lookback_days=30
)

results = lostsales.calculate(
    df=tables.data,
    scope=tables.scope,
    target_date="2010-07-18",
    life_start_table=tables.life_start_table,
    extra_info=tables.extra_info
    )

results.show()

# #### Case 2: Lost sales run with external store shares

# +
# Calculate shares with BaseShareCalculation module

from noob.postprocessing.breakdown import BaseShareCalculation
from noob.utils.features import union
from noob.utils.date import get_period_calendar

daily_calendar, _ = get_period_calendar(tables.calendar, "day")
df = tables.data.join(daily_calendar.select("date", "day_index"), on="date", how="left")

# We want to calculate lost sales for the week '2010-07-12' to '2010-07-18'.
# That's why we will filter out the data after that week and calculate
# store shares for the following day, which is day_index == 3370.
df = df.filter("date < '2010-07-12'")
future = df.select("store_id", "product_id").distinct().crossJoin(
    spark.createDataFrame([(199,)], ["day_index"]))
df = union([df, future], col_handling="pad")

share_calculator = BaseShareCalculation(
    target_date="2010-07-12",
    calendar=tables.calendar,
    period="day",
    share_level="store_id",
    forecast_level=["product_id"],
    demand_col="sales_quantity"
)
shares = share_calculator.calculate_level_shares(df, [(366,1)])
shares = shares.select("store_id", "product_id", "store_id_share_0")
# -

shares.show(5)

lostsales = StoreShareLostSales(
    calendar=tables.calendar,
    agg_cols_day_effect=["cluster"],
    agg_cols_store_share=["product_id"]
)
results_1 = lostsales.calculate(
    df=tables.data,
    target_date="2010-07-18",
    scope=tables.scope,
    store_shares=shares,
    extra_info=tables.extra_info,
    store_share_col="store_id_share_0")

results_1.show(5)

# +
# I/O Control step

# First, generate summary of the previous run and current run
summary = lostsales.generate_summary(results)
summary_1 = lostsales.generate_summary(results_1)

# Now, run the control step
try:
    lostsales.control(summary_1, summary, tables.scope)
except:
    print('Control failed')
# -

# #### Case 3: Smartlag lost sales

from noob.lostsales.smartlag import SmartlagLostSales

sls = SmartlagLostSales(
    target_date="2010-07-18",
    calendar=tables.calendar,
    groupby_cols_list=[["store_id", "product_id"], ["product_id"], ["cluster"]],
    window=7,
    minimum_periods=3)

df = sls.preprocess(
    df=tables.data,
    scope=tables.scope,
    life_start_table=tables.life_start_table,
    extra_info=tables.extra_info,
    lookback_days=30,
    lookforward_days=7)

df = sls.predict(df)

df = sls.calculate(df)

df.show()

# #### Case 4: ml lost sales

from noob.utils.date import get_period_calendar
from noob.lostsales.ml_lib import MLLostSales
from tests.lostsales.test_mllib_lostsales import MLLostSalesTest as tables
tables.create_test_data(spark)

sls = MLLostSales(
    target_date="2010-07-18",
    calendar=tables.calendar)

df = sls.preprocess(
    df=tables.data,
    scope=tables.data.select("store_id", "product_id").distinct(),
    extra_info=tables.extra_info,
    lookback_days=60,
    lookforward_days=7,
)

df = sls.enrich(df, 119)

df.columns

# +
import pyspark.sql.functions as F
import xgboost as xgb
from noob.forecasting.ml_lib import MachineLearningModel

df = df.withColumn("forecast_group_id", F.lit(1))
forecast_group_column = "forecast_group_id"

cols_to_use = [
 'year',
 'week',
 'sp_avg_-375_-351_in_stock_sales',
 'sp_avg_-368_-362_in_stock_sales',
 'sp_avg_-14_-7_in_stock_sales',
 'sp_avg_7_14_in_stock_sales',
 'p_avg_-375_-351_in_stock_sales',
 'p_avg_-368_-362_in_stock_sales',
 'p_avg_-14_-7_in_stock_sales',
 'p_avg_7_14_in_stock_sales',
 'lag_-21_in_stock_sales',
 'lag_-14_in_stock_sales',
 'lag_-7_in_stock_sales',
 'lag_-1_in_stock_sales',
 'lag_1_in_stock_sales',
 'lag_7_in_stock_sales',
 'lag_14_in_stock_sales',
 'lag_21_in_stock_sales',
 'product_lag_-21_in_stock_sales',
 'product_lag_-14_in_stock_sales',
 'product_lag_-7_in_stock_sales',
 'product_lag_-1_in_stock_sales',
 'product_lag_in_stock_sales',
 'product_lag_1_in_stock_sales',
 'product_lag_7_in_stock_sales',
 'product_lag_14_in_stock_sales',
 'product_lag_21_in_stock_sales']

param = {}
param['eta'] = 0.1
param['grow_policy'] = 'lossguide'
param['colsample_bytree'] = 0.8
param['min_child_weight'] = 1
param['booster'] = 'gbtree'
param['objective'] = 'reg:squarederror'
param['n_estimators'] = 100

output_columns = [
    "store_id", "product_id", "date", "sales_quantity"
]
xgbm = MachineLearningModel(model=xgb.XGBRegressor(**param))
params = {
    "categorical_cols": [],
    "numerical_cols": [],
    "cols_to_use": cols_to_use,
    "target_col": 'in_stock_sales',
    "group_by_cols": [forecast_group_column],
    "use_eval_set": 0.5,
    "summary_output": output_columns
}
# -

df = sls.predict(df, xgbm, params, 7)

ls_weekly = sls.calculate(df)

ls_weekly.show(5, False)

# ## [Volatility estimation](api/noob.volatility_estimation.rst#noob.volatility_estimation.base.VolatilityEstimator)

from noob.volatility_estimation import VolatilityEstimator
from tests.volatiliy_estimation.test_volatility_estimation import (
    VolatiliyEstimationTest)
from pyspark.sql.types import IntegerType, DoubleType, TimestampType
vol_est = VolatilityEstimator(aggregation_levels=["h1", "h2", "h3"],
                              min_observation=3,
                              mode="from_sales",
                              window_length=5,
                              weight="base")
VolatiliyEstimationTest._create_test_data(spark)
df = VolatiliyEstimationTest.df
# mode: from_sales
estimation_df_a = vol_est.calculate_abc(df)
estimation_df_a.cache()
estimation_df_a.show(3)
# mode: from_forecast
vol_est = VolatilityEstimator(aggregation_levels=["h1", "h2", "h3"],
                              min_observation=3,
                              mode="from_forecast",
                              window_length=5,
                              weight="base")
estimation_df_b = vol_est.calculate_abc(df)
estimation_df_b.cache()
estimation_df_b.show(3)
# IO control
# compare 1: dup or null check
summ_a = vol_est.generate_summary(estimation_df_a)
summ_b = vol_est.generate_summary(estimation_df_b)
# compare 2: rowwise check
summ_row_a = vol_est.generate_summary_rowwise(estimation_df_a)

vol_est.control_io(summ_a, summ_b)
vol_est.control_io(summ_row_a, summ_row_a, groupby_features=['level'])

# ## Forecasting

# ## [LLTD](api/noob.forecasting.rst#noob.forecasting.lltd.LLTD)
# #### Daily Product Store Level

# +
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta

from noob.utils.date import get_period_calendar
from noob.utils import forecast_preprocess as fp
from noob.forecasting.lltd import LLTD
from tests.forecasting.test_lltd import LLTDTest
from tests.forecasting.test_smartlag import SmartlagTest
from warp.spark.date import create_calendar_table
# -

# Preprocess

window_size = 7
hierarchy = ["product_id", "store_id"]

calendar = create_calendar_table(start_date="2019-01-01",
                                 end_date="2019-06-01",
                                 week_start_day=1)
calendar = calendar.withColumnRenamed("iso_8601_week", "week")
calendar, _ = get_period_calendar(calendar, period="day")
calendar.show(2)

SmartlagTest.create_test_data(spark)
enriched_df = SmartlagTest.enriched_df
enriched_df.show(2)
# +
# Optional scope filter:
# enriched_df = fp.scope_preprocess(enriched_df, scope_data, ["product_id", "store_id"])
# -

# MULTIPLIER PREPROCESS
enriched_df = fp.multiplier_preprocess(enriched_df)

enriched_df.show(2)

# SPLIT ENRICHED DATA
input_df, future_df = fp.split_train_test(
    input_data=enriched_df,
    forecast_start_date="2019-05-01",
    prediction_length=7,
    calendar_df=calendar,
    period="day",
    max_lookback=None,
)

input_df.show(1)

future_df.show(1)

future_df = fp.future_preprocess(future_df, hierarchy, 1)
future_df.show(1)

# STOCK OUT CORRECTION
input_df = fp.stock_out_preprocess(input_df, hierarchy)
input_df.show(1)
input_df.count()

# REMOVE OUTLIERS
input_df = fp.outlier_preprocess(input_df, "outlier", "0", hierarchy, window_size)
input_df.show(3)
input_df.count()

# ### Model
model = LLTD(["product_id", "store_id"])
base_fc = model.train(
    input_df,
    window_size=7,
    min_window_size=7,
    weight_mode="base")
preds = model.predict(base_fc, future_df)

# #### Weekly Product Level

window_size = 2
hierarchy = ["product_id"]

SmartlagTest.create_test_data(spark)
enriched_df = SmartlagTest.enriched_df
groupby_cols = ["product_id", "year", "week"]
enriched_df = enriched_df.groupBy(groupby_cols).agg(
    F.sum("sales_quantity").alias("sales_quantity"),
    F.sum("sales_revenue").alias("sales_revenue"),
    F.sum("replaced_sales").alias("replaced_sales"),
    F.sum("inventory").alias("inventory"),
    F.sum("outlier").alias("outlier"),
    F.count("store_id").alias("store_count"),
    F.sum("usable").alias("usable"))

calendar = create_calendar_table(start_date="2019-01-01",
                                 end_date="2019-06-01",
                                 week_start_day=1)
calendar = calendar.withColumnRenamed("iso_8601_week", "week").drop("week_index")
calendar, _ = get_period_calendar(calendar, period="week")

c = calendar.select(F.col("iso_8601_year").alias("year"),
                    F.col("week"),
                    F.col("week_index"),
                    F.col("week_start_date").alias("start_date"),
                    F.col("week_end_date").alias("end_date")).distinct()
enriched_df = enriched_df.join(c, on=["year", "week"])
# We will need a column named date
enriched_df = enriched_df.withColumn("date", F.col("start_date"))
# +
# Optional Scope Filter
# enriched_df = scope_preprocess(enriched_df, scope_data, ["product_id", "store_id"])
# -

# MULTIPLIER PREPROCESS
enriched_df = fp.multiplier_preprocess(enriched_df)

# SPLIT ENRICHED DATA
input_df, future_df = fp.split_train_test(
    input_data=enriched_df,
    forecast_start_date="2019-04-29",
    prediction_length=4,
    calendar_df=calendar,
    period="week",
    max_lookback=None,
)

future_df = future_df.select(
    *hierarchy,
    F.col("start_date").alias("forecast_start_date"),
    F.col("end_date").alias("forecast_end_date"),
    F.col("future_multiplier"),
    F.lit(1).alias("prediction_size"),
)

future_df.show(1)

# STOCK OUT CORRECTION
input_df = fp.stock_out_preprocess(input_df, hierarchy)

# REMOVE OUTLIERS
input_df = fp.outlier_preprocess(input_df, "outlier", "0", hierarchy, window_size)

# ### Model
model = LLTD(["product_id"])
base_fc = model.train(
    input_df,
    window_size=2,
    min_window_size=2,
    weight_mode="base")
preds = model.predict(base_fc, future_df)
base_fc.show(2)
preds.show(2)

# ## [RollingLLTD](api/noob.forecasting.rst#noob.forecasting.rolling_lltd.RollingLLTD)

from noob.forecasting.rolling_lltd import RollingLLTD, get_bootstrap_period

LLTDTest.create_test_data(spark)

train_data = LLTDTest.train.drop("future_multiplier")
scope = train_data.select("product_id", "store_id").distinct()
future = get_bootstrap_period("2019-04-04", "2019-04-11", spark)
future = future.crossJoin(scope)

model = RollingLLTD(
    aggregation_level=["product_id", "store_id"],
    time_index="day_index"
)
base_fc = model.train(
    input_data=train_data,
    window_size=7,
    min_window_size=7,
    max_lookback=10,
    future_data=future
)

base_fc.show(2)


preds = model.predict(
    trained_model=base_fc,
    future_data=future,
)

preds.show(2)

# ## [Ensemble Forecasting](api/noob.forecasting.ensemble.rst#noob.forecasting.ensemble.non_negative_least_squares.NonNegativeLeastSquares)

# +
from noob.forecasting.ensemble.non_negative_least_squares import NonNegativeLeastSquares
from tests.ensemble.NonNegativeLeastSquares.test_nnls import NNLSTest

ensemble = NonNegativeLeastSquares(
    group_by_cols=["attr_1", "attr_2"],
    forecast_cols=[
        "actual",
        "lltd_49_base",
        "lltd_28_exponential",
        "lltd_105_base",
    ],
    level=["date", "store_id", "product_id"],
)
NNLSTest.create_test_data(spark)
train = NNLSTest.train
predict = NNLSTest.test

train.show(2)
predict.show(2)

weight_frame = ensemble.train_weights(train, "lsq")
ensemble_forecasts = ensemble.assemble_models(predict, weight_frame)
weight_frame.show(2, truncate=False)
ensemble_forecasts.show(2, truncate=False)
# -

# ## [Smart Selection](api/noob.forecasting.smart_selection.rst#noob.forecasting.smart_selection.base.BaseSmartSelect)

# ### Smart Selection with time based features

# +
from noob.forecasting.smart_selection import BaseSmartSelect
from tests.smartselect.test_smartselect import SmartSelectTest

SmartSelectTest.create_test_data(spark)

error = SmartSelectTest.train
error_ = SmartSelectTest.test
smart = BaseSmartSelect(
    ["attr"])
error_ranked = smart.calculate_metrics(error,metric_scale=False)
smart_select_forecasts = smart.get_smart_selection(error_ranked, error_)
smart_select_forecasts.show(2)
# -

# #### Prepare train data

columns = [
    "store_id", "product_id", "date", "actual", "prediction",
    "forecast_horizon", "error", "abs_error", "model_id", "time_attr_0",
    "time_attr_1"
]
data = [
    (1, 1, "2020-02-01", 5, 5, 1, 0, 0, "a", 0, 0),
    (1, 1, "2020-02-01", 5, 0, 1, 5, 5, "b", 0, 0),
    (1, 1, "2020-02-01", 5, 2, 1, 3, 3, "c", 0, 0),
    (1, 1, "2020-02-01", 5, 2, 1, 3, 3, "d", 0, 0),

    (1, 1, "2020-02-02", 5, 0, 1, 5, 5, "a", 1, 0),
    (1, 1, "2020-02-02", 5, 5, 1, 0, 0, "b", 1, 0),
    (1, 1, "2020-02-02", 5, 4, 1, 1, 1, "c", 1, 0),
    (1, 1, "2020-02-02", 5, 3, 1, 2, 2, "d", 1, 0),

    (1, 1, "2020-02-03", 5, 3, 1, 2, 2, "a", 0, 1),
    (1, 1, "2020-02-03", 5, 2, 1, 3, 3, "b", 0, 1),
    (1, 1, "2020-02-03", 5, 5, 1, 0, 0, "c", 0, 1),
    (1, 1, "2020-02-03", 5, 1, 1, 4, 4, "d", 0, 1),

    (1, 1, "2020-02-04", 5, 2, 1, 3, 3, "a", 1, 1),
    (1, 1, "2020-02-04", 5, 0, 1, 5, 5, "b", 1, 1),
    (1, 1, "2020-02-04", 5, 3, 1, 2, 2, "c", 1, 1),
    (1, 1, "2020-02-04", 5, 5, 1, 0, 0, "d", 1, 1),
]
train = spark.createDataFrame(data, columns)
train.show()

# #### Prepare test data

columns = ["store_id", "product_id", "date", "prediction", "model_id",
           "time_attr_0", "time_attr_1"]
data = [
    (1, 1, "2020-02-09", 1, "a", 0, 0),
    (1, 1, "2020-02-09", 3, "b", 0, 0),
    (1, 1, "2020-02-09", 8, "c", 0, 0),
    (1, 1, "2020-02-09", 4, "d", 0, 0),

    (1, 1, "2020-02-10", 1, "a", 1, 0),
    (1, 1, "2020-02-10", 3, "b", 1, 0),
    (1, 1, "2020-02-10", 8, "c", 1, 0),
    (1, 1, "2020-02-10", 4, "d", 1, 0),

    (1, 1, "2020-02-11", 1, "a", 0, 1),
    (1, 1, "2020-02-11", 3, "b", 0, 1),
    (1, 1, "2020-02-11", 8, "c", 0, 1),
    (1, 1, "2020-02-11", 4, "d", 0, 1),

    (1, 1, "2020-02-12", 1, "a", 1, 1),
    (1, 1, "2020-02-12", 3, "b", 1, 1),
    (1, 1, "2020-02-12", 8, "c", 1, 1),
    (1, 1, "2020-02-12", 4, "d", 1, 1),
]
test = spark.createDataFrame(data, columns)
test.show()

# #### Create smart selection model

smart_selection = BaseSmartSelect(group_by_cols=["time_attr_0", "time_attr_1"])

# #### Find model ranks

metrics = smart_selection.calculate_metrics(train)
metrics.show()

# #### Make prediction

result = smart_selection.get_smart_selection(metrics, test)
result.show()

# ## [Smartlag](api/noob.forecasting.rst#noob.forecasting.smartlag.BaseSmartlag)

# #### Basic usage

# +
from noob.forecasting.smartlag import BaseSmartlag
from tests.forecasting.test_smartlag import SmartlagTest
SmartlagTest.create_test_data(spark)

enriched_df = SmartlagTest.enriched_df
calendar = SmartlagTest.calendar

smartlag = BaseSmartlag(
    enriched_df=enriched_df,
    calendar=calendar,
    forecast_start_date='2019-05-01',
    lookback_weeks=10,
    forecast_horizon_weeks=4,
    y_column='sales_quantity',
    fixed_features=["product_id", "store_id"],
    groupby_cols_list=[["product_id", "store_id",
                        "day_of_week"], ["product_id", "store_id"]],
    extra_output_columns=[])

smartlag.predict(window=7,
                 minimum_periods=3)
smartlag.result.cache()

smartlag.result.show(5)
# -

# #### Smartlag with lostsales

# +
from noob.forecasting.smartlag import BaseSmartlag

SmartlagTest.create_test_data(spark)

enriched_df = SmartlagTest.enriched_df
calendar = SmartlagTest.calendar
lostsales = SmartlagTest.lost_sales

smartlag = BaseSmartlag(
    enriched_df=enriched_df,
    calendar=calendar,
    forecast_start_date="2019-05-01",
    lookback_weeks=10,
    forecast_horizon_weeks=4,
    y_column="sales_quantity",
    fixed_features=["product_id", "store_id"],
    groupby_cols_list=[["product_id", "store_id", "day_of_week", "lostsales"],
                       ["product_id", "store_id"]],
    extra_output_columns=[])

smartlag.predict(window=7,
                 minimum_periods=3,
                 lost_sales_df=lostsales)
smartlag.result.cache()

smartlag.result.show(5)
# -

# ## [Rolling Smartlag](api/noob.forecasting.rst#noob.forecasting.rolling_smartlag.RollingSmartlag)

# ### Rolling Smartlag without lostsales

# +
from noob.utils.io import read_parquet
from noob.forecasting.rolling_smartlag import RollingSmartlag
from tests.forecasting.test_rolling_smartlag import RollingSmartlagTest

SmartlagTest.create_test_data(spark)
enriched_df = SmartlagTest.enriched_df
calendar = SmartlagTest.calendar

rolling_smartlag = RollingSmartlag(
    enriched_df=enriched_df,
    test_df=enriched_df,
    calendar=calendar,
    forecast_start_date="2019-05-01",
    lookback_weeks=5,
    forecast_horizon_weeks=2,
    y_column="sales_quantity",
    fixed_features=["product_id", "store_id"],
    groupby_cols_list=[["product_id", "store_id", "day_of_week"],
                       ["product_id", "store_id"],
                       ["product_id", "day_of_week"],
                       ["product_id"]],
    extra_output_columns=[])

rolling_smartlag.predict(window=7,
                 minimum_periods=3)
rolling_smartlag.result.show()
# -

# ## [Machine learning forecasting](api/noob.forecasting.ml_lib.rst#noob.forecasting.ml_lib.base.MachineLearningModel)

# +
# Trained Model

import shap
import matplotlib.pyplot as plt
import lightgbm as lgb
from noob.forecasting.ml_lib import MachineLearningModel
from tests.forecasting.test_ml_lib import MachineLearningModelTest
import pyspark.sql.functions as F

MachineLearningModelTest.create_test_data(spark)
train_data = MachineLearningModelTest.train_df
test_data = MachineLearningModelTest.test_df

# +
cols_to_use = [
    "product_id", "week", "day_of_week", "avg_sales_lag_1", "avg_sales_lag_2",
    "store_type"]

numerical_cols = [
    "product_id", "week", "day_of_week", "avg_sales_lag_1", "avg_sales_lag_2"]

categorical_cols = ["store_type"]

group_by_cols = ["forecast_group_id"]

lgbm = MachineLearningModel(
    model=lgb.LGBMRegressor(
        min_data=1,
        n_jobs=-1,
        seed=0,
        early_stopping_rounds=3,
        verbose=1))
params = {
    "categorical_cols": categorical_cols,
    "numerical_cols": numerical_cols,
    "cols_to_use":cols_to_use,
    "target_col": 'avg_sales',
    "group_by_cols": group_by_cols,
    "use_eval_set": 0.1,
    "fit_params": {
        "verbose": 0
    }
}

# +
# Train and predict separately

model = lgbm.train(train_data, **params)
prediction = lgbm.predict(model, test_data, **params)
# -

# #### Train and predict separately

model = lgbm.train(train_data, **params)
prediction = lgbm.predict(model, test_data, **params)

model.show()

prediction.show(2, truncate=False)

# #### Train and predict together

models, preds = lgbm.fit_predict(train_data, test_data, **params)
# -

models[0].show()

# #### Train with different model parameters for each group

params = {
    **params,
    "model_params_by_group": {
        (23,): {
            "early_stopping_rounds": 5,
        },
    }
}
model = lgbm.train(train_data, **params)

model.show()

# #### Print feature importances of the model

# +
from noob.utils.analysis import get_feature_importances

df = get_feature_importances(model_row=models[0].collect()[0])
df
# -

# #### Print descriptive tree plots and SHAP plots of the model

# +
from noob.utils.analysis import get_shap_plots

pdf, explainer, shap_values, force_plot, evals_result = get_shap_plots(
    model_row=models[0].collect()[0],
    test_data=test_data,
    start=0,
    end=5,
    plot_trees=[0, -10, -5, -1])

# +
# See the force plot of your choice by using returned explainer, shap_values and pdf

example_row = 0

shap.force_plot(
        explainer.expected_value, shap_values[example_row, :],
        pdf.iloc[example_row, :])

# +
# See the force plot returned by the function call

force_plot

# +
# See the loss plot by number of estimators

plt.plot(evals_result["valid_0"]["l2"])
# -

# #### Rolling ML

# +
params.update({"period_col": "week_index", "summary_output": ["store_type", "product_id", "week_index"]})

preds = lgbm.rolling_fit_predict(
    train_data=train_data,
    test_data=test_data,
    horizon=4,  # Generate forecasts for 4 periods in each iteration
    step=1,  # Generate forecasts in each period
    test_periods=None,  # None for generating forecasts for each period in the test data
    **params)
# -

preds.show(2, truncate=False)

# ## [Long term forecasting](api/noob.forecasting.long_term.rst#noob.forecasting.long_term.base.BaseLongTermForecaster)
#

# +
import lightgbm as lgb
from noob.forecasting.ml_lib import MachineLearningModel
from noob.forecasting.long_term import BaseLongTermForecaster
from tests.forecasting.test_long_term import BaseLongTermForecasterTest
from noob.forecasting.lltd import LLTD
from warp.spark.date import create_calendar_table
from noob.utils import forecast_preprocess as fp

import pyspark.sql.functions as F
# -

BaseLongTermForecasterTest.create_test_data(spark)
train_data = BaseLongTermForecasterTest.train_data
validation_data = BaseLongTermForecasterTest.validation_data
test_data =  BaseLongTermForecasterTest.test_data

train_data.show(1)
validation_data.show(1)
test_data.show(1)

data = train_data.unionByName(validation_data).unionByName(test_data)

calendar = create_calendar_table(week_start_day=1)
calendar = calendar.select(
    "date",
    F.col("week_start_date").alias("start_date"),
    F.col("week_end_date").alias("end_date"))

data = data.join(calendar, on=['date'], how='left')
data = data.withColumn('sales_quantity', F.col('avg_sales'))
data = data.withColumn('inventory', F.lit(0))

data = fp.multiplier_preprocess(data)
data = fp.stock_out_preprocess(data, ['product_id', 'store_type'])
data = fp.outlier_preprocess(data, 'outlier', '0', ['product_id', 'store_type'], 2)
data = data.withColumn('prediction_size', F.lit(1))
data = data.withColumn('forecast_start_date', F.col('start_date'))
data = data.withColumn('forecast_end_date', F.col('end_date'))

train_data = data.filter(F.col("date") < "2020-02-24")
validation_data = data.filter(
    (F.col("date") >= "2020-02-24") &
    (F.col("date") < "2020-03-16"))
test_data = data.filter(F.col("date") >= "2020-03-16")

train_data.show(1)

test_data.show(1)

# +
lgbm = MachineLearningModel(
    model=lgb.LGBMRegressor(
        n_jobs=-1,
        num_leaves=127,
        learning_rate=0.05,
        n_estimators=10,
        min_child_samples=1000,
        feature_fraction=0.8,
        bagging_fraction=0.8,
        early_stopping_rounds=3,
        verbose=0
    )
)

lltd = LLTD(['product_id', 'store_type'])
# -

lltd_params = {
    "window_size": 1,
    "min_window_size": 1,
    "weight_mode": 'base',
    "summary_output": ['avg_sales', 'week_index'],
    "replace_sales_quantity": True
}

ml_params = {
    "categorical_cols": ['store_type'],
    "numerical_cols": [
        "product_id", "week", "day_of_week", "avg_sales_lag_1",
        "avg_sales_lag_2"],
    "cols_to_use": [
        "product_id", "week", "day_of_week", "avg_sales_lag_1",
        "avg_sales_lag_2", "store_type"],
    "target_col": 'avg_sales',
    "group_by_cols": ['forecast_group_id'],
    "use_eval_set": 0.5,
    "cols_to_pop": [
        ['avg_sales_lag_1'],
        ['avg_sales_lag_2']],
    "run_for": [12, 13, 14],
    "period_col": 'week_index',
    "summary_output": ['product_id', 'store_type', 'week_index'],
    "fit_params": {
        "verbose": 0
    }
}

models_params = [
    (lltd, lltd_params),
    (lgbm, ml_params)
]

forecaster = BaseLongTermForecaster(
    models_params=models_params,
    target_col='avg_sales',
    eval_based_on=['product_id', 'store_type', 'week_index'],
    base_forecast_id='LGBMRegressor',
    period_col='week_index')

models, predictions = forecaster.run(train_data, validation_data, test_data, 4)

predictions.groupby('model_id').count().show()


# ## [Prophet forecasting](api/noob.forecasting.long_term.rst#noob.forecasting.prophet.base.ProphetForecastSpark)

# +
from noob.forecasting.prophet import ProphetForecast, ProphetForecastSpark
from tests.forecasting.test_prophet import (
    create_train_test_data as create_train_test_data_prophet)

# ### train, test inputs

# +
train_data, test_data = create_train_test_data_prophet(spark)

train_data.orderBy("ds").show(2)
test_data.orderBy("ds", ascending=False).show(2)

# ### create internal prophet model

# +
prophet = ProphetForecast(
    model_params=None, # use defaults
    period="day",
    coef_columns=["yearly"],
    normalize_coef_columns=True
)

# ### create prophet spark wrapper

# +
prophet_model = ProphetForecastSpark(prophet=prophet)

# ### train & test

# +
prophet_args = {
    "group_by_cols": ["product_id"],
    "cols_to_use": ["ds", "y"],
    "target_col": "y",
    "prediction_col": "yhat",
    "additional_predictions": ["yearly"],
    "remove_null_targets": True
}
models, predictions = prophet_model.fit_predict(
    train_data, test_data, **prophet_args)

# ### trained models

# +
models.cache()
models.show()

# ### predictions

# +
predictions.show(2)

# ### fit-predict straight (if your dataset contains both train & test records and you want to obtain predictions for training set as well)

# +
_, input_data = create_train_test_data_prophet(spark)

prophet_args = {
    "group_by_cols": ["product_id"],
    "cols_to_use": ["ds", "y"],
    "target_col": "y",
    "prediction_col": "yhat",
    "additional_predictions": ["yearly"],
    "remove_null_targets": True
}
predictions = prophet_model.fit_predict_straight(input_data, **prophet_args)

# ### predictions

# +
predictions.show(2)


# ## [Store / Product Share Calculation](api/noob.postprocessing.breakdown.rst#noob.postprocessing.breakdown.base.BaseShareCalculation)

from noob.postprocessing.breakdown import BaseShareCalculation
from noob.utils.breakdown import clean_inactives, clean_demand, calculate_time_shares
import pyspark.sql.functions as F
from noob.utils.date import get_period_calendar
from tests.postprocessing.test_share_calculation import ShareCalculationTest as tables
tables.create_test_data(spark=spark, calendar=True)

# ### Calculate store shares

# +
calendar = tables.calendar
full_data = tables.data.select("store_id", "product_id", "week_index", "sales_quantity")
full_data = full_data.withColumn("hierarchy5", F.lit(1))

forecast = full_data.filter(F.col("week_index") > 494)
data = full_data.filter(F.col("week_index") <= 494)
forecast = forecast.groupby("product_id", "week_index", "hierarchy5").agg(F.sum(F.col("sales_quantity")).alias("prediction"))
# -

forecast.select("week_index").drop_duplicates().show()

c, _ = get_period_calendar(calendar, "week")
c.filter(F.col("date") == "2019-06-26").show()

target_date = "2019-06-26"
period = "week"  # data is a weekly demand data (a property of the dataframe)
share_level = "store_id"  # store share calculation
forecast_level = ["product_id"]  # store shares are calculated for product_id level
min_threshold = 10  # If a product is not sold at more than 10 stores, filter it out
min_reliability_ratio = 0.3  # A product must at least 30% present, otherwise filter it out
quantile_threshold = 0.97  # Demand smoothing first threshold
quantile_reliable_threshold = 0.8  # Demand smoothing second threshold
stdev_multip = 1.5  # Demand smoothing outlier std. dev. multiplier
demand_col = "sales_quantity"  # Enter the column name of demand in the data

share_calculator = BaseShareCalculation(
    target_date=target_date,
    calendar=calendar,
    period=period,
    share_level=share_level,
    forecast_level=forecast_level,
    demand_col=demand_col
)

common_offsets = [(53, 1), (5, 1) ]  # Last year & last 4 weeks are calculated without any changes for all the forecast horizon
sliding_offsets = [(53, 49)]  # Last year next 4 weeks, slide with the forecast horizon

shares = share_calculator.calculate_level_shares(
    data=full_data,
    common_offsets=common_offsets,
    sliding_offsets=sliding_offsets
)

shares.orderBy("week_index", "store_id", "product_id").show()

# ### Example usage with an aggregated forecast

broken_fc = forecast.join(shares, on=["product_id", "week_index"], how="left")

broken_fc.show()

broken_fc = broken_fc.withColumn("prediction", F.col("sales_quantity") * F.col("store_id_share_0"))

broken_fc.show(2)

# ### Calculate time shares

# + load example dataframes
from tests.utils.test_breakdown import BreakdownUtilsTest as tables
tables.create_test_data(spark)
# -

tables.daily_data.printSchema()

shares = calculate_time_shares(
    data=tables.daily_data,
    calendar=tables.calendar,
    period="day",
    wide_period="week",
    hierarchy=["store_id", "product_id"],
    demand_col="sales_quantity"
)
shares.show(2)

"""
full_data = clean_inactives(
    data=full_data,
    period=period,
    share_level=share_level,
    forecast_level=forecast_level,
    min_threshold=min_threshold,
    min_reliability_ratio=min_reliability_ratio)
full_data = clean_demand(
    main_data=full_data,
    share_level=share_level,
    forecast_level=forecast_level,
    cleaning_level=["hierarchy5"],
    quantile_threshold=quantile_threshold,
    quantile_reliable_threshold=quantile_reliable_threshold,
    stdev_multip=stdev_multip,
    period=period,
    demand_col=demand_col)
"""

# ## [Seasonality Calculation](api/noob.seasonality.rst#noob.seasonality.base.Seasonality)

# ### Coefficients calculation
# Data Preparation

from noob.seasonality import Seasonality
from pyspark.sql.types import FloatType
from tests.seasonality.test_seasonality import SeasonalityTest as tables
import os
import tempfile

tables.create_test_data(spark)
sales_df = tables.sales_df
holidays = tables.holidays
calendar = tables.calendar
aggregation_levels = tables.aggregation_levels
promo = tables.promo
join_cols = ['product_id']

# Initialization

seasonality = Seasonality(sales_df=sales_df,
                          aggregation_levels=aggregation_levels,
                          join_cols=join_cols,
                          calendar=calendar)

# no holiday, no muslim effect, single method, 20 week future coefs

coefs, _ = seasonality.calculate_coef(methods=['prophet'],
                                      seasonality_period=20)

coefs.show(3)

# Multiple methods

coefs, _ = seasonality.calculate_coef(methods=['prophet', 'dummy'])

coefs.show(2)

# with muslim effect (method defaults to prophet)

coefs, _ = seasonality.calculate_coef(muslim=True)

coefs.show(2)

# with holiday effects

coefs, effects = seasonality.calculate_coef(df_extra=[holidays, promo])

# holiday and muslim effects, (effects will be a tuple)
# all model - holiday - hijri effect selection combinations are possible

coefs, effects = seasonality.calculate_coef(
    df_extra=[holidays, promo],
    muslim=True)

# combined
coefs, effects = seasonality.calculate_coef(
    methods=[{'method': 'prophet',
              'name': 'prophet0', # name key defaults to method value
              'params': {}}, # params defaults to empty dict
             {'method': 'prophet',
              'name': 'prophet_1',
              'params': {'yearly_seasonality': 5}},
             {'method': 'prophet',
              'name': 'prophet_2',
              'params': {'yearly_seasonality': 5, 'mcmc_samples': 0}},
             {'method': 'dummy'}],
    df_extra=[holidays],
    holiday_params={'yearly_seasonality': 10,
                    'seasonality_prior_scale': 1,
                    'holidays_prior_scale': 10},
    muslim=True,
    muslim_params={'yearly_seasonality': 10,
                   'seasonality_prior_scale': 1},
    seasonality_period=20)

coefs.cache()
coefs.show(5)

# ### Smart selection

final_coef, lltd_output, lltd_summary, agg_level_summary = \
    seasonality.smart_selection(coefs)
final_coef.cache()

agg_level_summary.show()

final_coef.show(2)

lltd_output.show(1)

lltd_summary.show()

# ### Running seasonality calculations with updated holidays
with tempfile.TemporaryDirectory() as output_path:
    join_cols = ['product_id']
    # calculating seasonality coefficients in a normal fashion
    seasonality = Seasonality(sales_df=sales_df,
                            aggregation_levels=aggregation_levels,
                            join_cols=join_cols,
                            calendar=calendar)
    coefs, effects = seasonality.calculate_coef(df_extra=[holidays, promo])
    
    coefs.write.parquet(os.path.join(output_path, "coefs"))
    coefs = spark.read.parquet(os.path.join(output_path, "coefs"))

    effects.write.parquet(os.path.join(output_path, "effects"))
    effects = spark.read.parquet(os.path.join(output_path, "effects"))
    
    final_coef, lltd_output, lltd_summary, agg_level_summary = \
        seasonality.smart_selection(coefs)

    final_coef.write.parquet(os.path.join(output_path, "final_coef"))
    final_coef = spark.read.parquet(os.path.join(output_path, "final_coef"))
    
    lltd_summary.write.parquet(os.path.join(output_path, "lltd_summary"))
    lltd_summary = spark.read.parquet(os.path.join(output_path, "lltd_summary"))
    
    agg_level_summary.write.parquet(os.path.join(output_path, "agg_level_summary"))
    agg_level_summary = spark.read.parquet(os.path.join(output_path, "agg_level_summary"))

    # assumption is that user stores coefs, effects, lltd_summary and
    # agg_level_summary dataframes since these will be needed to update final
    # coefficients with new holidays

    new_promo = promo.unionByName(
        spark.createDataFrame(
            [(2020, 7, 2, "low_promo")],
            ["year", "week", "product_id", "promo"]))
    # let's assume above holidays are recently announced and we want to update
    # our final predictions accordingly

    new_final_coef = seasonality.update_coef(
        coefs=coefs,
        effects=effects,
        df_extra=[holidays, new_promo],
        lltd_summary=lltd_summary,
        agg_level_summary=agg_level_summary)
    new_final_coef.cache()

    # to see whether it worked or not
    filter_cond = "product_id == 2 and year == 2020 and week == 7"

    # observe that for this product_id, AggLevel2=1 is satisfactory agg_level
    agg_level_summary.filter('product_id == 2').show()
    # observe promo and holiday effects
    effects.filter('target_column = 1 and level = "AggLevel2"').show()

    final_coef.filter(filter_cond).show()
    new_final_coef.filter(filter_cond).show()

    # ### I/O Control

    summary_prev = seasonality.generate_summary(final_coef)
    summary_prev.cache()
    summary = seasonality.generate_summary(new_final_coef)
    summary.cache()

    seasonality.control(summary_prev, summary)

# ## [Seasonality Calculation v2](api/noob.seasonality.rst#noob.seasonality.seasonality.Seasonality)

# ### Coefficients calculation

# #### Data Preparation

from noob.seasonality.seasonality import Seasonality as SeasonalityV2
from tests.seasonality.test_seasonality_new import SeasonalityTest as tables
import os
import tempfile

tables.create_test_data(spark)
sales_df = tables.sales_df
holidays = tables.holidays
agg_holidays = tables.agg_holidays
calendar = tables.calendar
aggregation_levels = tables.aggregation_levels
join_cols = ['product_id']
holiday_names = ["NewYearsDay", "high_promo", "low_promo"]

# Initialization
# +

seasonality = SeasonalityV2(
    join_cols=join_cols,
    aggregation_levels=aggregation_levels,
    calendar=calendar,
    target_date=tables.target_date, # last positive sales date
    period="week",
)

# no holiday, no hijri effect, single method, 20 week future coefs
# +

coefs, _ = seasonality.calculate_coef(
    sales_df,
    methods=['prophet'],
    seasonality_period=20
)

coefs.show(3)

# Multiple methods
# +

coefs, _ = seasonality.calculate_coef(sales_df, methods=['prophet', 'dummy'])

coefs.show(2)

# with hijri effect (method defaults to prophet)
# +

coefs, _ = seasonality.calculate_coef(sales_df, hijri=True)

coefs.show(2)

# with holiday effects
# +

coefs, effects = seasonality.calculate_coef(
    sales_df, agg_holidays=agg_holidays, holiday_names=holiday_names)

coefs.show(2)

# holiday and hijri effects
# +

coefs, effects = seasonality.calculate_coef(
    sales_df,
    agg_holidays=agg_holidays,
    holiday_names=holiday_names,
    hijri=True
)

effects.show()

# granular level effect cleaning - re-calculate agggregation level - notice the product_id in effects
# +

coefs, effects = seasonality.calculate_coef(
    sales_df,
    holidays=holidays,
    agg_holidays=agg_holidays,
    holiday_names=holiday_names,
    normalize_before_agg=True
)

coefs.show(2)

effects.show()

# granular level effect cleaning - aggregate calculated effects to aggregation levels
# +

coefs, effects = seasonality.calculate_coef(
    sales_df,
    holidays=holidays,
    agg_holidays=agg_holidays,
    holiday_names=holiday_names,
    normalize_before_agg=True,
    aggregate_effects=True
)

coefs.show(2)

effects.show()

# multiple methods
# +

methods = [
    {
        "method": "prophet",
        "name": "prophet_0",
    },
    {
        "method": "prophet",
        "name": "prophet_1",
        "params": {
            **SeasonalityV2.default_prophet_params,
            "yearly_seasonality": 5
        }
    },
    {"method": "dummy"}
]

coefs, effects = seasonality.calculate_coef(
    sales_df,
    methods=methods,
    holidays=holidays,
    agg_holidays=agg_holidays,
    holiday_names=holiday_names,
    normalize_before_agg=True,
    seasonality_period=20)

coefs.cache()
effects.cache()

coefs.show(5)

# ### Smart selection

# #### Smart Selection in aggregated mode

final_coef, lltd_output, lltd_summary, agg_level_summary = \
    seasonality.smart_selection(sales_df, coefs, mode="aggregated")
final_coef.cache()

agg_level_summary.show()

final_coef.show(2)

lltd_output.show(1)

lltd_summary.show()

# #### Smart Selection in individual mode

final_coef, lltd_output, lltd_summary, agg_level_summary = \
    seasonality.smart_selection(sales_df, coefs, mode="individual")
final_coef.cache()

agg_level_summary.show()

final_coef.show(2)

lltd_output.show(1)

lltd_summary.show()

# ### Seasonality predict mode with new items & updated holidays

# +
new_aggregation_levels = aggregation_levels.unionByName(
    spark.createDataFrame(
        [
            (3, "h4_3", "h3"),
            (4, "h4_2", "h3")
        ],
        ["product_id", "AggLevel1", "AggLevel2"])
)
new_holidays = agg_holidays.filter(
    (F.col("date") != "2019-07-08") &
    (F.col("holiday") != "low_promo")
).unionByName(spark.createDataFrame(
    [("2019-07-08", 2019, 28, "h4_1", "agg_level_1", "high_promo")],
    ["date", "year", "week", "target_column", "level", "holiday"])
)
holiday_diff = new_holidays


with tempfile.TemporaryDirectory() as output_path:
    coefs.write.parquet(os.path.join(output_path, "coefs"))
    coefs = spark.read.parquet(os.path.join(output_path, "coefs"))

    effects.write.parquet(os.path.join(output_path, "effects"))
    effects = spark.read.parquet(os.path.join(output_path, "effects"))

    final_coef, lltd_output, lltd_summary, agg_level_summary = \
        seasonality.smart_selection(sales_df, coefs, mode="individual")

    final_coef.write.parquet(os.path.join(output_path, "final_coef"))
    final_coef = spark.read.parquet(
        os.path.join(output_path, "final_coef"))

    agg_level_summary.write.parquet(
        os.path.join(output_path, "agg_level_summary"))
    agg_level_summary = spark.read.parquet(
        os.path.join(output_path, "agg_level_summary"))

    lltd_summary.write.parquet(os.path.join(output_path, "lltd_summary"))
    lltd_summary = spark.read.parquet(
        os.path.join(output_path, "lltd_summary"))

    seasonality = SeasonalityV2(
        join_cols=join_cols,
        aggregation_levels=new_aggregation_levels,
        calendar=calendar,
        target_date=tables.target_date,
        period="week",
    )

    predicted_coefs = seasonality.predict(
        final_coef,
        agg_level_summary,
        coefs,
        lltd_summary,
        effects,
        holiday_diff,
        holiday_names
    )

    predicted_coefs.show(2)

# ### I/O Control
# +

seasonality = SeasonalityV2(
    join_cols=join_cols,
    aggregation_levels=aggregation_levels,
    calendar=calendar,
    target_date=tables.target_date, # last positive sales date
    period="week",
)

coefs, effects = seasonality.calculate_coef(sales_df)
coefs.cache()

final_coef, lltd_output, lltd_summary, agg_level_summary = \
    seasonality.smart_selection(sales_df, coefs, mode="aggregated")
final_coef.cache()
new_final_coef = final_coef

summary_prev = seasonality.generate_summary(final_coef)
summary_prev.cache()
summary = seasonality.generate_summary(new_final_coef)
summary.cache()

seasonality.control(summary_prev, summary)


# ## Attribute Generation

import pyspark.sql.functions as F
from pyspark.sql import SparkSession
import numpy as np
spark = SparkSession.builder.getOrCreate()

from noob.preprocessing.attribute_generation import PromoAttributeGenerator

# #### [PromoAttributeGenerator](api/noob.preprocessing.attribute_generation.rst#noob.preprocessing.attribute_generation.promo.PromoAttributeGenerator)
# Calculates promo attributes.
# Below generate dummy sales, lostsales and seasonality dataframes.

data = spark.createDataFrame(
    [
        (1, 0, 10, 1,), (1, 0, 15, 2,), (1, 1, 30, 3,),
        (1, 0, 5, 4,), (1, 0, 10, 5,), (1, 0, 5, 6,),
        (1, 1, 40, 7,), (1, 1, 30, 8,), (1, 0, 10, 9,),
        (1, 0, 5, 10,), (2, 1, 20, 1,), (2, 1, 15, 2,),
        (2, 1, 10, 3,), (2, 1, 15, 4,), (2, 0, 10, 5,),
        (2, 0, 5, 6,), (2, 0, 10, 7,), (2, 0, 5, 8,),
        (2, 0, 10, 9,), (2, 0, 5, 10,),
    ],
    ["product_id", "is_promo", "sales_quantity", "week_index"]
)
lostsales_df = spark.createDataFrame(
    [
        (1, 10, 1,), (1, 15, 2,), (1, 30, 3,),
        (1, 5, 4,), (1, 10, 5,), (1, 5, 6,),
        (1, 40, 7,), (1, 30, 8,), (1, 10, 9,),
        (1, 5, 10,), (2, 20, 1,), (2, 15, 2,),
        (2, 10, 3,), (2, 15, 4,), (2, 10, 5,),
        (2, 5, 6,), (2, 10, 7,), (2, 5, 8,),
        (2, 10, 9,), (2, 5, 10,),
    ],
    ["product_id", "lostsales", "week_index"]
)
seasonality_df = spark.createDataFrame(
    [
        (1, 1.0, 1,), (1, 1.5, 2,), (1, 1.7, 3,),
        (1, 1.5, 4,), (1, 1.0, 5,), (1, 0.8, 6,),
        (1, 0.6, 7,), (1, 0.8, 8,), (1, 1.0, 9,),
        (1, 1.0, 10,), (2, 1.5, 1,), (2, 1.7, 2,),
        (2, 1.5, 3,), (2, 1.0, 4,), (2, 0.8, 5,),
        (2, 0.6, 6,), (2, 0.8, 7,), (2, 1.0, 8,),
        (2, 1.2, 9,), (2, 1.0, 10,),
    ],
    ["product_id", "multip", "week_index"]
)

data.show(5)

lostsales_df.show(5)

seasonality_df.show(5)

pag = PromoAttributeGenerator(level=["product_id"], period="week")

data = pag.preprocess(data, lostsales_df, seasonality_df)

# Preprocessing joins lostsales data and calculates demand by sales quantity and demand columns.
# Then, seasonality is joined and demand is divided by multip column to get normalized demand

data.show(5)

result = pag.generate_promo_frequency(
    data=data,
    promo_col="is_promo",
    demand_col="normal_demand",
    last_n_periods=5,
    cut_borders=(0, 0.10, 0.40, 1),
    cut_labels=("low", "medium", "high")
)

result.show()

result = pag.generate_promo_sensitivity(
    data=data,
    promo_col="is_promo",
    demand_col="normal_demand",
    last_n_periods=5,
    cut_borders=(0, 1.5, 3, np.inf),
    cut_labels=("low", "medium", "high")
)

result.show()


import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from warp.spark.date import create_calendar_table
import numpy as np
spark = SparkSession.builder.getOrCreate()

# #### [SeasonalAttributeGenerator](api/noob.preprocessing.attribute_generation.rst#noob.preprocessing.attribute_generation.seasonal.SeasonalAttributeGenerator)
# Calculates seasonal attributes.
#
# Let's create a dummy seasonality dataframe first.
# All multipliers are 0.5 when week < 45; otherwise 10.0

data = [(1, 2020, w, 0.5) if w < 45 else
        (1, 2020, w, 10.0) for w in range(1, 54)]
columns = ["product_id", "year", "week", "multiplier"]
seasonality = spark.createDataFrame(data, columns)
seasonality.orderBy(F.desc("week")).show(12)
seasonality.cache()

# Create a SeasonalAttributeGenerator instance with year = 2020 and smoothing_span = 3.

from noob.preprocessing.attribute_generation import SeasonalAttributeGenerator
calendar = create_calendar_table().select(
    F.col("date").cast("string"),
    F.col("iso_8601_year").cast("string"),
    F.col("iso_8601_week").alias("week").cast("int"),
    F.col("month").cast("string"),
    F.col("day_of_week").cast("string"),
    F.col("iso_8601_year").alias("year").cast("int"),
    F.col("week_start_date").cast("string"),
    F.col("week_end_date").cast("string"),
)
seas = SeasonalAttributeGenerator(
    calendar,
    year=2020,
    smoothing_span=0,
    labels=("not_seasonal", "seasonal", "highly_seasonal"),
    borders=(0, 0.25, 0.75, np.inf)
)

#
# Let's generate the attributes

results = seas.generate(seasonality)

# * We see that this is a highly seasonal product and it has only a single season.
#
# * The season starts at week 45 and ends at week 1.
#
# * `intensity` column is created which indicates how strong the seasonality of a product is.
#
# Products are labeled as according to this `intensity` column. Since the `intensity` 1.70 is greater than
# 0.75, this product is classified as `highly_seasonal`

results.show()

# #### Implement your own intensity calculation
#  By default, `intensity` is calculated as:
#
#  $\ std(multiplier) / mean(multiplier)$
#
#  You can write you own `intensity` function.
#  Let's say that we want to define a new `intensity` function as follows.
#
#  $\ max(multiplier)$
#
#  You can simply override the `get_seasonal_intensity_method` and return the seasonality
#  dataframe with the `intensity` column.
#
#  Higher `intensity` should indicate higher seasonality. Feel free to experiment with
#  more sophisticated `intensity` definitions.
#

# +
from pyspark.sql import Window

class MyClass(SeasonalAttributeGenerator):
    def get_seasonal_intensity(self, seasonality):
        window = Window.partitionBy("product_id")
        seasonality = seasonality.withColumn(
            "intensity", F.max("multiplier").over(window))
        return seasonality


# -

# We can now see that `intensity` is 10.0. Again, this product is
# classified as `highly_seasonal` since 10.0 > 0.75

my_attr_generator = MyClass(
    calendar=calendar,
    year=2020,
    smoothing_span=0,
    labels=("not_seasonal", "seasonal", "highly_seasonal"),
    borders=(0, 0.25, 0.75, np.inf)
)
my_attr_generator.generate(seasonality).show()

# To display the time dependent features in "month-date" format, just pass display_date=True.

from warp.spark.date import create_calendar_table
calendar = create_calendar_table()
results = seas.generate(seasonality, display_date=True)

results.show()

# #### [ProductLifeAttributeGenerator](api/noob.preprocessing.attribute_generation.rst#noob.preprocessing.attribute_generation.product_life.ProductLifeAttributeGenerator)


from noob.preprocessing.attribute_generation import (
    ProductLifeAttributeGenerator)
from noob.utils.date import get_period_calendar
from tests.preprocessing.test_attribute_generator import AttributeGeneratorTest

AttributeGeneratorTest.create_test_data(spark)

daily_data = AttributeGeneratorTest.data
calendar, _ = get_period_calendar(AttributeGeneratorTest.calendar, "day")
plag = ProductLifeAttributeGenerator(target_date="2019-01-01",
                                        level=["product_id"],
                                        calendar=calendar,
                                        newness_threshold_days=300,
                                        border=2,
                                        last_n_days=50)
result = plag.generate_product_life_status(daily_data)
result.show()


In [13]:
error.show()

+--------+----------+----------+------+----------+----------------+-----+---------+--------+----+
|store_id|product_id|      date|actual|prediction|forecast_horizon|error|abs_error|model_id|attr|
+--------+----------+----------+------+----------+----------------+-----+---------+--------+----+
|       1|         1|2020-02-02|     5|         2|               1|   -3|        3|       a|   1|
|       1|         3|2020-02-02|     5|         6|               1|    1|        1|       a|   1|
|       1|         1|2020-02-09|     5|         3|               1|   -2|        2|       a|   1|
|       1|         3|2020-02-09|     5|         4|               1|   -1|        1|       a|   1|
|       1|         1|2020-02-02|     5|         3|               1|   -2|        2|       b|   1|
|       1|         2|2020-02-02|     5|         5|               1|    0|        0|       b|   1|
|       1|         3|2020-02-02|     5|         4|               1|   -1|        1|       b|   1|
|       1|         1

In [2]:
from noob.forecasting.smart_selection import BaseSmartSelect
from noob.tests.smartselect.test_smartselect import SmartSelectTest

SmartSelectTest.create_test_data(spark)

error = SmartSelectTest.train
error_ = SmartSelectTest.test
smart = BaseSmartSelect(
    ["attr"])
error_ranked = smart.calculate_metrics(error,metric_scale=False)
smart_select_forecasts = smart.get_smart_selection(error_ranked, error_)
smart_select_forecasts.show(2)


ModuleNotFoundError: No module named 'noob.tests'

In [5]:
import pyspark.sql.functions as F

In [6]:
columns = [
    "store_id",
    "product_id",
    "date",
    "actual",
    "prediction",
    "forecast_horizon",
    "error",
    "abs_error",
    "model_id",
    "attr",
]
data = [
    (1, 1, "2020-02-02", 5, 2, 1, -3, 3, "a", 1),
    (1, 3, "2020-02-02", 5, 6, 1, 1, 1, "a", 1),
    (1, 1, "2020-02-09", 5, 3, 1, -2, 2, "a", 1),
    (1, 3, "2020-02-09", 5, 4, 1, -1, 1, "a", 1),

    (1, 1, "2020-02-02", 5, 3, 1, -2, 2, "b", 1),
    (1, 2, "2020-02-02", 5, 5, 1, 0, 0, "b", 1),

    (1, 3, "2020-02-02", 5, 4, 1, -1, 1, "b", 1),
    (1, 1, "2020-02-03", 5, 3, 1, -2, 2, "b", 1),
    (1, 3, "2020-02-03", 5, 2, 1, -3, 3, "b", 1),
    (1, 1, "2020-02-04", 5, 3, 1, -2, 2, "b", 1),
    (1, 3, "2020-02-04", 5, 2, 1, -3, 3, "b", 1),
    (1, 1, "2020-02-05", 5, 3, 1, -2, 2, "b", 1),
    (1, 3, "2020-02-05", 5, 2, 1, -3, 3, "b", 1),
    (1, 1, "2020-02-06", 5, 3, 1, -2, 2, "b", 1),
    (1, 3, "2020-02-06", 5, 0, 1, -5, 5, "b", 1),
    (1, 1, "2020-02-09", 5, 8, 1, 3, 3, "b", 1),
    (1, 3, "2020-02-09", 5, 3, 1, -2, 2, "b", 1),

    (1, 1, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 2, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 3, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 4, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 5, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 6, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 7, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 8, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 9, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),
    (1, 10, "2020-02-02", 5, 5, 1, 0, 0, "c", 1),

    (1, 1, "2020-02-02", 5, None, 1, None, None, "d", 1),
    (1, 1, "2020-02-03", 5, None, 1, None, None, "d", 1),
    (1, 1, "2020-02-09", 5, None, 1, None, None, "d", 1),
]
train = F.broadcast(spark.createDataFrame(data, columns))
train.cache()
columns = [
    "store_id",
    "product_id",
    "date",
    "prediction",
    "model_id",
    "attr",
]
data = [
    (1, 1, "2020-02-09", None, "a", 1),
    (1, 1, "2020-02-09", None, "b", 1),
    (1, 1, "2020-02-09", 8, "c", 1),
    (1, 1, "2020-02-09", 4, "d", 1),
    (1, 2, "2020-02-09", 2, "d", 1),
    (1, 2, "2020-02-09", 3, "a", 1),
    (1, 3, "2020-02-09", 6, "a", 1),
    (1, 3, "2020-02-09", None, "b", 1),
    (1, 3, "2020-02-09", 7, "c", 1),
]
test = F.broadcast(spark.createDataFrame(data, columns))
test.cache()

DataFrame[store_id: bigint, product_id: bigint, date: string, prediction: bigint, model_id: string, attr: bigint]

In [14]:
columns = [
    "store_id", "product_id", "date", "actual", "prediction",
    "forecast_horizon", "error", "abs_error", "model_id", "time_attr_0",
    "time_attr_1"
]
data = [
    (1, 1, "2020-02-01", 5, 5, 1, 0, 0, "a", 0, 0),
    (1, 1, "2020-02-01", 5, 0, 1, 5, 5, "b", 0, 0),
    (1, 1, "2020-02-01", 5, 2, 1, 3, 3, "c", 0, 0),
    (1, 1, "2020-02-01", 5, 2, 1, 3, 3, "d", 0, 0),

    (1, 1, "2020-02-02", 5, 0, 1, 5, 5, "a", 1, 0),
    (1, 1, "2020-02-02", 5, 5, 1, 0, 0, "b", 1, 0),
    (1, 1, "2020-02-02", 5, 4, 1, 1, 1, "c", 1, 0),
    (1, 1, "2020-02-02", 5, 3, 1, 2, 2, "d", 1, 0),

    (1, 1, "2020-02-03", 5, 3, 1, 2, 2, "a", 0, 1),
    (1, 1, "2020-02-03", 5, 2, 1, 3, 3, "b", 0, 1),
    (1, 1, "2020-02-03", 5, 5, 1, 0, 0, "c", 0, 1),
    (1, 1, "2020-02-03", 5, 1, 1, 4, 4, "d", 0, 1),

    (1, 1, "2020-02-04", 5, 2, 1, 3, 3, "a", 1, 1),
    (1, 1, "2020-02-04", 5, 0, 1, 5, 5, "b", 1, 1),
    (1, 1, "2020-02-04", 5, 3, 1, 2, 2, "c", 1, 1),
    (1, 1, "2020-02-04", 5, 5, 1, 0, 0, "d", 1, 1),
]
train = spark.createDataFrame(data, columns)
train.show()

# #### Prepare test data

columns = ["store_id", "product_id", "date", "prediction", "model_id",
           "time_attr_0", "time_attr_1"]
data = [
    (1, 1, "2020-02-09", 1, "a", 0, 0),
    (1, 1, "2020-02-09", 3, "b", 0, 0),
    (1, 1, "2020-02-09", 8, "c", 0, 0),
    (1, 1, "2020-02-09", 4, "d", 0, 0),

    (1, 1, "2020-02-10", 1, "a", 1, 0),
    (1, 1, "2020-02-10", 3, "b", 1, 0),
    (1, 1, "2020-02-10", 8, "c", 1, 0),
    (1, 1, "2020-02-10", 4, "d", 1, 0),

    (1, 1, "2020-02-11", 1, "a", 0, 1),
    (1, 1, "2020-02-11", 3, "b", 0, 1),
    (1, 1, "2020-02-11", 8, "c", 0, 1),
    (1, 1, "2020-02-11", 4, "d", 0, 1),

    (1, 1, "2020-02-12", 1, "a", 1, 1),
    (1, 1, "2020-02-12", 3, "b", 1, 1),
    (1, 1, "2020-02-12", 8, "c", 1, 1),
    (1, 1, "2020-02-12", 4, "d", 1, 1),
]
test = spark.createDataFrame(data, columns)
test.show()

# #### Create smart selection model

smart_selection = BaseSmartSelect(group_by_cols=["time_attr_0", "time_attr_1"])

# #### Find model ranks

metrics = smart_selection.calculate_metrics(train)
metrics.show()

# #### Make prediction

result = smart_selection.get_smart_selection(metrics, test)
result.show()


+--------+----------+----------+------+----------+----------------+-----+---------+--------+-----------+-----------+
|store_id|product_id|      date|actual|prediction|forecast_horizon|error|abs_error|model_id|time_attr_0|time_attr_1|
+--------+----------+----------+------+----------+----------------+-----+---------+--------+-----------+-----------+
|       1|         1|2020-02-01|     5|         5|               1|    0|        0|       a|          0|          0|
|       1|         1|2020-02-01|     5|         0|               1|    5|        5|       b|          0|          0|
|       1|         1|2020-02-01|     5|         2|               1|    3|        3|       c|          0|          0|
|       1|         1|2020-02-01|     5|         2|               1|    3|        3|       d|          0|          0|
|       1|         1|2020-02-02|     5|         0|               1|    5|        5|       a|          1|          0|
|       1|         1|2020-02-02|     5|         5|              



+--------+----------+----------+----------+--------+-----------+-----------+
|store_id|product_id|      date|prediction|model_id|time_attr_0|time_attr_1|
+--------+----------+----------+----------+--------+-----------+-----------+
|       1|         1|2020-02-09|         1|       a|          0|          0|
|       1|         1|2020-02-09|         3|       b|          0|          0|
|       1|         1|2020-02-09|         8|       c|          0|          0|
|       1|         1|2020-02-09|         4|       d|          0|          0|
|       1|         1|2020-02-10|         1|       a|          1|          0|
|       1|         1|2020-02-10|         3|       b|          1|          0|
|       1|         1|2020-02-10|         8|       c|          1|          0|
|       1|         1|2020-02-10|         4|       d|          1|          0|
|       1|         1|2020-02-11|         1|       a|          0|          1|
|       1|         1|2020-02-11|         3|       b|          0|          1|

In [8]:
error = train
error_ = test

In [9]:
smart = BaseSmartSelect(
    ["attr"])




In [10]:
smart = BaseSmartSelect(
    ["attr"])
error_ranked = smart.calculate_metrics(error,metric_scale=False)
smart_select_forecasts = smart.get_smart_selection(error_ranked, error_)



In [11]:
smart_select_forecasts.show(2)

+----------+--------+----------+-----------+----------+--------+------------+
|      date|store_id|product_id|original_id|prediction|reliable|    model_id|
+----------+--------+----------+-----------+----------+--------+------------+
|2020-02-09|       1|         3|          c|         7|       2|smart_select|
|2020-02-09|       1|         1|          c|         8|       2|smart_select|
+----------+--------+----------+-----------+----------+--------+------------+
only showing top 2 rows

