In [None]:
import os
from psycopg import sql
from pyspark.sql import SparkSession
from pyspark.sql import types as T
from dotenv import load_dotenv
from mlsgpt.db import store

load_dotenv("../.env-deploy", override=True)

In [None]:
data_home = "/Users/kwesi/Desktop/ai/gpts/mlsgpt/data"
jar_files = ["postgresql-42.7.3.jar", "mysql-connector-j-8.0.33.jar"]
jar_opts = ",".join([f"{data_home}/jars/{jar}" for jar in jar_files])
warehouse = f"{data_home}/warehouse"

spark: SparkSession = (
    SparkSession.builder\
    .appName("MLSGPT")
    .config("spark.driver.memory", "16G")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.shuffle.service.enabled", "true")
    .config("spark.sql.warehouse.dir", f"{warehouse}")
    .config("spark.sql.session.timeZone", "UTC")
    .config("spark.jars", f"{jar_opts}") 
    .enableHiveSupport()
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

In [None]:
schema = T.StructType(
    [
        T.StructField("ListingID", T.StringType(), False),
        T.StructField("PublicRemarks", T.StringType(), False),
        T.StructField("Embedding", T.StringType(), False)
    ]
)
df = spark.read.csv(f"{data_home}/embeddings/batch*", header=True, schema=schema)
rows = [(row.ListingID, row.PublicRemarks, row.Embedding) for row in df.collect()]

In [None]:
conn = store.create_pg_connection("gpts")
cursor = conn.cursor()
cmd = sql.SQL('INSERT INTO rsbr.embedding ("ListingID", "PublicRemarks", "Embedding") VALUES (%s, %s, %s)')

In [None]:
# for row in rows:
cursor.executemany(cmd, rows[10:])
conn.commit()