In [0]:
# mimic_mod/prep to clean text, derive labels, join notes and diagnosis from MIMIC IV
import pyspark.sql.functions as F

PHI_RE = r"\[\*\*.*?\*\*\]"

def filter_valid_text(df, text_col="text", min_len=10):
    return (df
            .filter(F.col(text_col).isNotNull())
            .filter(F.length(F.trim(F.col(text_col))) > min_len))

def clean_text(df, text_col="text", out_col="text_clean"):
    return (df
            .withColumn(out_col, F.regexp_replace(F.col(text_col), PHI_RE, " "))
            .withColumn(out_col, F.regexp_replace(F.col(out_col), r"\s+", " "))
            .withColumn(out_col, F.lower(F.col(out_col)).alias(out_col)))

def anxiety_label_from_icd(df_icd, title_col="long_title"):
    # 1 for anxiety mentions in ICD titles
    return (df_icd
            .withColumn("anxiety_flag", F.when(F.lower(F.col(title_col)).contains("anxiety"), F.lit(1)).otherwise(F.lit(0)))
            .groupBy("hadm_id")
            .agg(F.max("anxiety_flag").alias("label_anxiety")))

def join_notes_labels(df_notes, df_labels, how="left"):
    return (df_notes.join(df_labels, on="hadm_id", how=how)
            .fillna({"label_anxiety": 0}))
