In [0]:
import numpy as np
import pandas as pd
from scipy import stats
import datetime
import re

import pytz
import matplotlib.pyplot as plt
import seaborn as sns

import pyspark.sql.functions as pyf
import pyspark.sql.types as pyt
from pyspark.sql.window import Window

from sklearn.metrics import r2_score, mean_squared_error
from sklearn.ensemble import GradientBoostingRegressor

from acosta.alerting.preprocessing.functions import get_params_from_database_name, get_hash_org_unit_num_udf

from expectation.functions import pivot_pos, get_pos_prod, get_price
from expectation.model import get_latest_file_path
from expectation import parse_widget_or_raise

is_called = True

In [0]:
dbutils.widgets.text('database_name', '', 'Database Name')
dbutils.widgets.text('company_id', '', 'Company Id')
dbutils.widgets.text('parent_chain_id', '', 'Parent Chain Id')
dbutils.widgets.text('manufacturer_id', '', 'Manufacturer Id')
dbutils.widgets.text('confidence_level', '80', 'Confidence Level')


database_name = parse_widget_or_raise(dbutils.widgets.get('database_name'))
required_int_inputs = (
     'company_id', 'parent_chain_id', 'confidence_level')
int_parsed = [int(parse_widget_or_raise(dbutils.widgets.get(key))) for key in required_int_inputs]
company_id, parent_chain_id, confidence_level = int_parsed

manufacturer_id = dbutils.widgets.get('manufacturer_id')
manufacturer_id = int(manufacturer_id) if len(manufacturer_id) != 0 else None

source, country_code, client, retailer = get_params_from_database_name(database_name).values()
print(f'client: {client}, country_code: {country_code}, retailer: {retailer}')

In [0]:
database_config_dict = get_params_from_database_name(database_name)
# NOTE: we have to define these parameters for 03.0_Retail_Alert_Database_DDL notebook only - we don't use them in this notebook
RETAILER, CLIENT, COUNTRY_CODE = [
    database_config_dict[key]
    for key in ['retailer', 'client', 'country_code']
]

In [0]:
# %run "./03.0_Retail_Alert_Database_DDL"

# Load and Preprocess Data

In [0]:
time_zone = 'Europe/London' if country_code == 'uk' else 'America/Toronto'
local_time = pytz.timezone(time_zone)
today_date = datetime.datetime.now(local_time).strftime('%Y-%m-%d')
today_date = datetime.datetime.strptime(today_date, '%Y-%m-%d').date()

similar_store_not_available = False
similar_store_expired = False

DatabaseName = f'retail_alert_{retailer}_{client}_{country_code}_im'
TableName = f'{DatabaseName}.similar_stores'
df_store_matches = spark.sql(f"SELECT * FROM {TableName}")

if df_store_matches.count() == 0:
    similar_store_not_available = True

if similar_store_not_available:
    dbutils.notebook.run(
        '01_CalculateSimilarStores', 10800, {
            'database_name': database_name,
            'company_id': company_id,
            'parent_chain_id': parent_chain_id,
            'manufacturer_id': dbutils.widgets.get('manufacturer_id')
        })

max_load_ts = df_store_matches.agg(pyf.max('LOAD_TS').alias('max_load_ts')).collect()[0]['max_load_ts']
latest_date = max_load_ts.date()
print('The latest date of matched stores', latest_date)
if latest_date < (today_date - datetime.timedelta(days=30)):
    similar_store_expired = True
df_store_matches = df_store_matches.filter(pyf.col('LOAD_TS') == max_load_ts).drop('LOAD_TS')
print(f'N of similar stores = {df_store_matches.cache().count():,}')

In [0]:
df_full_pos = get_pos_prod(
    database_name,
    spark,
    n_days=90,
    method='Gen2_DLA'
)
print(df_full_pos.select(pyf.max('SALES_DT')).collect()[0][0])
df_full_pos = df_full_pos.filter(pyf.col('SALES_DT') < today_date)
print(df_full_pos.select(pyf.max('SALES_DT')).collect()[0][0])

