In [None]:
import re
from pyspark.sql import SparkSession, DataFrame, Row
from pyspark.sql import types as T
from pyspark.errors import AnalysisException

from rich.pretty import pprint
from dotenv import load_dotenv

from mlsgpt.dbv2 import store, models, schema
load_dotenv("../.env-deploy", override=True)

In [None]:
reader = store.DataReader()

In [None]:
data_home = "/Users/kwesi/Desktop/ai/gpts/mlsgpt/data"
jar_files = ["postgresql-42.7.3.jar", "mysql-connector-j-8.0.33.jar"]
jar_opts = ",".join([f"{data_home}/jars/{jar}" for jar in jar_files])
warehouse = f"{data_home}/warehouse"

spark: SparkSession = (
    SparkSession.builder\
    .appName("MLSGPT")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.shuffle.service.enabled", "true")
    .config("spark.sql.warehouse.dir", f"{warehouse}")
    .config("spark.sql.session.timeZone", "UTC")
    .config("spark.jars", f"{jar_opts}") 
    .enableHiveSupport()
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

In [None]:
props = [models.Property(prop) for prop in reader.get_properties()]

In [None]:
# import tiktoken
# cost_per_1k_tokens = 5.0/1000
# enc = tiktoken.encoding_for_model("gpt-4")
# tokens = []
# costs = [cost_per_1k_tokens * (token_count / 1000) for token_count in tokens]
# print(f"Total tokens: {sum(tokens)}")
# print(f"Total cost: ${sum(costs):.4f}")

In [None]:
system_prompt = open("summary_system_prompt.md").read()
prompt_template = "Description JSON:\n```json\n{}\n```"

In [None]:
def parse_text(text: str) -> str:
    extracted = ""
    match = re.search(r'text\n(.*?)', text, re.DOTALL) 
    if match: 
        extracted = match.group(1) 
        return extracted

In [None]:
from openai import OpenAI
client = OpenAI()

def summarize(prop:models.Property) -> str:
    client = OpenAI()
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt_template.format(prop.model_dump_json())}
    ]
    response = client.chat.completions.create(messages=messages, model="gpt-4o")
    description = parse_text(response["choices"][0]["message"]["content"])
    return Row(ListingID=prop.ListingID, Summary=description, Embedding=None), response

In [None]:
prop = models.Property.model_validate(props[0])
row, response = summarize(prop)

In [None]:
# import concurrent.futures

# # Define a function to process a batch of embeddings
# def process_batch(batch):
#     with concurrent.futures.ThreadPoolExecutor(100) as executor:
#         futures = []
#         for data in batch:
#             futures.append(executor.submit(summarize, data))

#         results = []
#         for future in concurrent.futures.as_completed(futures):
#             results.append(future.result())
#     return results

# # Split the 'to_embed' list into batches of 1000
# batch_size = 100
# size = len(props)
# batches = [props[i:i+batch_size] for i in range(0, size, batch_size)]

# _schema = T.StructType(
#     [
#         T.StructField("ListingID", T.StringType(), False),
#         T.StructField("Summary", T.TextType(), False),
#         T.StructField("Embedding", T.StringType(), False)
#     ]
# )

# print(f"Processing {len(props)} records in {len(batches)} batches of {batch_size} records each")
# total_processed = 0
# for i, batch in enumerate(batches):
#     if i +  1 > 275:
#         rows = process_batch(batch)
#         df = spark.createDataFrame(rows, _schema)
#         df.write.csv(f"{data_home}/summaries/batch{str(i+1).zfill(6)}.csv", mode="overwrite", header=True)
#         total_processed += len(batch)
#         print(f"Processed batch {i+1} of {len(batches)} for ({total_processed} of {size}) records")

