# 1. Objective
   The Objective of the notebook is to perform missing value treatment for the features mentioned in the config under "missing value treatment" section. The notebook applies different missing value treatment techniques to the specified features based on the config parameters

# 2. Imports

In [0]:
import yaml
import glob
import numpy as np
import pandas as pd
from distutils.command.config import config
from datetime import datetime
import os
import shutil


from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import functions as F
session = get_active_session()

In [None]:
# Check database and schema
print("✅ Snowpark Session Initialized Successfully!")
print("Current Database:", session.get_current_database())
print("Current Schema:", session.get_current_schema())

# 3. Setup environment

## 3.1. Load Config
Currently, we are reading the uploaded config file in each notebook separately. Eventually, we will have the config in a common stage and load from there so that there is one config file at the end

In [None]:
import yaml
stage_path = "@ORANGE_ZONE_SBX_TA.PUBLIC.CONNECTIONS/config_new_PROD.yaml"
stream = session.file.get_stream(stage_path)
yaml_text = stream.read().decode()
app_config = yaml.safe_load(yaml_text)

### 3.2 Update Output Database, Schema , table

In [None]:
output_database = app_config["general_inputs"]["output_database"]
output_schema = app_config["general_inputs"]["output_schema"]
print(output_database, output_schema)

In [None]:
session.use_database(output_database)
session.use_schema(output_schema)
output_table_name = "PROD_MISSING_VALUE_TREATMENT_OUTPUT"

In [None]:
# Example check (optional)
print("✅ Snowpark Session Initialized Successfully!")
print("Current Database:", session.get_current_database())
print("Current Schema:", session.get_current_schema())

## 3.3. Capturing necessary variables

In [0]:
# Get the modeling granularity
modeling_granularity_conf = app_config["general_inputs"]["modeling_granularity"]

# Get date and Dependent variable
dv_config = app_config["general_inputs"]["dependent_variable"]
ds_config = app_config["general_inputs"]["date_var"]

In [0]:
broadcast_date_col = ds_config
broadcast_granularity = modeling_granularity_conf
broadcast_algo_params = app_config['data_processing']['missing_value_treatment']

# 4. Utility Functions

## 4.1. Function - Implement the `mean_across_years` of missing value imputation.

This function is called from the UDF.

In [0]:
def mean_across_years(
    df: pd.DataFrame,
    date_col: str,
    numeric_cols: list,
    modeling_granularity: list,
    time_granularity: str = "weekly",
) -> pd.DataFrame:
    """Function to find the mean across years for different time granularities to impute for missing values

    Parameters
    ----------
    df : pd.DataFrame
        The dataframe which contains values for all the variables
    date_col : str
        The column in the df dataframe which contains datevalues
    numeric_cols : list
        The list of columns containing numeric values
    modeling_granularity : list
        The list of columns containg modeling granularity metrics
    time_granularity : str, optional
        The time granularity at which the dataset is grouped, by default "weekly". Possible values - 'weekly','daily'

    Returns
    -------
    pd.DataFrame
        Returns the dataframe where the missing values are imputed with the mean of the time granularity grouped data

    Raises
    ------
    ValueError
        if it fails to convert date column to datetime datatype
    """
    if not isinstance(df[date_col],(np.datetime64)): # add more types if needed
        try:
            df[date_col] = pd.to_datetime(df[date_col])
        except:
            raise ValueError("Date column is not datetime. Failed to convert.")
    if time_granularity == "weekly":        
        df["week_of_year"] = df[date_col].dt.isocalendar().week  
        model_level_col = ["week_of_year"]  
    elif time_granularity=="daily":
        df["Day"] = df[date_col].dt.day
        df["Month"] = df[date_col].dt.month
        model_level_col = ["Day","Month"]  
    
    df_mean = df[modeling_granularity+model_level_col+numeric_cols].groupby(modeling_granularity+model_level_col).mean()
    df_mean.columns = ["New_"+x if x in numeric_cols else x for x in df_mean.columns]
    df_combine = df.merge(df_mean,on=modeling_granularity+model_level_col,how="left")

    for x in numeric_cols:
        df_combine[x] = np.where(df_combine[x].isna(),df_combine["New_"+x], df_combine[x])
        df_combine = df_combine.drop(["New_"+x], axis = 1)
    df_combine.drop(model_level_col, axis = 1, inplace = True)
    # df_combine[date_col] = df_combine[date_col].astype(str)

    return df_combine

## 4.2. Function - Imputes the missing values

This function implements the missing value imputation methodology across various imputation methods. The function is called from UDF.