# update POS_ITEM_QTY based on 0 inventory
df_full_pos = df_full_pos.withColumnRenamed('POS_ITEM_QTY', 'POS_ITEM_QTY_ORG')
df_full_pos = df_full_pos.withColumn(
    'POS_INV',
    pyf.when(
        ((df_full_pos['ON_HAND_INVENTORY_QTY'] == 0)\
        & (df_full_pos['POS_ITEM_QTY_ORG'] == 0)), -1
    ).otherwise(df_full_pos['POS_ITEM_QTY_ORG'])
).withColumnRenamed('POS_INV', 'POS_ITEM_QTY')

print(f'N = {df_full_pos.cache().count():,}')

In [0]:
max_date = df_full_pos.select(pyf.max('SALES_DT')).collect()[0][0]
print(max_date)
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('POS_ITEM_QTY_ORG') != 0).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('ON_HAND_INVENTORY_QTY') != 0).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('POS_ITEM_QTY') != -1).count())

In [0]:
max_date = '2024-07-08'
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('POS_ITEM_QTY_ORG') != 0).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('ON_HAND_INVENTORY_QTY') != 0).count())
print(df_full_pos.filter(pyf.col('SALES_DT') == max_date).filter(pyf.col('POS_ITEM_QTY') != -1).count())

In [0]:
#join POS data and similar stores
df_pivot = pivot_pos(
    df_full_pos,
    'daily',
    2
)

df_pos = df_store_matches.join(
    df_pivot,
    (df_store_matches['ORGANIZATION_UNIT_NUM'] == df_pivot['ORGANIZATION_UNIT_NUM'])
     & (df_store_matches['RETAILER_ITEM_ID']== df_pivot['RETAILER_ITEM_ID']),
    how='left'
).drop(df_pivot['RETAILER_ITEM_ID'])\
.drop(df_pivot['ORGANIZATION_UNIT_NUM'])

print(f'N = {df_pos.cache().count():,}')

In [0]:
# Filter out combinations that will cause index errors
def create_assertion_message(dfi):
    assertion_message = f'Item = {dfi["RETAILER_ITEM_ID"].unique()[0]} & TestStore = {dfi["TEST_ORGANIZATION_UNIT_NUM"].unique()[0]}'
    return assertion_message


def label_data_as_safe_udf(dfi):
    assertion_message = create_assertion_message(dfi)
    assert dfi['RETAILER_ITEM_ID'].nunique() == 1, 'Check GroupBy code: ' + assertion_message
    assert dfi['TEST_ORGANIZATION_UNIT_NUM'].nunique() == 1, 'Check data preprocessing process: ' + assertion_message

    test_store_num = dfi['TEST_ORGANIZATION_UNIT_NUM'].unique()[0]
    n_stores = dfi['ORGANIZATION_UNIT_NUM'].nunique()

    n_store_cond_passed = n_stores > 1
    test_store_is_present = test_store_num in set(dfi['ORGANIZATION_UNIT_NUM'].values)
    is_safe_test_store = n_store_cond_passed and test_store_is_present

    # Process data
    cols_to_drop = ['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM', 'DISTANCE',
                    'CAP_VALUE', 'MAX_ALERTS_PER_STORE']
    data = dfi.drop(columns=cols_to_drop)
    data = data.sort_values(by='ORGANIZATION_UNIT_NUM')
    data = data.dropna(axis='columns', how='all')
    data = data.fillna(0)
    data.index = data['ORGANIZATION_UNIT_NUM']

    # Check if is safe
    is_safe_data = True
    try:
        _ = data.drop(columns='ORGANIZATION_UNIT_NUM', index=test_store_num).values.T
        _ = data.drop(columns='ORGANIZATION_UNIT_NUM').loc[test_store_num].values
        _ = _[-1]
    except:
        is_safe_data = False

    dfi['is_safe'] = is_safe_data and is_safe_test_store
    return dfi


