In [None]:
import os
from pyspark.sql import SparkSession, DataFrame, Row
from pyspark.sql import types as T
from pyspark.errors import AnalysisException
from dotenv import load_dotenv

load_dotenv("../.env-deploy", override=True)

In [None]:
data_home = "/Users/kwesi/Desktop/ai/gpts/mlsgpt/data"
jar_files = ["postgresql-42.7.3.jar", "mysql-connector-j-8.0.33.jar"]
jar_opts = ",".join([f"{data_home}/jars/{jar}" for jar in jar_files])
warehouse = f"{data_home}/warehouse"

spark: SparkSession = (
    SparkSession.builder\
    .appName("MLSGPT")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.shuffle.service.enabled", "true")
    .config("spark.sql.warehouse.dir", f"{warehouse}")
    .config("spark.sql.session.timeZone", "UTC")
    .config("spark.jars", f"{jar_opts}") 
    .enableHiveSupport()
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

In [None]:
def read_table(url:str, props:dict, table_name: str, ) -> DataFrame:
    try:
        return spark.read.jdbc(url=url, table=table_name, properties=props)
    except AnalysisException as e:
        print(f"Table {table_name} not found")
        return None
    
pg_url = "jdbc:postgresql://{}:{}/{}".format(os.getenv("POSTGRES_HOST"), os.getenv("POSTGRES_PORT"),os.getenv("POSTGRES_DB"))
pg_props = {
    "user": os.getenv("POSTGRES_USER"),
    "password": os.getenv("POSTGRES_PASSWORD"),
    "driver": "org.postgresql.Driver"
}

In [None]:
df0 = (
    read_table(pg_url, pg_props, "rsbr.property")
    .select("property_id", "ListingID", "PublicRemarks")
)
df1 = (
    read_table(pg_url, pg_props, "rsbr.embedding")
    .select("ListingID")
)

In [None]:
added = [row.ListingID for row in df1.collect()]
df0 = df0.filter(~df0.ListingID.isin(added))
to_embed = df0.select("ListingID", "PublicRemarks").collect()

In [None]:
print(len(added), len(to_embed))

In [None]:
import tiktoken
cost_per_1k_tokens = 0.00013
enc = tiktoken.encoding_for_model("text-embedding-3-small")
tokens = [len(enc.encode(row["PublicRemarks"])) for row in to_embed]
costs = [cost_per_1k_tokens * (token_count / 1000) for token_count in tokens]
print(f"Total tokens: {sum(tokens)}")
print(f"Total cost: ${sum(costs):.4f}")

In [None]:
from openai import OpenAI
client = OpenAI()

def embed(row:Row) -> str:
    client = OpenAI()
    response = client.embeddings.create(input = row["PublicRemarks"], model="text-embedding-3-small")
    return Row(ListingID=row["ListingID"], PublicRemarks=row["PublicRemarks"], Embedding=response.data[0].embedding)

In [None]:
import concurrent.futures

# Define a function to process a batch of embeddings
def process_batch(batch):
    with concurrent.futures.ThreadPoolExecutor(100) as executor:
        futures = []
        for data in batch:
            futures.append(executor.submit(embed, data))

        results = []
        for future in concurrent.futures.as_completed(futures):
            results.append(future.result())
    return results

# Split the 'to_embed' list into batches of 1000
batch_size = 100
batches = [to_embed[i:i+batch_size] for i in range(0, len(to_embed), batch_size)]

schema = T.StructType(
    [
        T.StructField("ListingID", T.StringType(), False),
        T.StructField("PublicRemarks", T.StringType(), False),
        T.StructField("Embedding", T.StringType(), False)
    ]
)


print(f"Processing {len(to_embed)} records in {len(batches)} batches of {batch_size} records each")
total_processed = 0
for i, batch in enumerate(batches):
    if i +  1 > 275:
        rows = process_batch(batch)
        df = spark.createDataFrame(rows, schema)
        df.write.csv(f"{data_home}/embeddings/batch{str(i+1).zfill(6)}.csv", mode="overwrite", header=True)
        total_processed += len(batch)
        print(f"Processed batch {i+1} of {len(batches)} for ({total_processed} of {len(to_embed)}) records")

