In [0]:
import pandas as pd
from mlflow.deployments import get_deploy_client
from yaml import safe_load
from pyspark.sql.functions import col, concat, lit, when, udf
import time 
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, StringType
from pyspark.sql.window import Window
import math


with open("../params.yml", "r") as f:
    params = safe_load(f)

CATALOG = params.get('data_params').get('catalog')
SCHEMA = params.get('data_params').get('schema')

client = get_deploy_client("databricks")

def make_prompt(device_error):
    return f"""Please find an automated and remote fix for this device and error message: {device_error}. Be as concise as possible, preferably using fewer than 20 words. If such a solution exists without requiring human intervention, start the response with the phrase AUTOMATED FIX followed by a 1-sentence overview. If human intervention is necessary, start the response with the phrase HUMAN INTERVENTION REQUIRED and state a single phrase overview for the course of action needed.
    """

pred_content = make_prompt('Delta D750 DC Power System, telemetry message Battery voltage below critical threshold at 10.8V triggering low power alarm')
# response = client.predict(
#     endpoint="ka-fce5e7c5-endpoint",
#     inputs={
#         "messages": [{"role": "user", "content": pred_content}],
#         "databricks_options": {"return_trace": True}
#     }
# )
# response


In [0]:
df = spark.table(f"{CATALOG}.{SCHEMA}.device_status_triage").withColumn('query_column', concat(col('device'), lit(' message: '), col('telemetry_error') ))


def query_agent(device_error):
    prompt = make_prompt(device_error)
    response = client.predict(
        endpoint="ka-fce5e7c5-endpoint",
        inputs={
            "messages": [{"role": "user", "content": prompt}],
            "databricks_options": {"return_trace": True}
        }
    )
    return response.get('messages', [{'content': 'No response'}])[0]['content']

errorslist = df.select('query_column').toPandas().values.flatten().tolist()

errorslist_solution = [make_prompt(x) for x in errorslist]
# errorslist_solution = [client.predict(endpoint='ka-fce5e7c5-endpoint', inputs={'messages': [{'role': 'user', 'content': x}]}) for x in errorslist_solution]
print(errorslist_solution)
# query_agent_udf = udf(query_agent, StringType())
# df = df.withColumn('agent_response', query_agent_udf(col('query_column')))
# df = df.withColumn('prompt', make_prompt_udf(col('query_column')))
# display(df)


In [0]:
def query_agent(msg):
    time.sleep(15)
    response = client.predict(
    endpoint="ka-fce5e7c5-endpoint",
    inputs={
        "messages": [{"role": "user", "content": msg}],
        "databricks_options": {"return_trace": True}
    }
)
    return response.get('choices')[0].get('message').get('content')#.split('[')[0]

solutions = [query_agent(error) for error in errorslist_solution]

In [0]:

df = df.withColumn('solutions', lit(None))

for i, solution in enumerate(solutions):
    df = df.withColumn('solutions', when(col('query_column') == errorslist[i], solution).otherwise(col('solutions')))

display(df)

In [0]:
towers = spark.table(f"{CATALOG}.{SCHEMA}.sf_mobilelocations").select(['tower_id', 'lat', 'lon'])

df_latlon = df.join(towers, df.tower_id == towers.tower_id, how='left').withColumnsRenamed({'lat': 'station_lat', 'lon': 'station_lon'}) \
            .select(F.col('`telecommunications`.`self_healing_networks`.`sf_mobilelocations`.`tower_id`'), F.col('`telecommunications`.`self_healing_networks`.`device_status_triage`.`device_id`'), F.col('`telecommunications`.`self_healing_networks`.`device_status_triage`.`device`'), F.col('solutions'), F.col('station_lat'), F.col('station_lon'))
            

display(df_latlon)


In [0]:


# Haversine distance in kilometers
def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0
    phi1 = math.radians(lat1)
    phi2 = math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlambda = math.radians(lon2 - lon1)
    a = math.sin(dphi/2)**2 + math.cos(phi1)*math.cos(phi2)*math.sin(dlambda/2)**2
    return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

haversine_udf = F.udf(haversine, returnType=DoubleType())

# Broadcast the small table
df_small_b = F.broadcast(df_latlon)

routes = spark.table(f"{CATALOG}.{SCHEMA}.fieldtech_route")
max_date = routes.agg(F.max("date").alias("max_date")).collect()[0]["max_date"]
routes = routes.filter((F.col("date") == max_date) & (F.col("random_text") != "not scheduled"))
 
# Cross join and calculate distance
joined = routes.crossJoin(df_small_b) \
    .withColumn("distance_km", haversine_udf(
        F.col("lat"), F.col("lon"), F.col("station_lat"), F.col("station_lon")
    )).withColumn('triage_timestamp', F.current_timestamp()) 
    
    # result = joined.select(
    # [F.col(f"left.{c}") for c in df_large.columns] 

In [0]:
status_update = joined.withColumn(
    "status",
    F.when(F.col("solutions").startswith("HUMAN INTERVENTION REQUIRED"), F.concat(F.lit("Fix in progress: routing "), F.col("name"), F.lit(" to the site")))
    .otherwise("Automatically triaged")
).withColumn('triage_timestamp', F.current_timestamp()).select(['tower_id', 'device_id', 'device', 'status', 'triage_timestamp']).withColumnRenamed('triage_timestamp', 'event_timestamp')

status_update.write.mode("append").saveAsTable(f"{CATALOG}.{SCHEMA}.device_status")
# ['tower_id', 'device_id', 'device', 'status', 'event_timestamp']