label_data_as_safe_schema = pyt.StructType(
    list(df_pos.schema) +
    [pyt.StructField('is_safe', pyt.BooleanType())]
)

df_pos = df_pos\
    .groupby('RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM')\
    .applyInPandas(label_data_as_safe_udf, schema=label_data_as_safe_schema)
print(f'N_1 = {df_pos.cache().count():,}')

df_pos = df_pos.filter('is_safe is True')
df_pos = df_pos.drop('is_safe')
print(f'N_f = {df_pos.cache().count():,}')

# Estimate Probabilities

In [0]:
def poisson_geometric_expectation(obs, rate: float, n_zero_days: int):
    prob_zero_sales = stats.poisson(rate).pmf(0)
    prob_nonzero_sales = 1 - prob_zero_sales

    # Only the poisson distribution is required to compute the probability of non-zero sales
    prob_given_nonzero = stats.poisson(rate).cdf(obs)
    # Both the poisson and the geometric distribution is required to compute the probability of observed sales
    prob_given_zero = stats.geom(p=prob_nonzero_sales).sf(n_zero_days)

    if isinstance(obs, np.ndarray):
        selector = obs > 0
        prob = prob_given_zero.copy()
        prob[selector] = prob_given_nonzero[selector]
    else:
        prob = prob_given_nonzero if obs > 0 else prob_given_zero
    return prob


def count_trailing_zeros(array: np.ndarray):
    reversed_array = array[::-1]
    n = np.argmax(reversed_array != 0)
    if n == 0 and np.all(reversed_array == 0):
        n = len(reversed_array)
    return n


def compute_lost_sales(obs, rate:float):
    if isinstance(obs, np.ndarray):
        selector = obs > 0
        lost_sales = np.zeros_like(obs)
        lost_sales[selector] = (rate - obs)[selector]
        lost_sales[~selector] = rate[~selector]
    else:
        lost_sales = rate - obs if obs > 0 else rate
    return np.maximum(lost_sales, 0)

In [0]:
fit_udf_schema = pyt.StructType([
    pyt.StructField('RETAILER_ITEM_ID', pyt.StringType()),
    pyt.StructField('TEST_ORGANIZATION_UNIT_NUM', pyt.StringType()),
    pyt.StructField('DATE', pyt.DateType()),
    pyt.StructField('LAST_OBS', pyt.FloatType()),
    pyt.StructField('LAST_RATE', pyt.FloatType()),
    pyt.StructField('LAST_LOST_SALES', pyt.FloatType()),
    pyt.StructField('N_DAYS_LAST_SALE', pyt.IntegerType()),
    pyt.StructField('PROB', pyt.FloatType()),
    pyt.StructField('ZSCORE', pyt.FloatType()),
    pyt.StructField('R2', pyt.FloatType()),
    pyt.StructField('RMSE', pyt.FloatType()),
    pyt.StructField('CAP_VALUE', pyt.DecimalType(6,2)),
    pyt.StructField('MAX_ALERTS_PER_STORE', pyt.IntegerType())
])
fit_udf_cols = [col.name for col in fit_udf_schema]


