## 0) Config

In [0]:
from pyspark.sql import functions as F, types as T
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

## 1) Load Raw Data

In [0]:
raw_df = spark.table("XXX") 

## 2) Recoding

In [0]:
# Derive date from datepart columns and convert to date type
def first_day_of_period_expr(colname: str):
    v = F.regexp_replace(F.col(colname), r"\D", "")
    L = F.length(v)

    year = v.substr(1, 4)
    mm   = v.substr(5, 2)
    dd   = v.substr(7, 2)

    # quarter â†’ first month
    q     = v.substr(5, 1).cast("int")
    q_mm  = F.lpad(((q - 1) * 3 + 1).cast("string"), 2, "0")

    date_str = (
        F.when(L == 8, F.when(dd == "00", F.concat(year, mm, F.lit("01"))).otherwise(v)) # YYYYMMDD or YYYYMM00
         .when(L == 6, F.concat(year, mm, F.lit("01")))     # YYYYMM
         .when(L == 5, F.concat(year, q_mm, F.lit("01")))   # YYYYQ
         .when(L == 4, F.concat(year, F.lit("0101")))         # YYYY
         .otherwise(F.lit(None))
    )
    return F.to_date(date_str, "yyyyMMdd")

for c in datepart_cols:
     df = df.withColumn(c, first_day_of_period_expr(c))

In [0]:
# Decode column names
df_col_lookup = spark.sql("SELECT * FROM XXX")

col_map = {row['Header Name']: row['Name'] for row in df_col_lookup.collect() if 'Header Name' in row.asDict() and 'Name' in row.asDict()}
for old, new in col_map.items():
    if old in df.columns:
        df = df.withColumnRenamed(old, new)

## 3) Header Cleansing

In [0]:
df = df.toDF(*[c.replace(".", "_")\
                .replace(" ", "_")\
                .replace("/", "")\
                .replace("%", "pct")\
                .replace("(", "")\
                .replace(")", "")\
                .replace("[", "")\
                .replace("[", "")\
                .replace("-", "")\
                .replace(",", "")\
                .replace(";", "")\
                .replace(":", "")\
                .replace("{", "")\
                .replace("}", "")\
                .replace("=", "")\
                .replace("\xa0", "_")\
                                            
                for c in df.columns])

## 4) Missingness Handling

In [0]:
# Null profile
n_rows = df.count()

null_exprs = [
    (F.count(F.when(F.col(c).isNull(), c))/n_rows).alias(c)
    for c in df.columns
]

null_df = df.agg(*null_exprs)

null_dist = (null_df.select(F.explode(F.map_from_arrays(F.array([F.lit(c) for c in null_df.columns]), F.array(*[F.col(c) for c in null_df.columns]))).alias("col", "pct_null")))

result = null_dist.orderBy(F.desc("pct_null"))

cols = [row["col"] for row in null_dist.filter(F.col("pct_null")>0.5).collect()]

# Drop sparse columns
df = df.drop(*cols)

## 5) Constant Handling

In [0]:
# Constant columns

# Keep only orderable cols
orderable = {T.StringType, T.IntegerType, 
             T.DateType, T.DoubleType, T.LongType, T.FloatType} 
             
cols = [f.name for f in df.schema.fields if type(f.dataType) in orderable]

aggs = []
for c in cols:
    aggs += [
        F.min(F.col(c)).alias(f"{c}__min"),
        F.max(F.col(c)).alias(f"{c}__max"),
    ]

stats = df.agg(*aggs).collect()[0].asDict()

# Drop col where min = max
const_cols = [c for c in cols
              if stats[f"{c}__min"] == stats[f"{c}__max"]]

df = df.drop(*const_cols)

## 6) Impute Categoricals

In [0]:
cat_cols = [f.name for f in df.schema.field if isinstance(f.dataType, T.StringType)]

cat_imputer = df.fillna({c: "__MISSING__" for c in cat_cols})

## 7) Colinearity Handling

In [0]:
numeric_cols = [f.name for f in cat_imputer.schema.fields if isinstance(f.dataType, T.NumericType)]

df_num = df.select(*[F.col(c).cast("double").alias(c) for c in numeric_cols])

meds = (df_num.select([
            F.expr(f'percentile_approx(`{c}`, 0.5, {int(1/0.01)})')
            .alias(c) for c in df_num.columns])
            .first().asDict())

df_work = df_num.fillna(meds)

df_work = df_work.cache(); _ = df_work.count()

nn_row = (df_num.agg(*[
            F.sum(F.when(F.col(c).isNotNull(), 1)
            .otherwise(0)).alias(c) for c in df_num.columns])
            .collect()[0].asDict())
        
n_total = df_num.count()
missing_frac = {c: 1.0 - (nn_row.get(c, 0) / max(1, n_total)) for c in df_num.columns}

va = VectorAssembler(inputCols=df_num.columns, outputCol="features", handleInvalid="skip")
vecdf = va.transform(df_work).select("features").cache(); _ = vecdf.count()

corr_mat = Correlation.corr(vecdf, "features", method="pearson").head()[0].toArray()
abs_corr = np.abs(corr_mat)

to_drop = set() 
cols = df_num.columns 
upper = np.triu(abs_corr, k=1) 
pairs = np.argwhere(upper >= 0.90) 
pairs = sorted(pairs, key=lambda ij: upper[ij[0], ij[1]], reverse=True)

for i, j in pairs:
    c1, c2 = cols[i], cols[j]
    if c1 in to_drop or c2 in to_drop:
        continue
    m1, m2 = missing_frac.get(c1, 0.0), missing_frac.get(c2, 0.0)
    drop = c1 if (m1 > m2 or (m1 == m2 and len(c1) >= len(c2))) else c2
    to_drop.add(drop)

df = df.drop(*to_drop).cache()