In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, regexp_replace, lit, when, concat

spark = SparkSession.builder.appName("MaskSensitiveFields").getOrCreate()

# Sample data
data = [
    ("John Doe", "123-45-6789", "4111111111111111"),
    ("Jane Smith", "987-65-4321", "5500000000000004"),
    ("Alice Brown", None, "340000000000009"),
    ("Bob White", "111-22-3333", None)
]

columns = ["name", "ssn", "card_number"]
df = spark.createDataFrame(data, columns)

In [0]:
def mask_sensitive_fields(df, fields_to_mask):
    for field in fields_to_mask:
        if field == "ssn":
            # Mask SSN: replace first 5 digits with asterisks
            df = df.withColumn(field, regexp_replace(col(field), r"\d{3}-\d{2}", "***-**"))
        elif field == "card_number":
            # Mask card number: keep last 4 digits, mask rest
            df = df.withColumn(field, when(
                col(field).isNotNull(),
                concat(lit("**** **** **** "), col(field).substr(-4, 4))
            ).otherwise(lit(None)))
    return df

In [0]:
fields_to_mask = ["ssn", "card_number"]
masked_df = mask_sensitive_fields(df, fields_to_mask)
masked_df.show(truncate=False)