def fit_udf(dfi):
    assertion_message = create_assertion_message(dfi)
    assert dfi['RETAILER_ITEM_ID'].nunique() == 1, 'Check groupby code' + assertion_message
    assert dfi['TEST_ORGANIZATION_UNIT_NUM'].nunique() == 1, 'Check data preprocessing process' + assertion_message
    test_store_num = dfi['TEST_ORGANIZATION_UNIT_NUM'].unique()[0]

    dfi = dfi.drop_duplicates(['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM', 'ORGANIZATION_UNIT_NUM'])
    cols_to_drop = ['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM',
                    'DISTANCE', 'CAP_VALUE', 'MAX_ALERTS_PER_STORE'
                   ]
    data = dfi.drop(columns=cols_to_drop)
    data = data.sort_values(by='ORGANIZATION_UNIT_NUM')
    data = data.dropna(axis='columns', how='all')
    data = data.fillna(0)
    data.index = data['ORGANIZATION_UNIT_NUM']

    x = data.drop(columns='ORGANIZATION_UNIT_NUM', index=test_store_num).values.T
    y = data.drop(columns='ORGANIZATION_UNIT_NUM').loc[test_store_num].values

    dates = pd.to_datetime(data.drop(columns='ORGANIZATION_UNIT_NUM').columns)

    last_obs = y[-1]
    n_days_since_last_sale = None
    last_rate = None
    prob = None
    last_lost_sales = None
    z_score = None
    r2 = None
    rmse = None

    # Fit Models
    try:
        clf = GradientBoostingRegressor(random_state=1, n_iter_no_change=3, validation_fraction=0.1)
        clf.fit(x[:-1], y[:-1])
        last_rate = clf.predict(x[-1:])[0]
        n_days_since_last_sale = count_trailing_zeros(y.flatten())

        # Compute probabilities
        if last_rate < 0:
            last_rate = 0.0005
        prob = poisson_geometric_expectation(
            obs=last_obs,
            rate=last_rate,
            n_zero_days=n_days_since_last_sale
        )
        z_score = -stats.norm().ppf(prob)  # Larger represent more extreme observations

        # Compute lost items
        last_lost_sales = compute_lost_sales(
            obs=last_obs,
            rate=last_rate,
        )

        # Metrics
        r2 = r2_score(y[:-1], clf.predict(x[:-1]))
        rmse = np.sqrt(mean_squared_error(y[:-1], clf.predict(x[:-1])))
    except Exception as e:
        import traceback
        print(f'=== ERROR OCCURRED fitting {assertion_message} ===')
        print(e)
        traceback.print_exc()

    # Generate results
    df_results = dfi[['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM']].drop_duplicates().astype(str)
    df_results['DATE'] = dates[-1]
    df_results['R2'] = r2
    df_results['RMSE'] = rmse
    df_results['LAST_OBS'] = last_obs
    df_results['LAST_RATE'] = last_rate
    df_results['LAST_LOST_SALES'] = last_lost_sales
    df_results['N_DAYS_LAST_SALE'] = n_days_since_last_sale
    df_results['PROB'] = prob
    df_results['ZSCORE'] = z_score
    df_results['CAP_VALUE'] = dfi['CAP_VALUE'].unique()[0]
    df_results['MAX_ALERTS_PER_STORE'] = dfi['MAX_ALERTS_PER_STORE'].unique()[0]

    return df_results[fit_udf_cols]

df_prepared = df_pos.groupby('RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM').applyInPandas(fit_udf, schema=fit_udf_schema)
print(f'N = {df_prepared.cache().count():,}')

In [0]:
df_prepared = df_prepared\
    .filter(pyf.col('LAST_LOST_SALES') > 0)\
    .filter(pyf.col('Date') >= pyf.date_sub(pyf.from_utc_timestamp(pyf.current_timestamp(), time_zone), 3))
print(f'N after filtering Date and lost sales\n -> {df_prepared.cache().count():,}')

if df_prepared.count() == 0:
    raise ValueError('There is no LSV > 0 or there is no up to date sales record')

df_price = get_price(database_name, spark)

df_prepared = df_prepared.join(
    df_full_pos.selectExpr(
        'RETAILER_ITEM_ID',
        'ORGANIZATION_UNIT_NUM AS TEST_ORGANIZATION_UNIT_NUM',
        'SALES_DT AS DATE',
        'ON_HAND_INVENTORY_QTY',
    ),
    ['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM', 'DATE'],
    how='inner'
).drop(df_full_pos['ORGANIZATION_UNIT_NUM'])

