# Working with PostgreSQL tables using PySpark

In [1]:
from pathlib import Path
import configparser
import pyspark
import re
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

In [2]:
config_file = '/home/pybokeh/config.ini'

In [3]:
config = configparser.ConfigParser()
try:
    config.read(config_file)
except ConfigFileNotFound:
    print("config.ini file not found")

In [4]:
postgres_jdbc_driver = Path(config['postgresql']['jdbc_driver_path'])

In [5]:
# Read in the Postgresql database credentials for DSN-less connection
pg_host = config["postgresql"]["host"]
pg_port = config["postgresql"]["port"]
pg_db = config["postgresql"]["database"]
pg_user = config["postgresql"]["username"]
pg_pwd = config["postgresql"]["password"]

In [6]:
url = f'jdbc:postgresql://{pg_host}:{pg_port}/{pg_db}'
driver = 'org.postgresql.Driver'

In [7]:
spark = (
    SparkSession
    .builder.master("local[*]")
    .appName("Postgres")
    .config("spark.jars", postgres_jdbc_driver)
    .config("spark.sql.execution.eagerEval.enabled", "true")
    .getOrCreate()
)

23/05/21 11:26:22 WARN Utils: Your hostname, pybokeh-Lemur resolves to a loopback address: 127.0.1.1; using 192.168.1.147 instead (on interface wlp2s0)
23/05/21 11:26:22 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/05/21 11:26:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


#### Issue metadata query to get information about the database, what tables are available, etc

In [8]:
(
    spark.read
    .format("jdbc")
    .option("driver", driver)
    .option("url", url)
    .option("user", pg_user)
    .option("password", pg_pwd)
    .option("dbtable", "information_schema.tables")
    .load()
).printSchema()

root
 |-- table_catalog: string (nullable = true)
 |-- table_schema: string (nullable = true)
 |-- table_name: string (nullable = true)
 |-- table_type: string (nullable = true)
 |-- self_referencing_column_name: string (nullable = true)
 |-- reference_generation: string (nullable = true)
 |-- user_defined_type_catalog: string (nullable = true)
 |-- user_defined_type_schema: string (nullable = true)
 |-- user_defined_type_name: string (nullable = true)
 |-- is_insertable_into: string (nullable = true)
 |-- is_typed: string (nullable = true)
 |-- commit_action: string (nullable = true)



#### Let's obtain a list of schemas and table names

In [9]:
(
    spark.read
    .format("jdbc")
    .option("driver", driver)
    .option("url", url)
    .option("user", pg_user)
    .option("password", pg_pwd)
    .option("dbtable", "information_schema.tables")
    .load()
    .filter(col('table_schema') == 'public')
    .select(col('table_schema'), col('table_name'))
).show(truncate=False)

+------------+---------------------------------+
|table_schema|table_name                       |
+------------+---------------------------------+
|public      |us_counties_pop_est_2010_2019_raw|
|public      |regions                          |
|public      |divisions                        |
+------------+---------------------------------+



#### Let's fetch the 3 tables from Postgres and save them as PySpark dataframes

In [12]:
select_queries = [
    "SELECT * from public.us_counties_pop_est_2010_2019_raw",
    "SELECT * from public.regions",
    "SELECT * from public.divisions"
]

In [13]:
dataframes = []
for query in select_queries:
    dataframe = (
        spark.read
        .format("jdbc")
        .option("driver", driver)
        .option("url", url)
        .option("user", pg_user)
        .option("password", pg_pwd)
        .option("query", query)
        .load()
    )
    dataframes.append(dataframe)

In [14]:
us_counties_pop_est_2010_2019_raw = dataframes[0]
regions = dataframes[1]
divisions = dataframes[2]

In [15]:
us_counties_pop_est_2010_2019_raw.printSchema()

root
 |-- sumlev: string (nullable = true)
 |-- region: short (nullable = true)
 |-- division: short (nullable = true)
 |-- state: string (nullable = true)
 |-- county: string (nullable = true)
 |-- stname: string (nullable = true)
 |-- ctyname: string (nullable = true)
 |-- census2010pop: integer (nullable = true)
 |-- estimatesbase2010: integer (nullable = true)
 |-- popestimate2010: integer (nullable = true)
 |-- popestimate2011: integer (nullable = true)
 |-- popestimate2012: integer (nullable = true)
 |-- popestimate2013: integer (nullable = true)
 |-- popestimate2014: integer (nullable = true)
 |-- popestimate2015: integer (nullable = true)
 |-- popestimate2016: integer (nullable = true)
 |-- popestimate2017: integer (nullable = true)
 |-- popestimate2018: integer (nullable = true)
 |-- popestimate2019: integer (nullable = true)
 |-- npopchg_2010: integer (nullable = true)
 |-- npopchg_2011: integer (nullable = true)
 |-- npopchg_2012: integer (nullable = true)
 |-- npopchg_2013:

In [16]:
regions.show()

