In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lit, regexp_replace, when, concat

class DataMasker:
    """
    Utility class for masking sensitive fields in a Spark DataFrame.
    """

    def __init__(self, fields_to_mask: list):
        """
        Initialize with a list of fields to mask.

        Parameters:
        - fields_to_mask: List of column names to apply masking rules
        """
        self.fields_to_mask = fields_to_mask

    def mask(self, df: DataFrame) -> DataFrame:
        """
        Apply masking rules to the specified fields in the DataFrame.

        Supported Fields:
        - 'ssn': Masks first 5 digits (***-**1234)
        - 'card_number': Masks all but last 4 digits (**** **** **** 1234)

        Returns:
        - Masked Spark DataFrame
        """
        for field in self.fields_to_mask:
            if field not in df.columns:
                continue  # Skip if field doesn't exist

            if field == "ssn":
                df = df.withColumn(
                    field,
                    when(col(field).isNotNull(),
                         regexp_replace(col(field), r"\d{3}-\d{2}", "***-**"))
                    .otherwise(lit(None))
                )

            elif field == "card_number":
                df = df.withColumn(
                    field,
                    when(col(field).isNotNull(),
                         concat(lit("**** **** **** "), col(field).substr(-4, 4)))
                    .otherwise(lit(None))
                )

            # Add more masking rules here as needed

        return df