df_prepared = df_prepared\
     .join(df_price.selectExpr(
        'RETAILER_ITEM_ID',
        'ORGANIZATION_UNIT_NUM AS TEST_ORGANIZATION_UNIT_NUM',
        'SALES_DT AS DATE',
        'PRICE',
    ),
         on=['RETAILER_ITEM_ID', 'TEST_ORGANIZATION_UNIT_NUM', 'DATE'])


df_prepared = df_prepared.withColumn(
    'LostSalesValue',
    df_prepared['LAST_LOST_SALES'] * df_prepared['PRICE']
)

df_prepared = df_prepared.withColumn(
    'LostSalesValue',
    pyf.when(
        pyf.col('LostSalesValue').isNotNull(),
        pyf.col('LostSalesValue')
    ).otherwise(pyf.lit(0.0))
)
print(f'N = {df_prepared.cache().count():,}')

## test

In [0]:
display(df_prepared.filter(pyf.col('PROB') <= 1-80/100).count())

In [0]:
df_prepared.filter(pyf.col('PROB') <= 1-70/100).count()

In [0]:
df_prepared.filter(pyf.col('PROB') <= 1-60/100).count()

In [0]:
df_prepared.filter(pyf.col('PROB') <= 1-50/100).count()

## end of test

# Generate Final Alerts

In [0]:
df_prepared = df_prepared.filter(pyf.col('PROB') <= 1-confidence_level/100)
print(f'N after confidence threshold = {df_prepared.cache().count():,}')

In [0]:
maximum_num_alerts = float(
    df_prepared.select('MAX_ALERTS_PER_STORE').collect()[0][0]
)
maximum_num_alerts = max(maximum_num_alerts, 1.00) #NOTE: added this to ensure df is not null in the cases that we don't want any OSA alerts

def filter_alerts_by_confidence(df, confidence_level_list):
    for c in confidence_level_list:
        df = df.withColumn(
            f'CONFIDENCE_{int(c*100)}',
            pyf.when(
                pyf.col('PROB') <= (1 - c),
                pyf.lit(c)
            ).otherwise(pyf.lit(None).cast(pyt.FloatType()))
        )
    cols = [pyf.col(c) for c in df.columns if c.startswith('CONFIDENCE')]
    df = df.withColumn('CONFIDENCE', pyf.coalesce(*cols))

    win_spec = Window.partitionBy(
        'TEST_ORGANIZATION_UNIT_NUM',
        'CONFIDENCE'
    ).orderBy(pyf.col('LostSalesValue').desc())

    df = df.withColumn(
        'lsv_rank',
        pyf.row_number().over(win_spec)
    )
    win_spec = Window.partitionBy(
        'TEST_ORGANIZATION_UNIT_NUM'
    )\
        .orderBy(pyf.col('lsv_rank'))\
        .orderBy(pyf.col('CONFIDENCE').desc())

    df = df.withColumn(
        'row_number',
        pyf.row_number().over(win_spec)
    ).drop(*[c for c in df.columns if c.startswith('CONFIDENCE')])

    df = df\
    .filter(
        df['row_number'] <= maximum_num_alerts)\
    .drop('row_number', 'lsv_rank')
    return df


confidence_level_list = [0.95, 0.9, 0.85, 0.75, 0.5, 0]
df_final = filter_alerts_by_confidence(
    df_prepared, confidence_level_list
)
df_final = df_final.withColumnRenamed('TEST_ORGANIZATION_UNIT_NUM', 'ORGANIZATION_UNIT_NUM')
print(f'N = {df_final.cache().count():,}')