+------+-----------+
|region|region_name|
+------+-----------+
|     1|  Northeast|
|     2|    Midwest|
|     3|      South|
|     4|       West|
+------+-----------+



In [17]:
divisions.show()

+--------+------------------+
|division|     division_name|
+--------+------------------+
|       1|       New England|
|       2|   Middle Atlantic|
|       3|East North Central|
|       4|West North Central|
|       5|    South Atlantic|
|       6|East South Central|
|       7|West South Central|
|       8|           Montain|
|       9|           Pacific|
+--------+------------------+



#### Let's merge the 3 tables

In [18]:
temp_df = (
    us_counties_pop_est_2010_2019_raw.select(
        col('state').alias('state_fips'),
        col('county').alias('county_fips'),
        col('stname').alias('state_name'),
        col('ctyname').alias('county_name'),
        col('region'),
        col('division'),
        col('census2010pop'),
        col('estimatesbase2010'),
        us_counties_pop_est_2010_2019_raw.colRegex("`^popestimate20\d{2}$`"),
        us_counties_pop_est_2010_2019_raw.colRegex("`^births20\d{2}$`"),
        us_counties_pop_est_2010_2019_raw.colRegex("`^deaths20\d{2}$`"),
    ).join(
        regions,
        us_counties_pop_est_2010_2019_raw.region == regions.region,
        'left'
    ).join(
        divisions,
        us_counties_pop_est_2010_2019_raw.division == divisions.division,
        'left'
    )
    .drop(regions.region)
    .drop(divisions.division)
)

In [19]:
temp_df.columns

['state_fips',
 'county_fips',
 'state_name',
 'county_name',
 'region',
 'division',
 'census2010pop',
 'estimatesbase2010',
 'popestimate2010',
 'popestimate2011',
 'popestimate2012',
 'popestimate2013',
 'popestimate2014',
 'popestimate2015',
 'popestimate2016',
 'popestimate2017',
 'popestimate2018',
 'popestimate2019',
 'births2010',
 'births2011',
 'births2012',
 'births2013',
 'births2014',
 'births2015',
 'births2016',
 'births2017',
 'births2018',
 'births2019',
 'deaths2010',
 'deaths2011',
 'deaths2012',
 'deaths2013',
 'deaths2014',
 'deaths2015',
 'deaths2016',
 'deaths2017',
 'deaths2018',
 'deaths2019',
 'region_name',
 'division_name']

#### One problem is 'region_name' and 'division_name' columns are at the end, so need to re-arrange the columns

In [20]:
selected_columns = [
    col('state_fips'),
    col('county_fips'),
    col('state_name'),
    col('county_name'),
    col('region_name'),
    col('division_name'),
    col('census2010pop'),
    col('estimatesbase2010'),
    us_counties_pop_est_2010_2019_raw.colRegex("`^popestimate20\d{2}$`"),
    us_counties_pop_est_2010_2019_raw.colRegex("`^births20\d{2}$`"),
    us_counties_pop_est_2010_2019_raw.colRegex("`^deaths20\d{2}$`"),
]

In [21]:
temp_df.select(selected_columns).columns

['state_fips',
 'county_fips',
 'state_name',
 'county_name',
 'region_name',
 'division_name',
 'census2010pop',
 'estimatesbase2010',
 'popestimate2010',
 'popestimate2011',
 'popestimate2012',
 'popestimate2013',
 'popestimate2014',
 'popestimate2015',
 'popestimate2016',
 'popestimate2017',
 'popestimate2018',
 'popestimate2019',
 'births2010',
 'births2011',
 'births2012',
 'births2013',
 'births2014',
 'births2015',
 'births2016',
 'births2017',
 'births2018',
 'births2019',
 'deaths2010',
 'deaths2011',
 'deaths2012',
 'deaths2013',
 'deaths2014',
 'deaths2015',
 'deaths2016',
 'deaths2017',
 'deaths2018',
 'deaths2019']

#### The resulting dataframe above has the columns in the order we want, so now, let's write this dataframe as a PostgreSQL table

In [22]:
(
    temp_df.select(selected_columns).orderBy(col('state_fips'), col('county_fips'))
    # The following are needed to write this dataframe as a PostgreSQL table
    .write.format("jdbc")
    .option("url", url)
    .option("driver", driver)
    .option("dbtable", "public.us_counties_pop_est_2010_2019_basic")
    .option("user", pg_user)
    .option("password", pg_pwd)
    .mode("overwrite")
    .save()
)

23/05/21 11:26:34 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


#### Let's check if our new table has been created

In [23]:
(
    spark.read
    .format("jdbc")
    .option("driver", driver)
    .option("url", url)
    .option("user", pg_user)
    .option("password", pg_pwd)
    .option("dbtable", "information_schema.tables")
    .load()
    .filter(col('table_schema') == 'public')
    .select(col('table_schema'), col('table_name'))
).show(truncate=False)

In [25]:
spark.stop()