In [None]:
# Import Libraries

from pyspark import SparkContent, SparkConf
from pyspark import SparkSession, SQLContext
import os
import sys
from pyspark.sql.functions import *

In [None]:
# Set Java Home & Vars

os.environ["JAVA HOME"] = "C:\Program Files\Java\jdk-18.0.2.1"

In [None]:
# Set Spark Config Details

conf = SparkConf() \
    .setAppName("ETLPipeline") \
    .setMaster("local") \
    .set("spark.driver.extraClassPath", "G:/pyspark/*")

In [None]:
# Initiate Spark Session

sc = SparkContext.getOrCreate(conf=conf)
etl = SparkSession(sc)

In [None]:
# Set DB details

# Get password from env var
pwd = os.environ["PGPASS"]
uid = os.environ["PGUID"]

# SQL DB details
server = "localhost"
src_db = "AdventureWorksDW2019"
target_db = "AdventureWorks"
src_driver = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
target_driver = "org.postgresql.Driver"

In [None]:
# Source connection
src_url = f"jdbc:sqlserver://{server}:1433;databaseName={src_db};user={uid};password={pwd};"
# Target connection
target_url = f"jdbc:postgresql://{server}:5432/{target_db}?user={uid}&password={pwd}"

In [None]:
# SQL Statement and Test

sql = """select t.name as table_name from sys.tables t
where t.name in ('DinProduct', 'DinProductSubcategory', 'DinProductCategory', 'DinSalesTerritory', 'FactInternetSales')"""

In [None]:
# Test the connection
dfs = etl.read. \
    format("jbdc"). \
    options(driver=src_driver, user=uid, password=pwd, url=src_url, query=sql). \
    load()
# Displays the df if connection is successful
dfs.show()

In [None]:
# Retrieve all values in the dataframe
data_collect = dfs.collect()
# Looping through each row of the dataframe
for row in data_collect:
    # Print table names to get table names for extract operation
    print(row["table_name"])

In [None]:
# Function to extract source system data
def extract():
    try:
        dfs = etl.read. \
            format("jdbc")
            options(driver=src_driver, user=uid, password=pwd, url=src_url, query=sql). \
            load()
        # Get table names
        data_collect = dfs.collect()
        # Looping througheach row of dataframe
        for row in data_collect:
            tbl_name = row["table_name"]
            df = etl.read \
            .format("jdbc") \
            .option("driver", src_driver) \
            .option("user", uid) \
            .option("url", src_url) \
            .option("dbtable", f"dbo.(tbl_name)") \
            .load()
            print(df.show(10))
            # Load(df, tbl_name)
            print("Data loaded successfully")
    except Exception as e:
        print("Data extract error: " + str(e))

In [None]:
# Function to persist data in target DB

def load():
    try:
        rows_imported = 0
        print(f"Importing rows: {rows_imported} to {rows_imported + df_count()}... for tale {tbl}")
        df.write.mode("overwrite") \
        .format("jdbc") \
        .option("driver", target_driver) \
        .option("user", uid) \
        .option ("password", pwd) \
        .option("url", target_url) \
        .option("dbtable", "src_" + tbl) \
        .save()
        print("Data imported successfully")
        rows_imported += df.count()
    except Exception as e:
        print("Data load error: " + str(e))

In [None]:
# Function call
extract()