In [0]:
def performance_summary(df_prepared, df_alerts):
    print('Summary of df_prepared:')
    display(df_prepared.describe())
    fig, axs = plt.subplots(1, 4, figsize=(15,5))

    sns.histplot(df_prepared.select('R2').toPandas()['R2'], kde=True, ax=axs[0])
    sns.histplot(df_prepared.select('RMSE').toPandas()['RMSE'], kde=True, ax=axs[1])
    sns.histplot(df_prepared.select('PROB').toPandas()['PROB'], kde=True, ax=axs[2])
    sns.histplot(df_alerts.select('LostSalesValue').toPandas()['LostSalesValue'], kde=True, ax=axs[3])

    fig.suptitle('Model Performance on Training Data')
    plt.axvline(0)
    display(plt.show())
    plt.close()

    print(f"mean R2: {df_prepared.select('R2').toPandas()['R2'].mean():.2f}, mean RMSE: {df_prepared.select('RMSE').toPandas()['RMSE'].mean():.2f}")

    print('\n Summary of final alerts')
    display(df_alerts.describe())

    unique_stores = df_alerts.select(pyf.countDistinct('ORGANIZATION_UNIT_NUM')).collect()[0][0]
    unique_products = df_alerts.select(pyf.countDistinct('RETAILER_ITEM_ID')).collect()[0][0]
    print('\n Min and max number of alerts per ORGANIZATION_UNIT_NUM')
    display(df_alerts.groupBy('ORGANIZATION_UNIT_NUM').count().select(pyf.max('count'), pyf.min('count')))

    print(f'Total Alerts = {df_alerts.count():,} \n Number of unique Stores = {unique_stores} \n Number of unique products = {unique_products}')

performance_summary(df_prepared, df_final)

# Gen2 Magic

In [0]:
unique_org_item = df_full_pos.select('ORGANIZATION_UNIT_NUM', 'RETAILER_ITEM_ID').distinct()
df_final = df_final.withColumn('LSV', pyf.col('LostSalesValue') + pyf.col('CAP_VALUE'))
df_final = df_final.withColumn('FORECAST', pyf.col('LSV') / pyf.col('PRICE') + pyf.col('LAST_OBS'))

# join unique_org_item with df_final and ensure date is correct and forecast is zero
last_date = df_final.agg({'DATE': 'min'}).collect()[0][0]
df_final = unique_org_item.join(df_final,
                                    on=['ORGANIZATION_UNIT_NUM', 'RETAILER_ITEM_ID'],
                                    how='left')\
                              .fillna({'FORECAST': 0})\
                              .withColumn('DATE', pyf.lit(last_date))

In [0]:
fit_forecast_schema = pyt.StructType([
    pyt.StructField('RETAILER_ITEM_ID', pyt.StringType()),
    pyt.StructField('ORGANIZATION_UNIT_NUM', pyt.StringType()),
    pyt.StructField('DATE', pyt.DateType()),
    pyt.StructField('FORECAST', pyt.FloatType()),
])
fit_forecast_cols = [col.name for col in fit_forecast_schema]

def fit_forecast(dfi):
    dfi.index = pd.to_datetime(dfi['DATE'])
    d = dfi.index.min()
    forecast = dfi.FORECAST.values[0]
    ix = pd.date_range(
        start=d - datetime.timedelta(7),
        end=d + datetime.timedelta(7),
        freq='D'
    )
    dfi = dfi.reindex(ix)
    dfi.loc[:,['ORGANIZATION_UNIT_NUM', 'RETAILER_ITEM_ID']] = dfi.loc[:,['ORGANIZATION_UNIT_NUM', 'RETAILER_ITEM_ID']].ffill().bfill()
    dfi['DATE'] = dfi.index
    if forecast != 0:
        dfi.loc[dfi.index < d, 'FORECAST'] = 0
        dfi.loc[dfi.index >= d, 'FORECAST'] = dfi.loc[dfi.index >= d, 'FORECAST'].ffill()
    else:
        dfi.loc[:, 'FORECAST'] = 0

    return dfi[fit_forecast_cols]

