In [None]:
def _data_input():
    
    query = f"""select
    r2d2_sub_category_id, r2d2_division_id, r2d2_super_department_id, r2d2_department_id, r2d2_category_id,rpt_lvl_0_nm, rpt_lvl_1_nm, rpt_lvl_2_nm, rpt_lvl_3_nm, rpt_lvl_4_nm, wm_yr_wk,
    sum(net_cncl_qty) net_cncl_qty, 
    avg(unit_price_amt) unit_price_amt,
    sum(unit_price_amt*net_cncl_qty)/sum(net_cncl_qty) as unit_price_amt_w, 
    avg(volume_cu_ft*net_cncl_qty)/sum(net_cncl_qty) volume_cu_ft_w, 
    extract(year from order_plcd_dt) as purchase_year,
    extract(month from order_plcd_dt) as purchase_month, 
    from wmt-one-demand-dev.inventory_placement_sn.historic_sales_nonreplen_2022_2025_3yr_sales_v6 a
    where   
    (r2d2_sub_category_id is not null or r2d2_division_id is not null 
    or r2d2_super_department_id is not null or r2d2_department_id is not null 
    or  r2d2_category_id is not null)
    group by r2d2_sub_category_id, r2d2_division_id, r2d2_super_department_id, r2d2_department_id, r2d2_category_id, rpt_lvl_0_nm, rpt_lvl_1_nm, rpt_lvl_2_nm, rpt_lvl_3_nm, rpt_lvl_4_nm, wm_yr_wk, extract(year from order_plcd_dt), extract(month from order_plcd_dt)"""
    sales_df = read_bq("wmt-one-demand-dev", query)
    
    query = f""" Select * except(inbound), inbound as calc_inbound
        from 
        -- wmt-one-demand-dev.inventory_placement_sn.inbound_3yrs_new_nr
        wmt-one-demand-dev.inventory_placement_sn.inbound_3yrs_new_nr_v2
        -- wmt-one-demand-dev.inventory_placement_sn.inbound_3yrs_new_nr_v2
        -- wmt-one-demand-dev.inventory_placement_sn.inbound_3yrs_new_nr_v3
        where 1 = 1
        -- and r2d2_division_name not in ('UNASSIGNED','UNASSIGNED L0')
        and inbound >=0
        """
    inbound = read_bq("wmt-one-demand-dev", query)
    inbound = inbound.rename(columns = {'WM_YR_WK': 'wm_yr_wk'})
    
    query = f"""SELECT distinct max(wm_month_nbr) month, max(wm_year_nbr) year, WmYrWk
            -- extract(year from min(WM_DATE)) year, extract(month from min(WM_DATE)) month
            FROM wmt-euclid-prod.geodemand.GeoCalendar
            where CalendarDate between '2022-01-01' and current_date()
            group by WmYrWk"""
    cal_df = read_bq("wmt-one-demand-dev", query)
    cal_df = cal_df.rename(columns = {'WmYrWk':'wm_yr_wk'})
    
    query = f"""SELECT distinct EventType, EventName, WmYrWk 
            from `wmt-euclid-prod.geodemand.GeoEvent`
            where Year in (2022,2023,2024,2025)
            and EventType in ('National', 'Cultural', 'Sporting')"""
    event_df = read_bq("wmt-one-demand-dev", query)
    event_df = event_df.rename(columns = {'WmYrWk': 'wm_yr_wk'})
    
    query = f"""SELECT * 
            from wmt-one-demand-dev.inventory_placement_sn.historic_inv_nonreplen_2022_2025_3yr_v7
            -- wmt-one-demand-dev.inventory_placement_sn.historic_inv_nonreplen_2022_2024_3yr_v5
            -- wmt-one-demand-dev.inventory_placement_sn.historic_inv_nonreplen_2022_2024_3yr_v4
            -- wmt-one-demand-dev.inventory_placement_sn.historic_inv_nonreplen_2022_2024_3yr_v1
            where r2d2_sub_category_id is not null and r2d2_division_id is not null and r2d2_super_department_id is not null
            and r2d2_department_id is not null and r2d2_category_id is not null"""
    inv_df = read_bq("wmt-one-demand-dev", query)
    inv_df = inv_df.rename(columns = {'WmYrWk': 'wm_yr_wk'})

    query = f"""SELECT distinct
			 r2d2_division_id
			, r2d2_super_department_id
			, r2d2_department_id
			, r2d2_category_id
			, r2d2_sub_category_id
			, network_type
			, processed_time
			--, SUM(CASE WHEN apparel_flag = TRUE THEN 1 ELSE 0 END) apparel
			--, sum(case when is_sortable = true and apparel_flag = False then 1 else 0 end) sort
			--, sum(case when is_sortable = false and apparel_flag = False then 1 else 0 end) non_sort
			FROM 
			wmt-sg-prod.us_wmt_placements.item_attributes_v2
			WHERE is_replen = FALSE
			and r2d2_sub_category_id is not null
			and processed_time = (Select max(processed_time) from wmt-sg-prod.us_wmt_placements.item_attributes_v2)
			;"""
	item_univ = read_bq("wmt-one-demand-dev", query)
    
    return sales_df,inbound, cal_df, event_df, inv_df, item_univ

