### Importing Libraries

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import to_timestamp, col, lit, array, when
from pyspark import SparkConf, SparkContext
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
import json
import snowflake.connector
import os
from datetime import datetime
import hmac 
from pyspark.sql import functions as F
from cryptography.fernet import Fernet
import hashlib
import re




  warn_incompatible_dep(


### Loading Credentials, Creating Spark Context, and Creating Connections

#### Load configuration properties from the Congig file

In [None]:
config = {'user': 'user',
 'password': 'password',
 'account': 'account',
 'warehouse': 'warehouse',
 'database': 'database',
 'schema': 'schema',
 'role': 'role',
 'db_host_src_pg': 'db_host_src_pg',
 'user_src_pg': 'user_src_pg',
 'password_src_pg': 'password_src_pg',
 'database_src_pg': 'database_src_pg',
 'username_mail': 'username_mail',
 'password_mail': 'password_mail',
 'enc_key': 'enc_key'}
 

In [None]:
# Loading connection credentials from Databricks secrets into config dictionary

for k,v in config.items():
    #print(k,v)
    config[k] = dbutils.secrets.get(scope="snow_cred", key= k)

#### Set up SparkSession

In [None]:
conf = SparkConf() \
    .setAppName("PostgreSQL_ETL") \
    .setMaster("local[*]") \
    .set("spark.jars.packages" ,
         "net.snowflake:snowflake-jdbc:3.13.3,net.snowflake:spark-snowflake_2.12:2.9.0-spark_3.0") \
    .set("spark.driver.extraClassPath" , "postgresql-42.2.24.jar")\
    .set("spark.driver.memory" , "15g") \
    .set("spark.executor.memory" , "15g") \
    #.set("spark.driver.maxResultSize" , "6g")\
    


sc = SparkContext.getOrCreate(conf=conf)
spark = SparkSession(sc)

#### Define connection properties

In [None]:
# PostgreSQL Properties
fetch_size = 100000
props = {
    "user": config['user_src_pg'],
    "password": config['password_src_pg'],
    "ssl": "true",
    "sslmode": "require", # verify-ca
    "sslrootcert": "/dbfs/FileStore/shared_uploads/some_directory/certs/root.crt",
    "sslcert": "/dbfs/FileStore/shared_uploads/some_directory/certs/postgresql.crt",
    "sslkey": "/dbfs/FileStore/shared_uploads/some_directory/certs/postgresql.unprotected.pk8",
    "fetchsize": str(fetch_size)
    }

# Define database url
url = "jdbc:postgresql://" + config['db_host_src_pg'] + ":5432/"+config['database_src_pg'] 


# Snowflake Properties

snow_options = {
    "sfURL": f"{config['account']}.snowflakecomputing.com" ,
    "sfUser": config['user'] ,
    "sfPassword": config['password'] ,
    "sfDatabase": config['database'] ,
    "sfSchema": config['schema'] ,
    "sfWarehouse": config['warehouse'] ,
    "sfRole": config['role'],
    "sfUseStagingTable": "true",
    "sfTemporaryStage": "SOME_STAGE",
    "sfPartitions": "8",
    "sfBatchSize": "100000",  # Adjust the batch size as needed
    "sfMaxConcurrency": "8", # Adjust the maximum concurrency as needed
    "sfCopyOptions" : "FILE_FORMAT = 'CSV_FORMAT', COMPRESSION = 'AUTO', SKIP_HEADER = 1"

    
}


# Establish a connection to Snowflake
conn = snowflake.connector.connect(
    user = config['user'],
    password = config['password'],
    account= config['account'],
    warehouse= config['warehouse'],
    database= config['database'],
    schema= config['schema']
)


# A list of messeages to send to me

msg_to_send = []

### Dummy Variables

In [None]:
dummy_json_str = {
    "_id": "6138763984fe5172962f8147",
    "balance": "$3,677.47",
    "about": "Lorem ipsum Aliqua ad elit elit veniam in mollit officia",
    "registered": "2018-05-04T02:28:05 -02:00",
    "latitude": 60.763774
}

dummy_json = json.dumps(dummy_json_str)
dummy_array = ['Lorem Ipsum']#dummy_json
dummy_array = array(lit(dummy_array[0]))

dummy_mail = 'dummy@mail.com'
dummy_string = 'Lorem ipsum'
dummy_phone = '+000000000000'
dummy_ip = '127.0.0.1'
dummy_date = '1660-01-01'


### Main Methods

#### Encryption

In [None]:
# Check if text is an email and encrypt the data


def split_email(email):
    parts = email.split('@', 1)
    return parts[0], "@" + parts[1]

def is_valid_email(email):
    email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    match = re.match(email_pattern, email)
    return bool(match)

def hash_value(text: str) -> str:
    if not text:
        text = 'null'

    if is_valid_email(text):
        username, domain = split_email(text)
        return hashlib.sha256((username.encode() + config['enc_key'].encode())).hexdigest() + domain
    else:
        return hashlib.sha256((text.encode() + config['enc_key'].encode())).hexdigest()

encrypt_data_udf = udf(hash_value, StringType())



#### Anonymise

In [None]:
def anonymise(tab : str, pg_df: 'DataFrame') -> 'DataFrame':
    if tab == 'some_table__1':
        pg_df = pg_df.withColumn("answers", lit(dummy_array))
        return pg_df
    elif tab == 'some_table__2':
        pg_df = pg_df.withColumn("json", lit(dummy_array))
        pg_df = pg_df.withColumn("whitelisted_ibans", lit(dummy_array))
        pg_df = pg_df.withColumn("balance", lit(dummy_array))
        return pg_df
    elif tab == 'some_table__3':
        pg_df = pg_df.withColumn("sender", encrypt_data_udf(col("sender")))
        pg_df = pg_df.withColumn("recipients", encrypt_data_udf(col("recipients")))
        pg_df = pg_df.withColumn("html", encrypt_data_udf(col("html")))
        pg_df = pg_df.withColumn("subject", encrypt_data_udf(col("subject")))
        return pg_df
    elif tab == 'some_table__4':
        pg_df = pg_df.withColumn("creation_ip", encrypt_data_udf(col("creation_ip")))
        pg_df = pg_df.withColumn("validation_ip", encrypt_data_udf(col("validation_ip")))
        pg_df = pg_df.withColumn("device_id", encrypt_data_udf(col("device_id")))
        pg_df = pg_df.withColumn("phone", lit(dummy_phone))
        return pg_df
    elif tab == 'some_table__5':
        pg_df = pg_df.withColumn("name", encrypt_data_udf(col("name")))
        pg_df = pg_df.withColumn("email", encrypt_data_udf(col("email")))
        return pg_df
    elif tab == 'some_table__6':
        pg_df = pg_df.withColumn("license_plate", encrypt_data_udf(col("license_plate")))
        return pg_df
    elif tab == 'some_table__7':
        pg_df = pg_df.withColumn("name", encrypt_data_udf(col("name")))
        return pg_df
    elif tab == 'some_table__8':
        pg_df = pg_df.withColumn("json", lit(dummy_json))
        return pg_df
    elif tab == 'some_table__9':
        pg_df = pg_df.withColumn("body", encrypt_data_udf(col("body")))
        pg_df = pg_df.withColumn("service", lit(dummy_phone))
        return pg_df
    elif tab == 'some_table__10':
        pg_df = pg_df.withColumn("stripe_source", lit(dummy_json))
        return pg_df
    elif tab == 'some_table__11':        
        pg_df = pg_df.withColumn("email", encrypt_data_udf(col("email")))
        pg_df = pg_df.withColumn("firstname", encrypt_data_udf(col("firstname")))
        pg_df = pg_df.withColumn("lastname", encrypt_data_udf(col("lastname")))
        pg_df = pg_df.withColumn("current_sign_in_ip", encrypt_data_udf(col("current_sign_in_ip")))
        pg_df = pg_df.withColumn("last_sign_in_ip", encrypt_data_udf(col("last_sign_in_ip")))
        pg_df = pg_df.withColumn("birthdate", lit(dummy_date))  
        return pg_df    
    elif tab == 'some_table__12':
        pg_df = pg_df.withColumn("holder", encrypt_data_udf(col("holder")))
        pg_df = pg_df.withColumn("bic", lit(dummy_json))
        pg_df = pg_df.withColumn("iban", lit(dummy_json))
        pg_df = pg_df.withColumn("unique_mandate_reference", encrypt_data_udf(col("unique_mandate_reference")))
        return pg_df
    elif tab == 'some_table__13':
        ecnrypt_my_address = ['User', 'BeneficialOwner','CardShipping']

        pg_df = pg_df.withColumn("street", when(col("addressable_type").isin(ecnrypt_my_address), encrypt_data_udf(col("street"))).otherwise(col("street")))
        pg_df = pg_df.withColumn("city", when(col("addressable_type").isin(ecnrypt_my_address), encrypt_data_udf(col("city"))).otherwise(col("city")))
        pg_df = pg_df.withColumn("postcode", when(col("addressable_type").isin(ecnrypt_my_address), encrypt_data_udf(col("postcode"))).otherwise(col("postcode")))
        pg_df = pg_df.withColumn("phone", lit(dummy_phone))
        return pg_df
    elif tab == 'some_table__14':
        pg_df = pg_df.withColumn("firstname", encrypt_data_udf(col("firstname")))
        pg_df = pg_df.withColumn("lastname", encrypt_data_udf(col("lastname")))
        pg_df = pg_df.withColumn("middle_names", encrypt_data_udf(col("middle_names")))
        pg_df = pg_df.withColumn("birthdate", lit(dummy_date))
        return pg_df
    elif tab == 'some_table__15':
        pg_df = pg_df.withColumn("current_phone", lit(dummy_phone))
        pg_df = pg_df.withColumn("next_phone", lit(dummy_phone))
        return pg_df
    elif tab == 'some_table__16':
        pg_df = pg_df.withColumn("name", encrypt_data_udf(col("name")))
        return pg_df
    elif tab == 'some_table__17':
        pg_df = pg_df.withColumn("email", encrypt_data_udf(col("email")))
        pg_df = pg_df.withColumn("sendgrid_data", lit(dummy_json))
        return pg_df
    else:
        return pg_df



#### Reading data from tables

In [None]:
# Uncomment below for DEV

"""
snow_options["sfSchema"] = "SOME_SCHEMA"


conn = snowflake.connector.connect(
    user = config['user'],
    password = config['password'],
    account= config['account'],
    warehouse= config['warehouse'],
    database= config['database'],
    schema= "PUBLIC_CLONE"#config['schema']
)
"""

##### Reading data from source table

In [None]:
# Read Data from PostgreSQL Database table

def read_src_pg_table(url: str, query: str, properties: dict) -> 'DataFrame':
    try:
        # Read table into a dataframe
        pg_table_df = spark.read.jdbc(url=url, table=query, properties=properties)
        return pg_table_df
    except Exception as e:
        print("Error occured -> ", e)

##### Reading data from target table

In [None]:
# Read Data from Snowflake Database table

def read_tgt_snw_table(query: str, snow_options: dict) -> 'DataFrame':
    try:
        # Read table into a dataframe
        snow_table_df = spark.read.format("snowflake").options(**snow_options).option("query", query).load()
        return snow_table_df
    except Exception as e:
        print("Error occured -> ", e)

##### Get source database tables to migrate

In [None]:
# Get all database tables to migrate

def get_all_tabs_to_migrate(url: str, get_all_tables_pg: str, pg_options: dict) -> list:
    #get_all_tables_pg = "(SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' order by table_name asc) all_tabs"
    all_pg_tabs = read_src_pg_table(url, get_all_tables_pg, pg_options)
    all_pg_tabs = all_pg_tabs.select("table_name").collect()
    return [row["table_name"] for row in all_pg_tabs]
     

##### Get target database tables to migrate

In [None]:
# Get all target database tables

def get_all_tgt_tabs_to_migrate(snow_options: dict) -> list:
    get_all_tar_tables_snw = f"(SELECT DISTINCT LOWER(TABLE_NAME) as TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{snow_options['sfSchema']}' AND TABLE_CATALOG = '{snow_options['sfDatabase']}' AND TABLE_TYPE = 'BASE TABLE' AND LOWER(TABLE_NAME) NOT IN ('job_accuracy', 'migration_monitoring', 'km_data', 'john_paul_sent', 'kpis', 'all_pub_tables', 'mastercard_interchanges_calculated', 'special_tables', 'job_execution_time', 'public_tables', 'risky_mcc') ORDER BY LOWER(TABLE_NAME) ASC) all_tgt_tabs"
    all_snw_tabs = read_tgt_snw_table(get_all_tar_tables_snw, snow_options)
    all_snw_tabs = all_snw_tabs.select("TABLE_NAME").collect()
    return [row["TABLE_NAME"] for row in all_snw_tabs]

#### Schema change detection and handling

##### Schema change detection

In [None]:
def check_schema_difference(table_to_check: str) -> bool:
    chng_str = ""
    
    # Processing Source (Postgres) table
    src_pg_query = f"(SELECT * FROM {table_to_check} LIMIT 1) {table_to_check}"
    src_pg_df = read_src_pg_table(url, src_pg_query, props)
    

    # Get column metadata from Source (Postgres) table
    src_pg_metadata = set([(row.name.lower(), str(row.dataType), row.nullable) for row in src_pg_df.schema])
    #print(f"\nsrc_pg_metadata -> {src_pg_metadata}\n")


    # Processing Target (Snowflake) table
    tar_snow_query = f"(SELECT * FROM {table_to_check} LIMIT 1) {table_to_check}"
    tar_snow_df = read_tgt_snw_table(tar_snow_query, snow_options)
    #print(f"len SRC -> {len(src_pg_df.columns)} and len TGT -> {len(tar_snow_df.columns)}")
    #print(tar_snow_df.printSchema())
    #print(f"\ntar_snow_df {tar_snow_df} {type(tar_snow_df)} {list(tar_snow_df)}")


    # Get column metadata from Target (Snowflake) table
    tar_snow_metadata = set([(row.name.lower(), str(row.dataType), row.nullable) for row in tar_snow_df.schema])
    #print(f"\ntar_snow_metadata -> {tar_snow_metadata}\n")


    # Definition of data type mapping between Postgres and Snowflake
    data_type_mapping = {
        'LongType()': 'DecimalType(38,0)',
        'bigint': 'DecimalType(38,0)',
        'IntegerType()': 'DecimalType(38,0)',
        'ArrayType(StringType(), True)': 'StringType()',
        'ShortType()': 'DecimalType(38,0)'
    }

    # Apply data type mapping to Source (Postgres) metadata
    mapped_src_pg_metadata = set(list([(column_name.lower(), data_type_mapping.get(data_type, data_type), is_nullable) for column_name, data_type, is_nullable in src_pg_metadata]))    
    #print(f"mapped_src_pg_metadata --> {mapped_src_pg_metadata}")

    # Apply nullable mapping to Source (Postgres) metadata
    

    # Compare schemas
    schema_diff =  mapped_src_pg_metadata - tar_snow_metadata
    
    print(f"\n\nDifference {schema_diff}")
    #print(f"len SRC -> {len(src_pg_df.columns)} and len TGT -> {len(tar_snow_df.columns)}")

    if len(schema_diff) > 0 and len(src_pg_df.columns) != len(tar_snow_df.columns): # check if there's no difference in the columns and in the length of columns
        for column in schema_diff:
            column_name, data_type, is_nullable = column
            chng_str += f"Column: {column_name}, Data Type: {data_type}, Nullable: {is_nullable}\n"
        
        chng_str = f"Database Table {table_to_check}\n" + chng_str
        msg_to_send.append(chng_str)
        return True
    else:
        return False


##### Schema change handling

In [None]:
def handle_schema_change(data_df: 'DataFrame', table_name: str, versioning_schema: str) -> None:
    try:
        # Get current timestamp
        current_timestamp = str(datetime.now().strftime("%Y%m%d%H%M%S"))
        temp_new_table = table_name+'__'+current_timestamp
        print(temp_new_table)

        # Save table with new name
        data_df.write.format("snowflake").options(**snow_options).option("dbtable", temp_new_table).mode("append").save() # "overwrite" "append"

        # Updates: 
                # -> Move old table to alternate schema and rename it by appending the current_timestamp of the new table to its name
                # -> Rename the new table to the name of the old table by stripping the current_timestamp from it

        # Start a transaction
        conn.cursor().execute("BEGIN TRANSACTION")

        try:
            # Send old table to temp schema and rename it
            conn.cursor().execute(f"ALTER TABLE {snow_options['sfDatabase']}.{snow_options['sfSchema']}.{table_name} RENAME TO {snow_options['sfDatabase']}.{versioning_schema}.{temp_new_table}")

            # Rename new table to old table
            conn.cursor().execute(f"ALTER TABLE {snow_options['sfDatabase']}.{snow_options['sfSchema']}.{temp_new_table} RENAME TO {snow_options['sfDatabase']}.{snow_options['sfSchema']}.{table_name}")

            # Commit the transaction
            conn.cursor().execute("COMMIT")
            msg = f"Schema change detected and new table created at {temp_new_table}"

            msg_to_send.append(msg)
        except Exception as e:
            # Rollback the transaction in case of any error
            conn.cursor().execute("ROLLBACK")
            print("Rolling Back...\nError occurred -> ", e)    


        #return saved_df
    except Exception as e:
        print("Error occured -> ", e)


#### Saving Data to Database

##### UPSERT target table

In [None]:
"""
This method essentially updates (MERGE) the target table
"""

def upsert_target_table(tab: str, merge_key: str = None) -> bool:

    if merge_key is None:
        merge_key = 'id'


    cursor = conn.cursor()
    cursor.execute("BEGIN TRANSACTION")

    try:
        ## Get the column names of the target table
        target_columns_query = f"SHOW COLUMNS IN {snow_options['sfDatabase']}.{snow_options['sfSchema']}.{tab}" 
        cursor.execute(target_columns_query)
        column_names = [column[2] for column in cursor.fetchall()]

        # Generate the column mappings for the MERGE query, handling reserved keywords
        column_mappings = ", ".join([f"\"target\".\"{col}\" = \"source\".\"{col}\"" for col in column_names])
        #print(f"column_mappings {column_mappings}")

        # Perform the MERGE operation to insert new rows or update existing rows
        #print(f"Upserting.... snow_options['sfSchema'] {snow_options['sfSchema']}")
        merge_query = (
            "MERGE INTO \"" + snow_options['sfDatabase']+"\"."+"\"" + snow_options['sfSchema']+"\"."+"\""+str(tab).upper() + "\" AS \"target\" "
            "USING " + snow_options['sfDatabase']+"." + snow_options['sfSchema']+".TEMP_TABLE AS \"source\" "
            "ON \"target\".\"" + merge_key + "\" = \"source\".\"" + merge_key + "\" "
            "WHEN MATCHED THEN "
            "UPDATE SET " + column_mappings + " "
            "WHEN NOT MATCHED THEN "
            "INSERT (" + ", ".join(["\"" + col + "\"" for col in column_names]) + ") "
            "VALUES (" + ", ".join(["\"source\".\"" + col + "\"" for col in column_names]) + ")"
        )
        #print(f"\nQuery to be executed {merge_query}")
        cursor.execute(merge_query)
        #print(f"\nQuery executed")
        
        # Delete the previously created TEMP_TABLE to free up the database
        cursor.execute(f"DROP TABLE {snow_options['sfDatabase']}.{snow_options['sfSchema']}.TEMP_TABLE")

        print(f"\nDrop table Query executed")

        # Commit the changes and close the connection
        cursor.execute("COMMIT")  

        return True
    except Exception as e:
        # Rollback the transaction in case of any error    
        msg_to_send.append(f"Could not Merge UPDATE table {tab} because of the following error\n{e}\n, and consequently the table was not migrated.\n Go take a look")
        cursor.execute("ROLLBACK")
        print("Rolling Back...\nError occurred -> ", e)  
        return False


##### Saving data to target table

In [None]:
# Save Data to Snowflake Database table

def load_tgt_snw_table(data_df: 'DataFrame', database_table: str, snow_options: dict, write_mode: str = None, updating: str = None) -> 'DataFrame':
    try:
        old_tab = database_table # Keep the table name just in case

        # Default write method
        if write_mode is None:
            write_mode = "append"
        
        # If updating, save data instead in a TEMP_TABLE for later MERGE operation and then change write mode to overite
        if updating is not None:
            database_table = "TEMP_TABLE"
            write_mode = "overwrite"

        # Read table into a dataframe
        #saved_df = 
        data_df.write.format("snowflake").options(**snow_options).option("dbtable", database_table).mode(f"{write_mode}").save()

        if updating is not None:
            if old_tab in ['mastercard_interchanges', 'mastercard_transaction_corrections']:
                upsert_target_table(old_tab, 'MASTERCARD_TRANSACTION_ID')
            else:
                upsert_target_table(old_tab)

        #return saved_df
    except Exception as e:
        print("Error occured -> ", e)

### Sending Email

In [None]:
def send_mail(subject: str, message: str, any_link: str=None) -> None:
    import smtplib
    from email.mime.text import MIMEText
    from email.mime.multipart import MIMEMultipart

    if any_link is None:
        any_link = 'No link'

    
    username = config['username_mail'] 
    password = config['password_mail']
    msg = MIMEMultipart('mixed')

    sender = 'my_email@mail.com'
    recipient = 'my_email@mail.com'

    msg['Subject'] = subject 
    msg['From'] = sender
    msg['To'] = recipient
    
    text_message = MIMEText(str(message))
    html_message = MIMEText(f"<br><b>File Located at:</b> <a href='{any_link}'>{any_link}</a>", 'html')
    msg.attach(text_message)
    msg.attach(html_message)

    mailServer = smtplib.SMTP('mail.smtp2go.com', 2525) # 8025, 587 and 25 can also be used.
    mailServer.ehlo()
    mailServer.starttls()
    mailServer.ehlo()
    mailServer.login(username, password)
    mailServer.sendmail(sender, recipient, msg.as_string())
    mailServer.close()

### Testing DB Connections and geting source and target tables

#### Get list of tables to migrate

In [None]:
get_all_tables_pg = "(SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' order by table_name asc) all_tabs"
all_pg_tabs = get_all_tabs_to_migrate(url, get_all_tables_pg, props) # Get a list of all source tables from the source database
len(all_pg_tabs), all_pg_tabs

all_snw_tabs = get_all_tgt_tabs_to_migrate(snow_options)
print(f"Target tables -> {len(all_pg_tabs)} \nSource tables -> {len(all_snw_tabs)}")


Target tables -> 188 
Source tables -> 204


In [None]:
cr_up_query = "(SELECT table_name FROM information_schema.columns WHERE table_schema = 'public' AND column_name IN ('created_at', 'updated_at', 'id') GROUP BY table_name HAVING COUNT(DISTINCT column_name) = 3) created_at__updated_at"

get_all_tables_pg = "(SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' order by table_name asc) all_tabs"

only_created_query = "(SELECT table_name FROM information_schema.columns WHERE table_schema = 'public' AND column_name IN ('created_at', 'id') GROUP BY table_name HAVING COUNT(DISTINCT column_name) = 2) only_created_at"

only_updated_query = "(SELECT table_name FROM information_schema.columns WHERE table_schema = 'public' AND column_name IN ('id') GROUP BY table_name HAVING COUNT(DISTINCT column_name) = 1) only_updated_at"

all_source_tables = get_all_tabs_to_migrate(url, get_all_tables_pg, props) # Get a list of all source tables from the source database


created_at__updated_at = get_all_tabs_to_migrate(url, cr_up_query, props) # Get a list of all source tables containing columns 'created_at', 'updated_at', and 'id'

all_created_at = get_all_tabs_to_migrate(url, only_created_query, props) # Get a list of all source tables containing at least the columns 'created_at', and 'id'
just_created_at = list(set(all_created_at) - set(created_at__updated_at)) # Get a list of all source tables containing only 'created_at', and 'id' columns

only_updated_at = get_all_tabs_to_migrate(url, only_updated_query, props)
only_updated_at = list(set(only_updated_at) - set(created_at__updated_at))

len(all_created_at), len(just_created_at), len(created_at__updated_at)

(155, 7, 148)

### Core Migration method

In [None]:
"""
Check if the tables in the add_to_created_at list exists in the all_pg_tabs(ALL SOURCE TABLES) list.
And if yes, append the tables to the all_created_at list and proceed with migrating the tables in the updates all_created_at list.
"""

add_to_created_at = ['loyaltek_text_messages', 'account_movements', 'payment_errors', 'ahoy_messages']


[all_created_at.append(i) for i in add_to_created_at if i in all_pg_tabs]

date_col = {
            'loyaltek_text_messages': 'at',
            'account_movements': 'transaction_date',
            'payment_errors': 'at',
            'ahoy_messages': 'sent_at'
            }
all_created_at = [i for i in all_created_at if i in all_source_tables]
all_created_at = sorted(all_created_at)

cnt = 0
for tab in all_created_at:
    try:
        print("\nmigrating table ", tab, "...\n")
        max_col = 'created_at'
        cnt += 1
        if tab in add_to_created_at:
            max_col = date_col[tab]

        if tab not in all_snw_tabs:        
            pg_query = f"(SELECT * FROM {tab} ORDER BY {max_col} ASC ) AS {tab}"
            pg_df = read_src_pg_table(url, pg_query, props)
            if len(pg_df.head(1)) == 0:
                continue
            pg_df.cache()  # Cache the DataFrame in memory
            load_tgt_snw_table(pg_df, tab, snow_options, "overwrite")
            pg_df.unpersist()
            msg_to_send.append(f"New table {tab} migrated")
            print("\nmigrated table ", tab, "\n")
            continue

        if check_schema_difference(tab):
            print('handle')
            pg_query = f"(SELECT * FROM {tab} ) AS {tab}"
            pg_df = read_src_pg_table(url, pg_query, props)
            if len(pg_df.head(1)) == 0:
                continue
            pg_df.cache()  # Cache the DataFrame in memory
            print(pg_query)
            handle_schema_change(pg_df, tab, 'TEMP_HOLD')
            pg_df.unpersist()
            print("\nmigrated table ", tab, "\n")
        else:
            cursor = conn.cursor()
            cursor.execute(f"SELECT MAX({max_col}) FROM {tab}")
            max_val = cursor.fetchone()[0]
            print('max_date ', max_val)

            if max_val:
                # Extension query: updated_at
                extension = f" OR updated_at >= '{max_val}' "
                if tab in just_created_at or tab in add_to_created_at:
                    extension = ""

                pg_query = f"(SELECT * FROM {tab} WHERE {max_col} >= '{max_val}' {extension} ORDER BY {max_col} ASC ) AS {tab}"
                print('next saved ')
                pg_df = read_src_pg_table(url, pg_query, props)
                pg_df.cache()  # Cache the DataFrame in memory

                # Anonymise Data
                pg_df = anonymise(tab, pg_df)
            
                #display(pg_df)

                load_tgt_snw_table(pg_df, tab, snow_options, "overwrite", "update")
                pg_df.unpersist()
                print("\nmigrated table ", tab, "\n")
            else:
                if max_val is not None:
                    msg_to_send.append(f"Table {tab} could not be migrated because it has no specific DATE to filter source rows")
    except Exception as e:
        print(f"Core Migration Error -> \n{e}")
        msg_to_send.append(f"Error while attempting to migrate the table {tab}")
        continue
print(f"Nummber of tables to migrate {len(all_created_at)}, number of tables actualy migrated {cnt} ")

if cnt != len(all_created_at):
    msg_to_send.append(f"Nummber of tables to migrate {len(all_created_at)}, number of tables actualy migrated {cnt} ")         



In [None]:
all_created_at

In [None]:
#print(msg_to_send)
if len(msg_to_send) > 0:
    any_link = 'https://databricks_notebook_link'
    sbj = 'PostgreSQL Migration - Logs from Databrick! '
    final_msg = ''.join([f"{i}\n" for i in msg_to_send])
    send_mail(sbj, final_msg, any_link)   

In [None]:
final_msg = []
msg_to_send = []

#### Spacial tables

In [None]:
none_int_id = ['attendees', 'email_validations', 'vat_rates', 'expense_attendee_guesses', 'sso_login_domains', 'active_storage_variant_records']

always_overwrite = ['card_patterns_expense_categories', 'cards_tags', 'billing_invoices_stripe_payments', 'sso_email_domains', 'card_mccs', 'expense_analytical_axes', 'expense_analytic_codes', 'supplier_mastercard_merchants', 'archival_transfers_receipts', 'schema_migrations', 'ar_internal_metadata']

using_id = ['consents', 'expense_analytic_code_guesses', 'expense_business_code_guesses', 'expense_vats', 'invoice_lines', 'invoices', 'mastercard_transaction_corrections', 'mastercard_interchanges']

##### Spacial tables -> Using IDs

In [None]:
"""
Check if the tables in the none_int_id list exists in the all_pg_tabs(ALL SOURCE TABLES) list.
And if yes, append the tables to the using_id list and proceed with migrating the tables in the using_id list.
"""



[using_id.append(i) for i in none_int_id if i in all_pg_tabs]

using_id = [i for i in using_id if i in all_source_tables]
using_id = sorted(using_id)

for tab in using_id:
    try:
        print("\nmigrating table ", tab, "...\n")
        max_col = 'id'
        int_id = '::int'
        if tab in none_int_id:
            int_id = ''
            print(f"None int it ")
        if tab in ['mastercard_interchanges', 'mastercard_transaction_corrections']:
            max_col = 'mastercard_transaction_id'
            
        if tab not in all_snw_tabs:        
            pg_query = f"(SELECT * FROM {tab} ORDER BY {max_col}{int_id} ASC ) AS {tab}"
            pg_df = read_src_pg_table(url, pg_query, props)
            if len(pg_df.head(1)) == 0:
                continue
            pg_df.cache()  # Cache the DataFrame in memory
            load_tgt_snw_table(pg_df, tab, snow_options, "overwrite")
            pg_df.unpersist()
            msg_to_send.append(f"New table {tab} migrated")
            print("\nmigrated table ", tab, "\n")
            continue

        if check_schema_difference(tab):
            print('handle')
            pg_query = f"(SELECT * FROM {tab} ORDER BY {max_col}{int_id} ASC ) AS {tab}"
            pg_df = read_src_pg_table(url, pg_query, props)
            if len(pg_df.head(1)) == 0:
                continue
            pg_df.cache()  # Cache the DataFrame in memory
            #print(pg_query)
            handle_schema_change(pg_df, tab, 'TEMP_HOLD')
            pg_df.unpersist()
            print("\nmigrated table ", tab, "\n")
        else:
            cursor = conn.cursor()
            cursor.execute(f"SELECT MAX({max_col}{int_id} ) FROM {tab}")
            max_val = cursor.fetchone()[0]
            print('max_val ', max_val)

            if max_val:            
                if tab in none_int_id:
                    pg_query = f"(SELECT * FROM {tab} WHERE {max_col}{int_id}  >= '{max_val}' ORDER BY {max_col}{int_id}  ASC ) AS {tab}"
                else:
                    pg_query = f"(SELECT * FROM {tab} WHERE {max_col}{int_id}  >= {max_val} ORDER BY {max_col}{int_id}  ASC ) AS {tab}"

                print('max ', pg_query)
                pg_df = read_src_pg_table(url, pg_query, props)

                # Anonymise Data
                pg_df = anonymise(tab, pg_df)
                pg_df.cache()  # Cache the DataFrame in memory
                #display(pg_df)

                load_tgt_snw_table(pg_df, tab, snow_options, "overwrite", "update")
                pg_df.unpersist()
                print("\nmigrated table ", tab, "\n")
            else:
                if max_val is not None:
                    msg_to_send.append(f"using_id :: Table {tab} could not be migrated because it has no specific DATE to filter source rows")
    except Exception as e:
        print(f"Using id Error -> \n{e}")



#print(msg_to_send)
if len(msg_to_send) > 0:
    any_link = 'https://databricks_notebook_link'
    sbj = 'Special Tables PostgreSQL Migration - Logs from Databrick! '
    final_msg = ''.join([f"{i}\n" for i in msg_to_send])
    send_mail(sbj, final_msg, any_link)            



##### Spacial tables -> Always Overwrite

In [None]:
always_overwrite = ['card_patterns_expense_categories', 'billing_invoice_lines', 'disbursements', 'cards_tags', 'billing_invoices_stripe_payments', 'sso_email_domains', 'card_mccs', 'expense_analytical_axes', 'expense_analytic_codes', 'supplier_mastercard_merchants', 'archival_transfers_receipts', 'schema_migrations', 'ar_internal_metadata']

# Check if the tables in the always_overwrite list exists in all_pg_tabs(ALL SOURCE TABLES) and if yes, proceed with migrating the table

for tab in [i for i in always_overwrite if i in all_pg_tabs]: 
    print(tab)
    pg_query = f"(SELECT * FROM {tab} ) AS {tab}"    
    pg_df = read_src_pg_table(url, pg_query, props)
    pg_df.cache()  # Cache the DataFrame in memory
    load_tgt_snw_table(pg_df, tab, snow_options, "overwrite",)
    pg_df.unpersist()

### End Spark

In [None]:
# Uncomment the code below after DEV
#spark.stop()

# Close the connection
#cursor.close()
#conn.close()

## DELETE ALL the cells after this one when you move to PROD

In [None]:
#pg_df.printSchema()

In [None]:
"""

max_col = 'started_at'
tab = 'SOME_TABLE'

cursor = conn.cursor()
cursor.execute(f"SELECT ID FROM {tab}")
original_list = cursor.fetchall()

new_list = [item[0] for item in original_list]

result_tuple = tuple(new_list)
result_tuple

"""

In [None]:
"""

pg_query = f"(SELECT * FROM {tab} WHERE ID NOT IN {result_tuple}) AS {tab}"    # pg_query = f"(SELECT COUNT({max_col})::int AS num_rows_prod FROM {tab} ) AS {tab}"    

pg_df = read_src_pg_table(url, pg_query, props)
display(pg_df)

"""