df_predictions = df_final.select('RETAILER_ITEM_ID', 'ORGANIZATION_UNIT_NUM', 'DATE', 'FORECAST')\
        .groupby('RETAILER_ITEM_ID', 'ORGANIZATION_UNIT_NUM')\
        .applyInPandas(fit_forecast, schema=fit_forecast_schema)

spark.sparkContext.setCheckpointDir(f'/mnt/processed/gen2_dla/checkpoints/{retailer}_{client}_{country_code}/')
df_predictions = df_predictions.checkpoint()
print(f'N = {df_predictions.cache().count():,}')

# Write Final Alerts into Table

In [0]:
# rename columns and add additional metadata to the predictions
prediction_results_dataframe = df_predictions.selectExpr(
    'ORGANIZATION_UNIT_NUM',
    'RETAILER_ITEM_ID',
    'CURRENT_TIMESTAMP() as LOAD_TS',
    '"Dynamic.Retail.Engine.Gen2" as RECORD_SOURCE_CD',
    'FORECAST as BASELINE_POS_ITEM_QTY',
    'DATE as SALES_DT',
    '"GEN2-DLA" as MODEL_ID'
)

if source is not None:
    itemMasterTableName = f'{database_name}.vw_latest_sat_retailer_item'
    storeMasterTableName = f'{database_name}.vw_latest_sat_organization_unit'

elif source is None:
    itemMasterTableName = f'{database_name}.hub_retailer_item'
    storeMasterTableName = f'{database_name}.hub_organization_unit'

else:
    raise ValueError('Store and item master tables do not exist')

store_master_table = spark.read.table(storeMasterTableName)
store_master_table = store_master_table.alias('OUH')

prediction_results_dataframe = prediction_results_dataframe.alias('PRDF') \
    .join(store_master_table,
          pyf.col('OUH.ORGANIZATION_UNIT_NUM') == pyf.col('PRDF.ORGANIZATION_UNIT_NUM'), 'inner') \
    .join(sqlContext.read.table(itemMasterTableName).alias('RIH'),
          pyf.col('RIH.RETAILER_ITEM_ID') == pyf.col('PRDF.RETAILER_ITEM_ID'), 'inner') \
    .select('PRDF.*', 'OUH.HUB_ORGANIZATION_UNIT_HK', 'RIH.HUB_RETAILER_ITEM_HK')

In [0]:
print(f'Confirm non-zero count {prediction_results_dataframe.cache().count():,}')

In [0]:
# insertDatabaseName = f'retail_alert_{retailer}_{client}_{country_code}_im'
# insertTableName = f'{insertDatabaseName}.DRFE_FORECAST_BASELINE_UNIT'

In [0]:
# try:
#     prediction_results_dataframe\
#         .select('HUB_ORGANIZATION_UNIT_HK', 'HUB_RETAILER_ITEM_HK', 'LOAD_TS', 'RECORD_SOURCE_CD', 'BASELINE_POS_ITEM_QTY',
#                 'MODEL_ID', 'SALES_DT')\
#         .write.mode('overwrite').insertInto(insertTableName, overwrite=True)
# except:
#     selected_data = prediction_results_dataframe.select(
#         'HUB_ORGANIZATION_UNIT_HK', 'HUB_RETAILER_ITEM_HK', 'LOAD_TS',
#         'RECORD_SOURCE_CD', 'BASELINE_POS_ITEM_QTY', 'MODEL_ID', 'SALES_DT')
#     selected_data.write.mode('overwrite').saveAsTable("temp_table")
#     spark.sql(f"INSERT INTO {insertTableName} SELECT * FROM temp_table")
#     print('used modified version')

In [0]:
# if similar_store_expired:
#     dbutils.notebook.run(
#         '01_CalculateSimilarStores', 10800, {
#             'database_name': database_name,
#             'company_id': company_id,
#             'parent_chain_id': parent_chain_id,
#             'manufacturer_id': dbutils.widgets.get('manufacturer_id')
#         })