def _data_prep(sales_df,inbound, cal_df, inv_df, event_df, group_cols) :
    
    sales_weekly = sales_df.groupby(group_cols, 
                                    as_index = False).agg(sale_qty = ('net_cncl_qty','sum'), 
                                                          unit_price_amt = ('unit_price_amt','mean'),
                                                          unit_price_amt_w = ('unit_price_amt_w','mean'),
                                                          volume_cu_ft_w = ('volume_cu_ft_w','median')
                                                         )
    inbound = inbound.groupby(group_cols+ ['year','month'], 
                               as_index = False).agg(tot_capacity = ('tot_capacity','sum'),
                                                     calc_inbound = ('calc_inbound','sum'),
                                                     order_qty = ('orderqty', 'sum')
                                                    )
    inv_df = inv_df.groupby(group_cols, as_index = False).agg(DaysInv = ('DaysInv','mean'),
                                                              DaysOOS = ('DaysOOS','mean'),
                                                              InvQty_Sun = ('InvQty_Sun','sum'),
                                                              InvQty_Mon = ('InvQty_Mon','sum'),
                                                              InvQty_Tue = ('InvQty_Tue','sum'),
                                                              InvQty_Wed = ('InvQty_Wed','sum'),
                                                              InvQty_Thu = ('InvQty_Thu','sum'),
                                                              InvQty_Fri = ('InvQty_Fri','sum'),
                                                              InvQty_Sat = ('InvQty_Sat','sum')
                                                             )
    event_df['value'] = 1 
    event_df_pivot = event_df.pivot_table(index = 'wm_yr_wk', columns = 'EventType', values = "value", fill_value=0).reset_index()
    event_df_pivot['is_event'] = event_df_pivot[['National', 'Cultural', 'Sporting']].max(axis=1)
    event_df_pivot.columns.name = None
    
    sales_weekly['sale_start_wk'] = sales_weekly.groupby(cat_group)['wm_yr_wk'].transform('min')
    inbound['inbound_start_wk'] = inbound.groupby(cat_group)['wm_yr_wk'].transform('min')
    inbound['inbound_end_wk'] = inbound.groupby(cat_group)['wm_yr_wk'].transform('max')
    inv_df['inv_start_wk'] = inv_df.groupby(cat_group)['wm_yr_wk'].transform('min')
    
    sales_inbound = sales_weekly.merge(inbound, on = group_cols, how = 'outer')
    sales_inbound_inv = inv_df.merge(sales_inbound, on = group_cols, how = 'outer')
    categories = inbound[cat_group].drop_duplicates()
    cal_df = cal_df.merge(event_df_pivot, on = 'wm_yr_wk', how = 'left')
    cal_categories = pd.merge(categories, cal_df, how = 'cross')
    df_merge = pd.merge(cal_categories, sales_inbound_inv, how = 'left', on = group_cols)
    
    df_merge = df_merge.drop(columns = ['month_y','year_y']).rename(columns = {'month_x':'month','year_x':'year'})
    df = df_merge[group_cols+ ['DaysInv','DaysOOS',  'InvQty_Sun', 'InvQty_Mon', 
#                                'DaysInv_subcat','DaysOOS_subcat', 'DaysInv_old', 'DaysOOS_old',
                               'InvQty_Tue', 'InvQty_Wed','InvQty_Thu', 'InvQty_Fri', 'InvQty_Sat', 'inv_start_wk',
                               'sale_qty', 'unit_price_amt', 'unit_price_amt_w', 'sale_start_wk', 'volume_cu_ft_w',
                               'inbound_start_wk', 'inbound_end_wk', 'tot_capacity','calc_inbound', 'year', 'month', 
                               'is_event', 'order_qty']]
#     df['start_wk'] = np.min(df[['sale_start_wk','inbound_start_wk','inv_start_wk']], axis = 1)
    df['start_wk'] = df.groupby(cat_group)['inbound_start_wk'].transform('min')
    df['end_wk'] = df.groupby(cat_group)['inbound_end_wk'].transform('max')
    df = df[(df['wm_yr_wk'] >= df['start_wk']) & (df['wm_yr_wk'] <= df['end_wk'])]
    # include  seasonality
    cum_sum = df.groupby(cat_group+['year'], as_index = False).agg(net_yr_sales = ('sale_qty','sum'),
                                                                   net_yr_ib = ('calc_inbound','sum')
                                                                  )
    cum_sum = df.merge(cum_sum, on = cat_group + ['year'])
    cum_sum['seasonality_yr_sales'] = np.round(cum_sum['sale_qty']/cum_sum['net_yr_sales'],3)
    cum_sum['seasonality_yr_ib'] = np.round(cum_sum['calc_inbound'].astype('float')/cum_sum['net_yr_ib'].astype('float'),3)
    
    
    cum_sum1 = cum_sum.groupby(cat_group, as_index = False).agg(net_sales = ('sale_qty','sum'),
                                                                net_ib = ('calc_inbound','sum')
                                                         )
    cum_sum1 = cum_sum.merge(cum_sum1, on = cat_group)
    cum_sum1['seasonality_sales'] = np.round(cum_sum1['sale_qty']/cum_sum1['net_sales'],3)
    cum_sum1['seasonality_ib'] = np.round(cum_sum1['calc_inbound'].astype('float')/cum_sum1['net_ib'].astype('float'),3)
    
    df = cum_sum1[group_cols + ['DaysInv','DaysOOS', 
#                                 'DaysInv_subcat','DaysOOS_subcat', 'DaysInv_old', 'DaysOOS_old', 
                                'InvQty_Sun', 'InvQty_Mon', 'InvQty_Tue', 
                               'InvQty_Wed','InvQty_Thu', 'InvQty_Fri', 'InvQty_Sat',
                               'sale_qty', 'unit_price_amt', 'unit_price_amt_w', 'volume_cu_ft_w',
                                'tot_capacity','calc_inbound', 'year', 'month', 'order_qty',
                                'is_event',
                                'seasonality_sales', 'seasonality_ib', 'seasonality_yr_sales', 'seasonality_yr_ib',
                              ]]
    cols = ['DaysInv','DaysOOS', 'InvQty_Sun', 'InvQty_Mon', 'InvQty_Tue', 'InvQty_Wed','InvQty_Thu', 'InvQty_Fri', 'InvQty_Sat',
#             'DaysInv_subcat','DaysOOS_subcat', 'DaysInv_old', 'DaysOOS_old', 
            'sale_qty', 'unit_price_amt', 'unit_price_amt_w', 'volume_cu_ft_w',
            'tot_capacity','calc_inbound','order_qty',
            'seasonality_sales', 'seasonality_ib', 'seasonality_yr_sales', 'seasonality_yr_ib'
           ]
    df[cols] =  df[cols].astype('float').round(3)
    df = df.sort_values(by = group_cols)
    return df

