In [1]:
import os
from delta.tables import DeltaTable
from pyspark.sql import SparkSession
from dateutil.parser import parse
import pyspark.sql.types as T
import pyspark.sql.functions as F
import pandas as pd
import psycopg2
import requests

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [2]:
try:
    spark.stop()
except Exception as e: 
    print(e)

spark = (
    SparkSession.builder \
    .appName("SparkLocalStackS3Integration") \
    .config("spark.hadoop.fs.s3a.endpoint", os.environ['AWS_ENDPOINT_URL'])\
    .config("spark.hadoop.fs.s3a.access.key", os.environ["AWS_ACCESS_KEY_ID"])\
    .config("spark.hadoop.fs.s3a.secret.key",os.environ["AWS_SECRET_ACCESS_KEY"])\
    .config("spark.hadoop.fs.s3a.path.style.access", "true")\
    .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")\
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")\
    .config("spark.hadoop.fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider")\
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")
    .getOrCreate()
)

name 'spark' is not defined


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
# spark.sql("describe history delta.`s3a://my-storage-bucket/append-locations/postgres.public.task_instance_1_1` limit 10").toPandas()
spark.sql(f"select count(distinct(dag_id, task_id , run_id , map_index)) from delta.`s3a://my-storage-bucket/append-locations/postgres.public.task_instance_1_1` limit 1").toPandas()



                                                                                

Unnamed: 0,"count(DISTINCT named_struct(dag_id, dag_id, task_id, task_id, run_id, run_id, map_index, map_index))"
0,55


In [4]:
spark.sql("select count(1) from delta.`s3a://my-storage-bucket/upsert-locations/postgres.public.task_instance_1_1` limit 10").toPandas()


                                                                                

Unnamed: 0,count(1)
0,55


In [5]:
import os
import psycopg2
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType
from pyspark.sql.functions import col, max as spark_max, abs as spark_abs, lit
from datetime import datetime, timedelta

RDS_DB_NAME = "airflow"
RDS_USER = "airflow"
RDS_PASSWORD = "airflow"
RDS_HOST = "postgres"
RDS_PORT = 5432

def get_rds_connection():
    """Establishes a connection to the RDS database."""
    try:
        conn = psycopg2.connect(
            dbname=RDS_DB_NAME,
            user=RDS_USER,
            password=RDS_PASSWORD,
            host=RDS_HOST,
            port=RDS_PORT
        )
        conn.autocommit = True
        print(f"Successfully connected to RDS database: {RDS_DB_NAME}")
        return conn
    except Exception as e:
        print(f"Error connecting to RDS: {e}")
        return None

conn = get_rds_connection()

Successfully connected to RDS database: airflow


In [7]:
# TODO: fetch from config API
pipeline_config_id = 1
DATALAKE_PATH = requests.get(f"http://backend:8000/configs/{pipeline_config_id}/").json()["upsert_write_path"]
print(DATALAKE_PATH)
datalake_df = spark.read.format('delta').load(DATALAKE_PATH)

s3a://my-storage-bucket/upsert-locations/postgres.public.task_instance_1_1


## Sampling record-to-record matching
* For the provided columns it will perform,
    * Schema matching
    * Value matching

In [8]:
def sampling_record_matching(datalake_df, fraction, rds_table, columns_to_check, pk_cols):
    datalake_df = spark.read.format('delta').load(DATALAKE_PATH)
    datalake_sample_df = datalake_df.sample(withReplacement=False, fraction=fraction, seed=42)

    if datalake_sample_df.count() == 0:
        print(f"Zero sample records selected")

    columns_to_check = list(set(columns_to_check + pk_cols))

    pk_filter = " and ".join([f"{pk_name} = %({pk_name})s" for pk_name in pk_cols])
    query = f"SELECT * FROM {rds_table} where {pk_filter}"

    for row in datalake_sample_df.select(*columns_to_check).collect():
        row_dict = row.asDict()
        print("Datalake row", row_dict)
        pk_vals = {k: v for k,v in row_dict.items() if k in pk_cols}

        with conn.cursor() as cur:
            cur.execute(query, row_dict)
            row = cur.fetchone()

            if not row:
                print(f"Warning: No records found in rds table for {pk_vals}")
                continue
            
            column_names = [desc[0] for desc in cur.description]
            row_dict_rds = dict(zip(column_names, row))

        for c in columns_to_check:
            if row_dict_rds.get(c) != row_dict.get(c):
                print(f"Error: Found mismatch for column {c}. RDS: {row_dict_rds.get(c)}. Datalake: {row_dict.get(c)}")
                break
        else:
            print(f"Matched for {pk_vals}")

