In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag, lead, lit, concat
from pyspark.sql.window import Window

# Sample sales data
data = [
    ("ProductA", "2024-01", 100),
    ("ProductA", "2024-02", 120),
    ("ProductA", "2024-03", 90),
    ("ProductB", "2024-01", 200),
    ("ProductB", "2024-02", 210),
    ("ProductB", "2024-03", 200),
]

columns = ["Product", "Month", "Revenue"]
df = spark.createDataFrame(data, columns)

In [0]:
# Define window partitioned by product and ordered by month
windowSpec = Window.partitionBy("Product").orderBy("Month")

# Add lag and lead columns
df_with_lag_lead = df.withColumn("Prev_Revenue", lag("Revenue", 1).over(windowSpec)) \
                     .withColumn("Next_Revenue", lead("Revenue", 1).over(windowSpec))

#df_with_lag_lead.display()

In [0]:
df_with_lag_lead_wri = df_with_lag_lead.withColumn(
    "Revenue_Gap", 
    concat(
        (((col("Revenue") - col("Prev_Revenue")) / col("Revenue")) * lit(100)).cast("string"), 
        lit('%')
    )
).withColumn('Flag', lit('Flag'))
#df_with_lag_lead_wri.display()

In [0]:
from pyspark.sql import SparkSession, DataFrame

spark = SparkSession.builder.getOrCreate()

def validate_schema(df: DataFrame, table_name: str, strict: bool = True) -> bool:
    """
    Validates schema of DataFrame against columns of a Hive table.

    Parameters:
    - df: Incoming DataFrame to validate
    - table_name: Hive table name to compare against
    - strict: If True, exact match required. If False, allows subset match

    Returns:
    - True if schema matches (based on strict mode), False otherwise
    """
    try:
        expected_cols = [field.name for field in spark.table(table_name).schema.fields]
    except Exception as e:
        print(f"[❌] Failed to retrieve schema for table '{table_name}': {e}")
        return False

    actual_cols = df.columns

    if strict:
        return set(actual_cols) == set(expected_cols)
    else:
        return set(expected_cols).issubset(set(actual_cols))

In [0]:
dbutils.widgets.text("TableName", "rro.sales_data")
table_name = dbutils.widgets.get("TableName")
# Assuming df_with_lag_lead_wri is your processed DataFrame
if validate_schema(df_with_lag_lead_wri, table_name, strict=True):
    df_with_lag_lead_wri.write.mode("append").saveAsTable(table_name)
else:
    print("[⚠️] Schema mismatch detected. Investigate before writing!")