def _data_cleaning(df_missing):
    print('Missing inbound values:',round(df_missing['calc_inbound'].isna().sum()*100/len(df_missing),3))
    print('Zero inbound values:',round(len(df_missing[df_missing['calc_inbound']==0])*100/len(df_missing),3))
    print('Missing order qty values:',round(df_missing['order_qty'].isna().sum()*100/len(df_missing),3))
    print('Zero order qty values:',round(len(df_missing[df_missing['order_qty']==0])*100/len(df_missing),3))
    
    # adding inventory and last year week columns
    df_missing['is_event'] = df_missing['is_event'].fillna(0)
    df_missing['last_year_wk'] = df_missing['wm_yr_wk']-100
    df_missing['OH_Inv'] = df_missing['InvQty_Fri'].fillna(df_missing['InvQty_Thu']).fillna(df_missing['InvQty_Wed']).fillna(df_missing['InvQty_Tue']).fillna(df_missing['InvQty_Mon']).fillna(df_missing['InvQty_Sun']).fillna(df_missing['InvQty_Sat'])

    ## outlier removal
    sale_qty_z = (df_missing.sale_qty- np.nanmean(df_missing.sale_qty))/np.nanstd(df_missing.sale_qty)
#     print(sale_qty_z[~sale_qty_z.isna()][sale_qty_z>3].quantile(.996))
    index_temp = sale_qty_z[sale_qty_z > sale_qty_z[~sale_qty_z.isna()][sale_qty_z>3].quantile(.996)].index
    subcat_list_sale = df_missing.iloc[index_temp].r2d2_sub_category_id.unique()
    print('sales:', len(subcat_list_sale),len(df_missing[~df_missing.r2d2_sub_category_id.isin(subcat_list_sale)])/len(df_missing))

    inv_z = (df_missing[~df_missing['OH_Inv'].isna()]['OH_Inv'] - df_missing[~df_missing['OH_Inv'].isna()]['OH_Inv'].mean())/df_missing[~df_missing['OH_Inv'].isna()]['OH_Inv'].std()
    index_temp = inv_z[inv_z > 5].index
    subcat_list_inv = df_missing.iloc[index_temp].r2d2_sub_category_id.unique()
    print('inventory:', len(subcat_list_inv),len(df_missing[~df_missing.r2d2_sub_category_id.isin(subcat_list_inv)])/len(df_missing))

    calc_inbound_z = (df_missing[~df_missing.calc_inbound.isna()]['calc_inbound']- df_missing[~df_missing.calc_inbound.isna()]['calc_inbound'].mean())/df_missing[~df_missing.calc_inbound.isna()]['calc_inbound'].std()