In [9]:
sampling_record_matching(datalake_df, 0.1, "task_instance", ["custom_operator_name", "run_id", "try_number"], ["dag_id","task_id","run_id","map_index"])

                                                                                

Datalake row {'custom_operator_name': None, 'run_id': 'manual__2025-06-15T06:50:31.485049+00:00', 'try_number': 1, 'dag_id': 'frequency_one', 'map_index': -1, 'task_id': 'docker_task_append_1_postgres.public.task_instance'}
Matched for {'run_id': 'manual__2025-06-15T06:50:31.485049+00:00', 'dag_id': 'frequency_one', 'map_index': -1, 'task_id': 'docker_task_append_1_postgres.public.task_instance'}
Datalake row {'custom_operator_name': None, 'run_id': 'scheduled__2025-06-15T06:49:37.820221+00:00', 'try_number': 1, 'dag_id': 'continuous_python_dag', 'map_index': -1, 'task_id': 'run_my_continuous_task'}
Matched for {'run_id': 'scheduled__2025-06-15T06:49:37.820221+00:00', 'dag_id': 'continuous_python_dag', 'map_index': -1, 'task_id': 'run_my_continuous_task'}
Datalake row {'custom_operator_name': None, 'run_id': 'scheduled__2025-06-15T06:49:46.194138+00:00', 'try_number': 1, 'dag_id': 'continuous_python_dag', 'map_index': -1, 'task_id': 'run_my_continuous_task'}
Matched for {'run_id': 'sch

                                                                                

In [10]:
sampling_record_matching(datalake_df, 0.1, "task_instance", ["custom_operator_name", "run_id", "try_number", "end_date"], ["dag_id","task_id","run_id","map_index"])

                                                                                

Datalake row {'custom_operator_name': None, 'run_id': 'manual__2025-06-15T06:50:31.485049+00:00', 'end_date': None, 'try_number': 1, 'dag_id': 'frequency_one', 'map_index': -1, 'task_id': 'docker_task_append_1_postgres.public.task_instance'}
Error: Found mismatch for column end_date. RDS: 2025-06-15 06:50:53.221776+00:00. Datalake: None
Datalake row {'custom_operator_name': None, 'run_id': 'scheduled__2025-06-15T06:49:37.820221+00:00', 'end_date': '2025-06-15T06:49:41.453703Z', 'try_number': 1, 'dag_id': 'continuous_python_dag', 'map_index': -1, 'task_id': 'run_my_continuous_task'}
Error: Found mismatch for column end_date. RDS: 2025-06-15 06:49:41.453703+00:00. Datalake: 2025-06-15T06:49:41.453703Z
Datalake row {'custom_operator_name': None, 'run_id': 'scheduled__2025-06-15T06:49:46.194138+00:00', 'end_date': '2025-06-15T06:49:48.275005Z', 'try_number': 1, 'dag_id': 'continuous_python_dag', 'map_index': -1, 'task_id': 'run_my_continuous_task'}
Error: Found mismatch for column end_date

## Heuristics count matching check
* This will try to check if a given row is there datalake table or not.
* It also considers any expected lag due to schedule

In [23]:
def heuristics_count_matching(datalake_df, rds_table, pk_cols, updated_at_field, frequency):
    datalake_df = spark.read.format('delta').load(DATALAKE_PATH)

    pk_cols_select = ", ".join(pk_cols)

    #### NOTE
    # This is not the most optimized query
    query = f"""
        SELECT {pk_cols_select} FROM {rds_table} where {updated_at_field} IS NOT NULL
        AND {updated_at_field} < now() - INTERVAL '{frequency} hour'
        ORDER BY random() LIMIT 5
    """

    with conn.cursor() as cur:
        cur.execute(query)
        rows = cur.fetchall()
        column_names = [desc[0] for desc in cur.description]
        row_dict_rds = [dict(zip(column_names, row)) for row in rows]

    for row_rds in row_dict_rds:
        tmp_datalake_df = datalake_df
        for pk_name, pk_val in row_rds.items():
            tmp_datalake_df = tmp_datalake_df.filter(F.col(pk_name) == pk_val)
        count = tmp_datalake_df.count()
        if count == 1:
            print(f"Single matched redord found for {row_rds}")
        elif count > 1:
            print(f"Warning: Multiple matched redords found for {row_rds}")
        else:
            print(f"Error: No matched record found for {row_rds}")

In [37]:
heuristics_count_matching(datalake_df, "task_instance", ["dag_id","task_id","run_id","map_index"], "end_date", 1)

## Table Lag checking

In [47]:
def table_lag_checing(datalake_df, rds_table, pk_cols, updated_at_field, frequency):
    datalake_df = spark.read.format('delta').load(DATALAKE_PATH)

    pk_cols_select = ", ".join(pk_cols)

    #### NOTE
    # This is not the most optimized query
    query = f"""
        SELECT {pk_cols_select}, {updated_at_field} FROM {rds_table}
        where {updated_at_field} IS NOT NULL
        AND {updated_at_field} < now() - INTERVAL '{frequency} hour'
        ORDER BY {updated_at_field} DESC LIMIT 1
    """

    with conn.cursor() as cur:
        cur.execute(query)
        row = cur.fetchone()
        if not row:
            print(f"Warning: No records found in RDS table")
            return
        
        column_names = [desc[0] for desc in cur.description]
        row_dict_rds = dict(zip(column_names, row))

    for pk_name, pk_val in row_dict_rds.items():
        datalake_df = datalake_df.filter(F.col(pk_name) == pk_val)

    dl_rows = datalake_df.orderBy(F.col(updated_at_field).desc()).collect()
    
    if len(dl_rows) == 0:
        print(f"Error: No records found for {row_dict_rds}")
        return
    elif len(dl_rows) > 1:
        print(f"Error: Multiple records found for {row_dict_rds} taking the latest one.")
    
    dl_row = dl_rows[0]

    dl_row_updated_at = parse(dl_row[updated_at_field]) if isinsance(dl_row[updated_at_field], str) else dl_row[updated_at_field]

    if (row_dict_rds[updated_at_field] > dl_row_updated_at):
        print(f"Error: Found Lag for {row_dict_rds}. RDS updated at {row_dict_rds[updated_at_field]}. Datalake update at {dl_row_updated_at}")
    else:
        print(f"No Lag found")

In [49]:
table_lag_checing(datalake_df, "task_instance", ["dag_id","task_id","run_id","map_index"], "end_date", 1)

                                                                                

Error: No records found for {'dag_id': 'continuous_python_dag', 'task_id': 'run_my_continuous_task', 'run_id': 'scheduled__2025-06-15T07:05:23.547339+00:00', 'map_index': -1, 'end_date': datetime.datetime(2025, 6, 15, 7, 5, 26, 279552, tzinfo=datetime.timezone.utc)}


In [50]:
spark.sql("""
select max(end_date) from delta.`s3a://my-storage-bucket/upsert-locations/postgres.public.task_instance_1_1`
""").show(10, False)



+---------------------------+
|max(end_date)              |
+---------------------------+
|2025-06-15T06:50:41.839462Z|
+---------------------------+



                                                                                