In [0]:
def impute_missing_data(
    df: pd.DataFrame,
    col_name: list,
    imputation_type: str,
    arbitrary_value: int = 0,
    window: int = None,
    modeling_granularity: list = [],
    time_granularity: str = None,
    date_col: str = None,
) -> pd.DataFrame:
    """Function to impute or fill the missing data with values based on the imputation type given

    Parameters
    ----------
    df : pd.DataFrame
        The raw dataset which contains the value for all variables
    col_name : list
        The list of names of the columns containing numerical values
    imputation_type : str
        The type of imputation based on which the missing values are filled
    arbitrary_value : int, optional
        The value which is used to fill missing values when imputation type is scalar, by default 0
    window : int, optional
        The size of the window used in the rolling mean and median methods, by default None
    modeling_granularity : list, optional
        The list of names of columns of modeling granularity values, by default []
    time_granularity : str, optional
        The time granularity values required for mean across years imputation type, by default None
    date_col : str, optional
        The name of the column containing date values, by default None

    Returns
    -------
    pd.DataFrame
        the input dataset after the missing values are imputed based on the given imputation type

    Raises
    ------
    ValueError
        if the window size is not provided for the rolling mean type
    ValueError
        if the window size is not provided for the rolling median type
    ValueError
        if the imputation type is not one of them in this function
    """
    cleaned_data = df.copy()
    if imputation_type=="Mean":
        cleaned_data[col_name] = cleaned_data[col_name].fillna(cleaned_data[col_name].mean())
    elif imputation_type=="Median":
        cleaned_data[col_name] = cleaned_data[col_name].fillna(cleaned_data[col_name].median())
    elif imputation_type=="Scalar":
        cleaned_data[col_name] = cleaned_data[col_name].fillna(arbitrary_value)
    elif imputation_type=="Backward_fill":
        cleaned_data[col_name]= cleaned_data[col_name].fillna(method ='bfill').fillna(method ='ffill')
    elif imputation_type=="Forward_fill":
        cleaned_data[col_name]= cleaned_data[col_name].fillna(method ='ffill').fillna(method ='bfill')
    elif imputation_type=="Linear_Interpolation":
        cleaned_data[col_name]= cleaned_data[col_name].interpolate(method='linear').fillna(method ='ffill').fillna(method ='bfill')
    elif imputation_type=="Spline_Interpolation":
        cleaned_data[col_name]= cleaned_data[col_name].interpolate(option='spline').fillna(method ='ffill').fillna(method ='bfill')
    elif imputation_type=="Mode":
        cleaned_data[col_name]= cleaned_data[col_name].fillna(cleaned_data[col_name].mode().iloc[0])
    elif imputation_type=="Zero":
        cleaned_data[col_name]= cleaned_data[col_name].fillna(0)
    elif imputation_type=="Rolling_Mean":
        if window == None:
            raise ValueError("Window Size not provided for rolling mean.")
        temp = cleaned_data[col_name].rolling(window, min_periods=1).mean()
        cleaned_data[col_name] = np.where(cleaned_data[col_name].isna(),temp, cleaned_data[col_name])
        cleaned_data[col_name] = cleaned_data[col_name].fillna(method = 'ffill').fillna(method ='bfill')
    elif imputation_type=="Rolling_Median":
        if window == None:
            raise ValueError("Window Size not provided for rolling median.")
        temp = cleaned_data[col_name].rolling(window, min_periods=1).median()
        cleaned_data[col_name] = np.where(cleaned_data[col_name].isna(),temp, cleaned_data[col_name])
        cleaned_data[col_name] = cleaned_data[col_name].fillna(method = 'ffill').fillna(method ='bfill')
    elif imputation_type == "Mean_Across_Years":
        cleaned_data = mean_across_years(cleaned_data, date_col, col_name, modeling_granularity, time_granularity )
        cleaned_data[col_name]= cleaned_data[col_name].fillna(method ='ffill').fillna(method ='bfill')
    else:
        raise ValueError("Incorrect imputation type")
    return cleaned_data

# 5. Load Data

- if missing date treatment is performed read the respective output
- else read the harmonized data

In [0]:
if app_config["data_processing"]["missing_dates_treatment_needed"]:
  # Read Missing date treatment output
  df = session.table("ORANGE_ZONE_SBX_TA.PUBLIC.PROD_MISSING_DATE_TREATMENT_OUTPUT")
else:
  # Read Data Harmonization output
  df = session.table("ORANGE_ZONE_SBX_TA.PUBLIC.PROD_ADS_STABLE_V4")

Input file path: /dbfs/mnt/solutionsadls2_data/Baseline_Forecasts/processed/Data_Processing/Missing_Date_Treatment/Missing_Dates_Treatment_Results (2025-08-13-10-46-01)


In [None]:
df


# 6. PySpark UDF Codes

Codes related to applying pandas-UDF on pyspark-dataframe.

## 6.1. UDF for Missing Value Treatment

The function API to be used as a pandas-UDF.