#     print(calc_inbound_z[~calc_inbound_z.isna()][calc_inbound_z>3].quantile(.95))
    index_temp = calc_inbound_z[calc_inbound_z > calc_inbound_z[~calc_inbound_z.isna()][calc_inbound_z>3].quantile(.95)].index
    subcat_list_calc = df_missing.iloc[index_temp].r2d2_sub_category_id.unique()
    print('inbound:', len(subcat_list_calc),len(df_missing[~df_missing.r2d2_sub_category_id.isin(subcat_list_calc)])/len(df_missing))
    
    # combining
    subcat_list = pd.concat([pd.Series(subcat_list_calc), pd.Series(subcat_list_inv), pd.Series(subcat_list_sale)])
    subcat_list = subcat_list.unique()
    print('subcat exclusion list:', subcat_list)

    ## Missing value imputation
    df_missing = df_missing[~df_missing['r2d2_sub_category_id'].isin(subcat_list)]

    if len(df_missing[(df_missing['OH_Inv'].isna()) & (df_missing['sale_qty'] > 0)]) >0:
        mask = df_missing['OH_Inv'].isna() & (df_missing['sale_qty'] > 0)

        # Forward fill missing OH_Inv within each subcategory only for relevant rows
        df_missing.loc[mask, 'OH_Inv'] = df_missing.groupby(cat_group)['OH_Inv'].ffill()
        df_missing.loc[mask, 'DaysOOS'] = df_missing.groupby(cat_group)['DaysOOS'].ffill()
        df_missing.loc[mask, 'DaysInv'] = df_missing.groupby(cat_group)['DaysInv'].ffill()

        print(len(df_missing[(df_missing['OH_Inv'].isna()) & (df_missing['sale_qty'] > 0)]))

    if df_missing['OH_Inv'].isna().sum() > 0 and (df_missing['calc_inbound'] > 0).any():
        mask = df_missing['OH_Inv'].isna() & (df_missing['calc_inbound'] > 0)

        # Apply shift only once per group and maintain DataFrame shape
        df_missing['prev_OH_Inv'] = df_missing.groupby(cat_group)['OH_Inv'].transform(lambda x: x.shift(1))
        df_missing['prev_sale_qty'] = df_missing.groupby(cat_group)['sale_qty'].transform(lambda x: x.shift(1))

        # Compute new OH_Inv values only for missing ones
        df_missing.loc[mask, 'OH_Inv'] = (df_missing['prev_OH_Inv'].fillna(0) 
                                          - df_missing['prev_sale_qty'].fillna(0) 
                                          + df_missing['calc_inbound'])

        df_missing.drop(columns = ['prev_OH_Inv','prev_sale_qty'], inplace = True)
        print("Remaining NaNs:", df_missing.loc[mask, 'OH_Inv'].isna().sum())

    if len(df_missing[(df_missing['OH_Inv'] == 0) & (df_missing['DaysInv'].isna())])>0:
        mask = (df_missing['OH_Inv'] == 0) & (df_missing['DaysInv'].isna())
        df_missing.loc[mask,'DaysInv'] = 0
        df_missing.loc[mask,'DaysOOS'] = 7

    ## not used in this version
    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(cat_group)['unit_price_amt'].transform(lambda x: x.rolling(window = 5, min_periods = 1).median())
    df_missing['unit_price_amt'] = df_missing['unit_price_amt'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])
    print(df_missing['unit_price_amt'].isna().sum())

    df_missing['unit_price_amt'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id', 'r2d2_category_id', 'wm_yr_wk'])[
        'unit_price_amt'].transform(lambda group: group.fillna(group.mean()))
    print(df_missing['unit_price_amt'].isna().sum())

    ## not used in this version
    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(cat_group)['unit_price_amt'].transform(lambda x: x.rolling(window = 5, min_periods = 1).median())
    df_missing['unit_price_amt'] = df_missing['unit_price_amt'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])
    df_missing['unit_price_amt'].isna().sum()

    df_missing['unit_price_amt'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id', 'wm_yr_wk'])[
        'unit_price_amt'].transform(lambda group: group.fillna(group.mean()))
    df_missing['unit_price_amt'].isna().sum()

    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id']
                                                    )['unit_price_amt'].transform(lambda x: x.rolling(window = 104, min_periods = 1).median())
    df_missing['unit_price_amt'] = df_missing['unit_price_amt'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])
    df_missing['unit_price_amt'].isna().sum()

    ## unit_price_amt_w_w & unit_price_amt_w
    # not used in this version
    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(cat_group)['unit_price_amt_w'].transform(lambda x: x.rolling(window = 5, min_periods = 1).median())
    df_missing['unit_price_amt_w'] = df_missing['unit_price_amt_w'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])

    df_missing[['unit_price_amt_w']].isna().sum()

    df_missing['unit_price_amt_w'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id', 'r2d2_category_id', 'wm_yr_wk'])[
        'unit_price_amt_w'].transform(lambda group: group.fillna(group.mean()))
    df_missing[['unit_price_amt_w']].isna().sum()

    ## unit_price_amt_w_w & unit_price_amt_w
    # not used in this version
    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(cat_group)['unit_price_amt_w'].transform(lambda x: x.rolling(window = 5, min_periods = 1).median())
    df_missing['unit_price_amt_w'] = df_missing['unit_price_amt_w'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])

    df_missing[['unit_price_amt_w']].isna().sum()

    df_missing['unit_price_amt_w'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id', 'wm_yr_wk'])[
        'unit_price_amt_w'].transform(lambda group: group.fillna(group.mean()))
    df_missing[['unit_price_amt_w']].isna().sum()

    ## unit_price_amt_w_w & unit_price_amt_w
    # not used in this version
    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(cat_group)['unit_price_amt_w'].transform(lambda x: x.rolling(window = 5, min_periods = 1).median())
    df_missing['unit_price_amt_w'] = df_missing['unit_price_amt_w'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])

    df_missing[['unit_price_amt_w']].isna().sum()

    df_missing = df_missing.sort_values(by = group_cols)
    df_missing['rolling_price'] = df_missing.groupby(['r2d2_division_id', 'r2d2_super_department_id',
                                                       'r2d2_department_id']
                                                    )['unit_price_amt'].transform(lambda x: x.rolling(window = 104, min_periods = 1).median())
    df_missing['unit_price_amt'] = df_missing['unit_price_amt'].fillna(df_missing['rolling_price'])
    df_missing = df_missing.drop(columns = ['rolling_price'])
    df_missing['unit_price_amt'].isna().sum()

    # grouped = df_missing.groupby(cat_group)
    def mapping_prev_year(grouped):
        grouped['last_year_wk'] = grouped['wm_yr_wk']-100
        grouped['last_2_year_wk'] = grouped['wm_yr_wk']-200
        last_yr_map = grouped.set_index('wm_yr_wk')[['sale_qty','DaysOOS']]
        grouped['last_yr_sale_qty'] = grouped['last_year_wk'].map(last_yr_map['sale_qty'])
        grouped['last_2yr_sale_qty'] = grouped['last_2_year_wk'].map(last_yr_map['sale_qty'])
        grouped['last_yr_DaysOOS'] = grouped['last_year_wk'].map(last_yr_map['DaysOOS'])
        grouped['last_2yr_DaysOOS'] = grouped['last_2_year_wk'].map(last_yr_map['DaysOOS'])

        grouped['sale_qty_new'] = np.where(
            grouped['DaysOOS'] < 3, 
            grouped['sale_qty'],  # Keep original sale_qty if DaysOOS < 3
            np.where(
                (grouped['DaysOOS'] >= 3) & 
                (grouped['last_yr_DaysOOS'] < 3) & 
                grouped['last_yr_sale_qty'].notna(),
                grouped['last_yr_sale_qty'],  # Use last year's sales if last_yr_DaysOOS < 3
                np.where(
                    (grouped['DaysOOS'] >= 3) & 
                    (grouped['last_yr_DaysOOS'] >= 3) & 
                    (grouped['last_2yr_DaysOOS'] < 3) & 
                    grouped['last_2yr_sale_qty'].notna(),
                    grouped['last_2yr_sale_qty'],  # Use last 2 years' sales if conditions met
                    grouped['sale_qty']  # Otherwise, keep the original sale_qty
                )
            )
        )
        grouped.drop(columns = ['last_2_year_wk','last_yr_sale_qty','last_2yr_sale_qty',
                                'last_yr_DaysOOS','last_2yr_DaysOOS'], inplace = True)
        return grouped

    df_missing = df_missing.groupby(cat_group, group_keys=False).apply(mapping_prev_year)
    # df_missing['sale_qty_new'] = df_missing['sale_qty_new'].fillna(0)
    df_missing['sale_qty_new'].isna().sum()
   
    return df_missing

