In [1]:
import os
from pyspark.sql import SparkSession, DataFrame
from pyspark.errors import AnalysisException
from dotenv import load_dotenv
from psycopg import sql
from mlsgpt.db import store
load_dotenv("../.env-deploy", override=True)

True

In [2]:
data_home = "/Users/kwesi/Desktop/ai/gpts/mlsgpt/data"
jar_files = ["postgresql-42.7.3.jar", "mysql-connector-j-8.0.33.jar"]
jar_opts = ",".join([f"{data_home}/jars/{jar}" for jar in jar_files])
warehouse = f"{data_home}/warehouse"

spark: SparkSession = (
    SparkSession.builder\
    .appName("MLSGPT")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.shuffle.service.enabled", "true")
    .config("spark.sql.warehouse.dir", f"{warehouse}")
    .config("spark.sql.session.timeZone", "UTC")
    .config("spark.jars", f"{jar_opts}") 
    .enableHiveSupport()
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

def read_table(url:str, props:dict, table_name: str, ) -> DataFrame:
    try:
        return spark.read.jdbc(url=url, table=table_name, properties=props)
    except AnalysisException as e:
        print(f"Table {table_name} not found")
        return None
    
pg_host = os.getenv("POSTGRES_HOST")
pg_port = os.getenv("POSTGRES_PORT")
pg_db = os.getenv("POSTGRES_DB")
pg_user = os.getenv("POSTGRES_USER")
pg_pass = os.getenv("POSTGRES_PASSWORD")
pg_driver = "org.postgresql.Driver"
pg_url = f"jdbc:postgresql://{pg_host}:{pg_port}/{pg_db}"
pg_props = {"user": pg_user, "password": pg_pass, "driver": pg_driver}

columns = [
    "ListingID", "H3IndexR00", "H3IndexR01", "H3IndexR02",
    "H3IndexR03", "H3IndexR04", "H3IndexR05", "H3IndexR06", 
    "H3IndexR07","H3IndexR08", "H3IndexR09", "H3IndexR10", 
    "H3IndexR11", "H3IndexR12","H3IndexR13", "H3IndexR14", 
    "H3IndexR15"
]
h3_df = (
    spark.read.format("parquet")
    .load(f"{data_home}/h3/property.parquet")
    .select(columns)
)


24/06/06 21:53:01 WARN Utils: Your hostname, marley.local resolves to a loopback address: 127.0.0.1; using 10.0.0.135 instead (on interface en0)
24/06/06 21:53:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/06/06 21:53:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/06 21:53:02 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
conn = store.create_pg_connection("gpts")
cursor = conn.cursor()
text = '''
INSERT INTO rsbr.h3_index ("ListingID", "H3IndexR00", "H3IndexR01", "H3IndexR02", "H3IndexR03", "H3IndexR04", "H3IndexR05", "H3IndexR06", "H3IndexR07", "H3IndexR08", "H3IndexR09", "H3IndexR10", "H3IndexR11", "H3IndexR12", "H3IndexR13", "H3IndexR14", "H3IndexR15")
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
'''
cmd = sql.SQL(text)

In [4]:
rows = [tuple(row) for row in h3_df.collect()]

In [5]:
cursor.executemany(cmd, rows)
conn.commit()