In [0]:
def missing_value_treatment_UDF(df: pd.DataFrame) -> pd.DataFrame:
    """Function utilizing broadcasted information from the config file to treat the missing values in the input dataset

    Parameters
    ----------
    df : pd.DataFrame
        the raw dataset which contains the value for all variables

    Returns
    -------
    pd.DataFrame
        Returns the input dataframe after the missing values in them are treated
    """
    df = df.sort_values(by=[broadcast_date_col],ascending=True)
  
    algo_params = broadcast_algo_params
    modeling_granularity = broadcast_granularity
    req_params = dict([x for x in broadcast_algo_params.items() if len(x[1]['cols'])>0])  
    for algo in req_params.keys():
        if algo in ['Rolling_Mean']:
            window = int(algo_params[algo]['window'])
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo, window = window)
        elif algo in ['Rolling_Median']:
            window = int(algo_params[algo]['window'])
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo, window = window)
        elif algo in ['Scalar']:
            value = int(algo_params[algo]['value'])
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo, arbitrary_value = value)
        elif algo in ['Forward_fill']:
            cols = algo_params[algo]['cols']
            df = impute_missing_data(df, cols, algo) 
        elif algo in ['Backward_fill']:
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Linear_Interpolation']:
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Spline_Interpolation']:
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Mean']:
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Median']:
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Mode']:
            cols = algo_params[algo]['cols']
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Zero']:
            cols = algo_params[algo]['cols']
            df = impute_missing_data(df, cols, algo)
        elif algo in ['Mean_Across_Years']:
            time_granularity = algo_params[algo]['time_granularity']
            cols = algo_params[algo]['cols']
            if algo_params[algo]['zero_as_missing_value'] == True:
                df[cols] = df[cols].replace(0,np.nan)
            df = impute_missing_data(df, cols, algo, modeling_granularity = modeling_granularity, time_granularity = time_granularity, date_col = broadcast_date_col.value)
    # Convert numeric columns to native Python types to avoid Snowflake Decimal conversion errors
    for col in df.select_dtypes(include=[np.integer, np.floating]).columns:
         if np.issubdtype(df[col].dtype, np.integer):
             df[col] = df[col].astype('int')  # or 'float' if decimals are expected
         elif np.issubdtype(df[col].dtype, np.floating):
             df[col] = df[col].astype('float')

    return df

In [None]:
req_cols = modeling_granularity_conf + [ds_config]
for cols in [x[1]['cols'] for x in broadcast_algo_params.items() if len(x[1]['cols'])>0]:
    req_cols.extend(cols)
req_cols

## 6.2. Call UDF

Applying the UDF on the input pyspark-dataframe.

In [None]:
modeling_granularity_conf

In [None]:
# Build expressions to count NULLs per column
null_counts = [
    F.sum(F.when(F.col(c).is_null(), 1).otherwise(0)).alias(c)
    for c in req_cols
]

# Aggregate once across all rows
missing_result_df = df.agg(*null_counts)
missing_result_df

In [None]:
df.count()

In [None]:
df = df.with_column(dv_config, F.col(dv_config).cast("float"))

In [None]:
print("Modeling granularity ->",modeling_granularity_conf)

In [0]:
results_s = (
    df.select(req_cols)
    .group_by(modeling_granularity_conf)
    .applyInPandas(missing_value_treatment_UDF, output_schema=df.select(req_cols).schema)
)

In [0]:
results_s

In [None]:
# Build expressions to count NULLs per column
null_counts = [
    F.sum(F.when(F.col(c).is_null(), 1).otherwise(0)).alias(c)
    for c in results_s.columns
]

# Aggregate once across all rows
missing_result_df = results_s.agg(*null_counts)
missing_result_df

In [None]:
results_s.count()

In [None]:
cleaned_cols = [c for c in req_cols if c not in modeling_granularity_conf + [ds_config]]

In [None]:
df_cleaned = df.select([c for c in df.columns if c not in cleaned_cols]).join(results_s, [*modeling_granularity_conf, ds_config], how = "left")
df_cleaned.count()

# 7. Store Results

Timestamp is added to the end of the file system because the notebook might be executed multiple times. This helps identify outputs from the most recent run.

In [0]:
results_s_with_ts = df_cleaned.with_column("LOAD_TS", F.current_timestamp())
results_s_with_ts.write.mode("overwrite").save_as_table(output_table_name)
print(results_s_with_ts.count())

File stored successfully.


In [None]:
test_df = session.table(output_table_name)
print("All data ->", test_df.count())
latest_ts = test_df.select(F.max("LOAD_TS")).collect()[0][0]
latest_data = test_df.filter(F.col("LOAD_TS") == F.lit(latest_ts))
print("Latest data ->", latest_data.count())