def _build_feature(df, group_col, lag_features, lead_features, lags, leads, ly_features, target_feature, forecast_horizon, rolling_windows, time_col):
    
    df = df.sort_values(by = time_col)
    
    # Function to create lag features for any column
    def create_lag_features(df, column, group_col, lags):
        for lag in lags:
            df[f'{column}_lag_{lag}'] = df.groupby(group_col)[column].shift(lag)
        return df
    # Function to create lag features for any column
    def create_lead_features(df, column, group_col, leads):
        for lead in leads:
            df[f'{column}_lead_{lead}'] = df.groupby(group_col)[column].shift(-lead)
        return df
    # Function to create difference features for any column
    def create_diff_features(df, column, group_col, diffs):
        for diff in diffs:
            df[f'{column}_diff_{diff}'] = df.groupby(group_col)[column].diff(diff)
        return df
    # Function to create rolling averages (e.g., 4-week moving average)
    def create_rolling_average(df, column, group_col, weeks):
        for week in weeks:
            df[f'{column}_rolling_{week}w'] = df.groupby(group_col)[column].transform(lambda x: x.rolling(week).mean())
        return df
    # Function to create cumulative sum features for any column
    def create_cumsum_features(df, column, group_col):
        df[f'{column}_cumsum'] = df.groupby(group_col)[column].cumsum()
        return df
    
    # Function to create cumulative sum features for any column
    def create_ly_features(df, column, group_col):
        df_grouped = df.groupby(group_col, group_keys = False)
        df['last_2year_wk'] = df['wm_yr_wk']-200
        
        def create_features(group):
            last_yr_map = group.set_index('wm_yr_wk')[column].to_dict()
            
            group[f'{column}_ly_cur_wk'] = group['last_year_wk'].map(last_yr_map)
            group[f'{column}_l2y_cur_wk'] = group['last_2year_wk'].map(last_yr_map)
            
            for lag in range(1,4):
                group[f'{column}_ly_cur_wk_lag_{lag}'] = group['last_year_wk'
                                                              ].map(lambda x:last_yr_map.get(x-lag, None))
                group[f'{column}_l2y_cur_wk_lag_{lag}'] = group['last_2year_wk'
                                                              ].map(lambda x:last_yr_map.get(x-lag, None))
            for lead in range(1,4):
                group[f'{column}_ly_cur_wk_lead_{lead}'] = group['last_year_wk'
                                                              ].map(lambda x:last_yr_map.get(x+lead, None))
                group[f'{column}_l2y_cur_wk_lead_{lead}'] = group['last_2year_wk'
                                                              ].map(lambda x:last_yr_map.get(x+lead, None))
            
            return group
        
        df = df_grouped.apply(create_features)
        
        return df
    
    
    def add_frequency_features(df, group_col, time_col, value_col):
        """
        Adds frequency domain features to a time-series DataFrame.

        Parameters:
        - df: Input DataFrame.
        - group_col: List of columns to group by (e.g., hierarchical IDs).
        - time_col: Column representing time (e.g., 'wm_yr_wk').
        - value_col: Column for which frequency domain features will be computed (e.g., 'calc_inbound').

        Returns:
        - Updated DataFrame with frequency features.
        """
        def extract_fft_features(group):
            # Ensure the group is sorted by time
            group = group.sort_values(by=time_col)

            # Fill missing values with 0 for FFT computation
            signal = group[value_col].fillna(0).values

            # Apply FFT
            fft_result = np.fft.fft(signal)
            fft_magnitude = np.abs(fft_result)  # Magnitude of the FFT
            fft_freqs = np.fft.fftfreq(len(signal))  # Corresponding frequencies

            # Extract features
            spectral_energy = np.sum(fft_magnitude**2)  # Total energy
            dominant_frequency = fft_freqs[np.argmax(fft_magnitude[:len(fft_freqs)//2])]  # Dominant frequency
            dominant_amplitude = np.max(fft_magnitude[:len(fft_freqs)//2])  # Amplitude of the dominant frequency
            spectral_entropy = -np.sum((fft_magnitude**2 / spectral_energy) * np.log(fft_magnitude**2 / spectral_energy + 1e-8))

            # Add features as new columns
            group[f'{value_col}_spectral_energy'] = spectral_energy
            group[f'{value_col}_dominant_frequency'] = dominant_frequency
            group[f'{value_col}_dominant_amplitude'] = dominant_amplitude
            group[f'{value_col}_spectral_entropy'] = spectral_entropy

            return group
    
        # Apply the function group-wise
        
        valid_grp = df.groupby(group_col)[value_col].count()
        valid_grp = valid_grp[valid_grp>1].index
        df = df[df[group_col].apply(tuple, axis=1).isin(valid_grp)]
        
        df = df.groupby(group_col, group_keys=False).apply(extract_fft_features)

        return df
    
    
#     leads = np.arange(1, 13, 1) # 1-week, 2-week, ... 52-week lags
    diffs = [1, 2, 4] # Week-over-week, 2-week difference, and 4-week difference
    
    # Create lag, lead, and difference features
    for col in lag_features:
        df = create_lag_features(df, col, group_col, lags)
        df = create_diff_features(df, col, group_col, diffs)
        df = create_rolling_average(df, col, group_col, rolling_windows)

    for col in lead_features:
        df = create_lead_features(df, col, group_col, leads)
    
    for col in target_feature:
        df = create_lead_features(df, col, group_col, forecast_horizon)
    
    for col in ly_features:
        df = create_ly_features(df, col, group_col)

    for col in ['calc_inbound']:
        df = add_frequency_features(df, group_col, 'wm_yr_wk', col)
    
    print(f"Columns after feature engineering: {df.columns}")
    
    return df

def generate_hierarchical_features(
    df,
    hierarchy_cols,
    time_col,
    features_to_agg,
    aggs=["mean", "sum"],
    rolling_windows=[4],
    lags=[1],
    sort_col=None,
    prevent_leakage=True
):

    df_out = df.copy()
    sort_col = sort_col or time_col

    for level in hierarchy_cols[1:]:  # Skip most granular level
        print(f"\n Processing level: {level}")

        for feat in features_to_agg:
            # Sort before groupby
            df_sorted = df.sort_values([level, sort_col])

            # Optional leakage protection via .shift(1)
            shifted_series = (
                df_sorted.groupby(level)[feat].shift(1)
                if prevent_leakage else
                df_sorted[feat]
            )

            df_shifted = df_sorted.copy()
            df_shifted[f"{feat}_shifted"] = shifted_series

            # --- Aggregations ---
            print(f" Aggregating {feat} with {aggs}")
            agg_df = (
                df_shifted.groupby([level, time_col])[f"{feat}_shifted"]
                .agg(aggs)
                .reset_index()
            )
            # Rename columns
            agg_df.columns = [level, time_col] + [f"{level}_{feat}_{agg}" for agg in aggs]

            # Merge into output
            df_out = df_out.merge(agg_df, on=[level, time_col], how='left')
            
            for window in rolling_windows:
            
                print(f" Rolling mean (window={window}) for {feat}")

                rolling_col = f"{level}_{feat}_rolling{window}w"

                df_out[rolling_col] = (
                    df.sort_values([level, sort_col])
                    .groupby(level)[feat]
                    .transform(lambda x: x.shift(1).rolling(window=window, min_periods=1).mean())
                )
            

            # --- Lag Features ---
            for lag in lags:
                print(f" Lag {lag} for {feat}")
                df_out[f"{level}_{feat}_lag{lag}w"] = (
                    df_sorted.groupby(level)[feat].shift(lag)
                ).values

    return df_out

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import MinMaxScaler, StandardScaler

def norm_apply(quantiles, test, num_col, cols):
    def normalize_log(column, q):
            column_safe = column.clip(lower=0)  # Replace negative values with 0
            q_safe = max(q, 1e-6)  # Ensure quantile is not too small
            return np.log1p(column_safe + 1) / np.log1p(q_safe + 1)
    norm_test_df = test[num_col].apply(lambda col: normalize_log(col, quantiles[col.name]), axis = 0)
    norm_test_df = pd.concat([norm_test_df, test[cols]],axis = 1)
    return norm_test_df

def _inverse_norm(norm_df, pred, quantiles, num_col, predicted_cols):
    norm_df_denorm = pd.concat([norm_df, pred], axis = 1)
    def inverse_normalize_log(normalized_column, q):
        q_safe = max(q, 1e-6)  # Ensure q is not too small
        out = np.expm1(normalized_column * np.log1p(q_safe+1))-1 
        out[out<=0] = 0
        return out  # Apply the inverse transformation
    norm_df_denorm[num_col] = norm_df_denorm[num_col].apply(lambda col:
                                                            inverse_normalize_log(col, quantiles[col.name]), 
                                                            axis = 0)
    for i in range(1,len(predicted_cols)+1):
        norm_df_denorm[predicted_cols[i-1]] = inverse_normalize_log(pred[f'calc_inbound_lead_{i}_predicted'],
                                                                  quantiles[f'calc_inbound_lead_{i}'])
    return norm_df_denorm

def smape(A,F):
    val = 2*np.abs(F-A)/(np.abs(A) + np.abs(F))
    len_ = np.count_nonzero(~np.isnan(val))
    if len_== 0 and np.nansum(val) == 0:
        return 100
    return 100/len_*np.nansum(val)


sales_df,inbound, cal_df, event_df, inv_df, item_univ = _data_input()
predict_week = sales_df.wm_yr_wk.max()
cal_df = cal_df[cal_df.wm_yr_wk <= predict_week]
df_prep = _data_prep(sales_df,inbound, cal_df, inv_df, event_df, group_cols)
df_missing = df_prep.copy()
df_missing = _data_cleaning(df_missing)

# atleast one year of data
valid_grp = df_missing.groupby(cat_group)['calc_inbound'].count()
print(np.median(valid_grp), np.mean(valid_grp))
valid_grp = valid_grp[valid_grp>10].index
testing = df_missing[df_missing[cat_group].apply(tuple, axis=1).isin(valid_grp)]

univ = item_univ[cat_group].drop_duplicates()
univ['wm_yr_wk'] = predict_week

temp = df_missing.copy()
temp['min_week'] = temp.groupby(cat_group)['wm_yr_wk'].transform(lambda x:x.min())
temp['max_week'] = predict_week
subcats = temp[cat_group].drop_duplicates()
subcats = pd.merge(subcats, cal_df['wm_yr_wk'], how = 'cross')

temp = subcats.merge(temp, on = group_cols, how = 'left')
temp = temp[(temp['wm_yr_wk'] >= temp['min_week']) & 
            (temp['wm_yr_wk'] <= temp['max_week'])]

univ_val = univ.merge(temp, right_on = group_cols, left_on = group_cols, how = 'outer')
univ_val = univ_val[~univ_val.wm_yr_wk.isna()]
univ_val = univ_val[~univ_val.r2d2_sub_category_id.isna()]

params_feature = {
    'lag': np.arange(1, 13, 1),  # Lags from 1 to 3
    'lead': [],
    'forecast_horizon': np.arange(1, 27, 1),  # Forecasting up to 13 weeks ahead
    'rolling_window': [4, 12],  # Rolling window size

    # Feature sets
    'lag_features': [
        'sale_qty_new', 'calc_inbound',
        'unit_price_amt_w',
        'tot_capacity', 'order_qty'
    ],
    'lead_features': [
        'sale_qty_new'], # sale has to be replaced demand
    'target_features': ['calc_inbound'],
    'ly_features': [
        'sale_qty_new', 'calc_inbound'],

    # Categorical grouping
    'cat_group': cat_group
}
df_feats = _build_feature(univ_val, params_feature['cat_group'], params_feature['lag_features'], params_feature['lead_features'], params_feature['lag'],
                          params_feature['lead'], params_feature['ly_features'], params_feature['target_features'], 
                          params_feature['forecast_horizon'], params_feature['rolling_window'], time_col = 'wm_yr_wk')

df_feats = generate_hierarchical_features(
    df=df_feats,  # stacked train + test
    hierarchy_cols=[
        'r2d2_sub_category_id',
        'r2d2_category_id',
        'r2d2_department_id',
        'r2d2_super_department_id',
        'r2d2_division_id'
    ],
    time_col='wm_yr_wk',
    features_to_agg=['calc_inbound', 'sale_qty_new', 'unit_price_amt_w'],
    aggs=['mean', 'sum'],
    rolling_windows=[4],
    lags=[1, 2, 4],
    prevent_leakage=True 
)


def _one_lagged_featues(df, group_cols, cols):
    df = df.sort_values(by = group_cols)
    for col in cols:
        df[f"{col}_lag1"] = df.groupby(group_cols)[col].shift(1)
    return df
cols = ['OH_Inv','DaysInv', 'DaysOOS']
df_feats = _one_lagged_featues(df_feats, params_feature['cat_group'], cols)


seasonality_year = (
    df_feats.groupby(cat_group + ['year'])[['sale_qty_new', 'calc_inbound']]
    .sum()
    .rename(columns={'sale_qty_new': 'net_yr_sales', 'calc_inbound': 'net_yr_ib'})
    .reset_index()
)

# Overall aggregates (across all time)
seasonality_total = (
    df_feats.groupby(cat_group)[['sale_qty_new', 'calc_inbound']]
    .sum()
    .rename(columns={'sale_qty_new': 'net_sales', 'calc_inbound': 'net_ib'})
    .reset_index()
)

df_feats = df_feats.merge(seasonality_year, on=cat_group + ['year'], how='left')
df_feats = df_feats.merge(seasonality_total, on=cat_group, how='left')

for d in [df_feats]:
    d['seasonality_yr_sales'] = np.round(d['sale_qty_new'] / d['net_yr_sales'], 3)
    d['seasonality_yr_ib'] = np.round(d['calc_inbound'] / d['net_yr_ib'], 3)
    d['seasonality_sales'] = np.round(d['sale_qty_new'] / d['net_sales'], 3)
    d['seasonality_ib'] = np.round(d['calc_inbound'] / d['net_ib'], 3)

cols_to_drop = [
    'InvQty_Sun', 'InvQty_Mon', 'InvQty_Tue', 'InvQty_Wed', 'InvQty_Thu', 'InvQty_Fri', 'InvQty_Sat','last_year_wk',
    'last_2year_wk',
    'unit_price_amt', 
    'sale_qty',
    'min_week','max_week',
    'OH_Inv','DaysInv', 'DaysOOS',
    'tot_capacity', 'calc_inbound', 'order_qty',
    'net_yr_sales',  'net_yr_ib', 'net_sales', 'net_ib', 
]
df_feats = df_feats.drop(columns = cols_to_drop)

cols = [f'{cat}_encoded' for cat in cat_group] + ['year'] + cat_group + time_col + events_col
cols_to_exclude = [] # if we want to exclude some columns
num_col = df_feats.columns[~df_feats.columns.isin(cols +cols_to_exclude)]

norm_forecast_df = norm_apply(quantiles, df_feats, num_col, cat_group + events_col + time_col)

original_cols = [
    'r2d2_sub_category_id', 'r2d2_division_id', 'r2d2_super_department_id',
    'r2d2_department_id', 'r2d2_category_id', 'year'
]

from joblib import dump, load
label_enc = joblib.load('encoders.pkl')

for col in original_cols:
    encoder = label_enc[col]
    mapping = dict(zip(encoder.classes_, range(len(encoder.classes_))))
    norm_forecast_df[f"{col}_encoded"] = norm_forecast_df[col].map(mapping).fillna(-1).astype(int)

def _cyclic_features(df1):
    df1['month_sin'] = np.sin((df1.month-1)*(2.*np.pi/12))
    df1['month_cos'] = np.cos((df1.month-1)*(2.*np.pi/12))
    df1['wk_sin'] = np.sin((df1['wm_yr_wk']%100 -1)*(2.*np.pi/52))
    df1['wk_cos'] = np.cos((df1['wm_yr_wk']%100 -1)*(2.*np.pi/52))
    return df1    

forecast_encoded_df = _cyclic_features(norm_forecast_df)

forecast_norm_df = forecast_encoded_df[forecast_encoded_df.wm_yr_wk == predict_week]
forecast_norm_df1 = forecast_norm_df.drop(columns = target_cols + cat_group + time_col)
forecast_norm_df1.fillna(-999, inplace=True)

# 26_week_forecast
model26 = load('26_week_forecast_0619.joblib')


y_forecast = pd.DataFrame(model26.predict(forecast_norm_df1), columns = predicted_cols, index = forecast_norm_df1.index)

y_forecast_denorm1 = _inverse_norm(forecast_norm_df, y_forecast, quantiles, num_col, predicted_cols)

full_set = univ.merge(y_forecast_denorm1, on = cat_group, 
                           how = 'left').rename(columns = {'wm_yr_wk_x':'wm_yr_wk'}).drop(columns = ['wm_yr_wk_y'])

full_set1 = full_set[group_cols + predicted_cols]

full_set1[predicted_cols] = full_set1.groupby(['r2d2_division_id',
                               'r2d2_super_department_id',
                               'r2d2_department_id',
                               'r2d2_category_id',
                               'wm_yr_wk'], as_index = False)[predicted_cols].transform(lambda x:x.fillna(x.mean()))

full_set1[predicted_cols] = full_set1.groupby(['r2d2_division_id',
                               'r2d2_super_department_id',
                               'r2d2_department_id',
                               'wm_yr_wk'], as_index = False)[predicted_cols].transform(lambda x:x.fillna(x.mean()))

full_set1[predicted_cols] = full_set1.groupby(['r2d2_division_id',
                               'r2d2_super_department_id',
                               'wm_yr_wk'], as_index = False)[predicted_cols].transform(lambda x:x.fillna(x.mean()))

full_set1[predicted_cols] = full_set1.groupby(['r2d2_division_id',
                               'wm_yr_wk'], as_index = False)[predicted_cols].transform(lambda x:x.fillna(x.mean()))

full_set1[predicted_cols] = full_set1.groupby(['wm_yr_wk'], as_index = False)[predicted_cols].transform(lambda x:x.fillna(x.mean()))


# Filter historical data to the appropriate window
window_size = 200
history_df = inbound[
    (inbound.wm_yr_wk < predict_week) & 
    (inbound.wm_yr_wk >= predict_week - window_size)
]

disagg_level = history_df.groupby(['r2d2_sub_category_id','network_type'
                                  ])['calc_inbound'].sum().reset_index(name = 'network_type_total')
disagg_level['sub_cat_total'] = disagg_level.groupby(['r2d2_sub_category_id'
                                                     ])['network_type_total'].transform(lambda x:x.sum())
disagg_level['proportions'] = disagg_level['network_type_total']/disagg_level['sub_cat_total']

all_combinations = item_univ[['r2d2_sub_category_id','network_type']].drop_duplicates()
all_combinations['network_type_count'] = all_combinations.groupby(['r2d2_sub_category_id'
                                                                  ])['network_type'
                                                                    ].transform(lambda x:x.nunique())
print(all_combinations)
full_proportions = all_combinations.merge(
        disagg_level[['r2d2_sub_category_id','network_type', 'proportions']], 
        on=['r2d2_sub_category_id','network_type'], 
        how='left'
    )
full_proportions = full_proportions[~full_proportions.network_type.isna()]

full_proportions['proportions'].fillna(0, inplace=True)
full_proportions['promotion_sum'] = full_proportions.groupby(['r2d2_sub_category_id'
                                                             ], as_index = False)['proportions'
                                                                                 ].transform(lambda x:x.sum())

def calculate_final_proportion(row):
    if row['promotion_sum'] > 0:
        # Normalize proportions to ensure they sum to 1
        return row['proportions'] / row['promotion_sum']
    elif row['network_type_count'] > 0:
        # Use equal distribution if no historical data
        return 1.0 / row['network_type_count']
    else:
        return 1.0
full_proportions['promotion_sum'] = pd.to_numeric(full_proportions['promotion_sum'], errors='coerce').fillna(0)
full_proportions['network_type_count'] = pd.to_numeric(full_proportions['network_type_count'], errors='coerce').fillna(0)

full_proportions['final_proportion'] = full_proportions.apply(calculate_final_proportion, axis=1)

result = full_set1[['r2d2_sub_category_id','wm_yr_wk']+
                   predicted_cols].merge(full_proportions, on = ['r2d2_sub_category_id'], how = 'left')

for val in predicted_cols:
    result[f'{val}'] = result[val]*result['final_proportion']

result = result.drop_duplicates()

output_col = [predict_week+i for i in range(1,27)]

y_forecast_denorm = result.rename(columns=dict(zip(predicted_cols, output_col)))

forecasted_output = y_forecast_denorm.melt(id_vars=[col for col in y_forecast_denorm.columns if col not in output_col],
                                           value_vars = output_col,
                                           var_name = 'wm_yr_wk_1',
                                           value_name = 'demand')

forecasted_output.rename(columns = {'wm_yr_wk':'start_date',
                                    'wm_yr_wk_1':'wm_yr_wk',
                                    'r2d2_sub_category_id': 'item_hierarchy_id'}, inplace = True)

today = pd.Timestamp.today()
friday = today + pd.offsets.Week(weekday=4)

# If today is after Friday, offset goes to next week's Friday, so adjust:
if today.weekday() > 4:
    friday = today - pd.Timedelta(days=today.weekday() - 4)
else:
    friday = today + pd.Timedelta(days=4 - today.weekday())

# Set the column
friday = friday.strftime('%Y-%m-%d')

forecasted_output['item_hierarchy_type']  = 'sub_cat_id'
forecasted_output['start_date'] = friday

dummy = forecasted_output[['item_hierarchy_id','item_hierarchy_type','network_type','wm_yr_wk','demand','start_date']]

dummy = dummy.groupby(['item_hierarchy_id', 'item_hierarchy_type', 'network_type', 'wm_yr_wk','start_date'], as_index = False)['demand'].mean()

dummy['start_date'] = dummy['start_date'].astype(str)
dummy.to_parquet('gs://inbound-forecast-sn/inbound_national_forecast_PANDAS/start_date='friday'/inbound_forecast.parquet', index=False)