## Sample notebook to demo the creation of in-memory `context` and `question_sql` from NYC Taxi Trips.

In [None]:
import warnings
import logging

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [None]:
import os
import random
from typing import Tuple

import gait as G
import pyspark.sql.functions as F
import wordninja
from rich.pretty import pprint

In [None]:
logging.basicConfig(level="WARN", format="%(message)s")

### Read the trips as a spark dataframe.

In [None]:
filename = os.path.expanduser("~/data/nyc-taxi-trip-duration/train.prq")
trips = spark.read.parquet(filename).cache()

### Create a SQL view of the spark dataframe.

In [None]:
trips.createOrReplaceTempView("trips")

In [None]:
trips.printSchema()

### Get the dataframe relevant column names and types.

In [None]:
exclude = {
    "pickup_timestamp",
    "dropoff_timestamp",
    "pickup_q",
    "pickup_r",
    "dropoff_q",
    "dropoff_r",
}
name_type = [
    (f.name, f.dataType.typeName())
    for f in trips.schema.fields
    if f.name not in exclude
]
name_type

### Create a context from the field names.

Note how we are using `wordninja` to create "nice" names of the field names ;-)

In [None]:
context = []

for field_name, field_type in name_type:
    a = " ".join(wordninja.split(field_name))
    c = f"Use column '{field_name}' in reference to {a}."
    context.append(c)

context.append("Set pickup_boro='Manhattan' whenever from Manhattan is used.")
context.append("Set dropoff_airport='LGA' whenever to LaGuardia or LGA is used.")

context

### Calculate the count of distinct values of each column.

In [None]:
# field_names = []

# for n, t in name_type:
#     rows = (
#         trips.filter(F.col(n).isNotNull())
#         .select(F.countDistinct(n).alias("countDistinct"))
#         .collect()
#     )
#     countDistinct = rows[0].countDistinct
#     field_names.append((n, t, countDistinct))


# field_names

### Create sample question/sql for each column.

In [None]:
question_sql = []


def get_ops(t: str, k: str) -> Tuple[str, str]:
    if t == "string" or k in ("UNK"):
        if random.random() < 0.5:
            return "is not", "!="
        else:
            return "is", "="
    else:
        return {
            1: ("is less than", "<"),
            2: ("is", "="),
            3: ("is greater than", ">"),
        }[random.randint(1, 3)]


for field_name, field_type in name_type:
    rows = (
        trips.filter(F.col(field_name).isNotNull())
        .select(field_name)
        .distinct()
        .orderBy(F.rand())
        .limit(10)
        .collect()
    )
    field_word = " ".join(wordninja.split(field_name))
    for (v,) in rows:
        op1, op2 = get_ops(field_type, field_name)
        o = v if field_type in ("integer", "double", "timestamp") else v.lower()
        q = f"Show trips where {field_word} {op1} {o}"
        w = v if field_type in ("integer", "double") else f"'{v}'"
        s = f"SELECT * FROM trips where {field_name}{op2}{w}"
        question_sql.append((q, s))

In [None]:
for _ in random.choices(question_sql, k=5):
    pprint(_, expand_all=True)

### Create IDRIS Instance back by Apache Spark SQL.

In [None]:
rdb = G.IdrisSparkSQL()

emb = G.IdrisLiteEmb(
    # model_name="openai/mxbai-embed-large:latest",
    # api_base="http://localhost:11434/v1",
    # api_key="ollama",
    model_name="azure/text-embedding-ada-002",
    api_base=os.environ["AZURE_API_URL"] + "/text-embedding-ada-002",
)

llm = G.IdrisLiteLLM(
    # model_name="openai/phi4:14b-q8_0",
    # api_base="http://localhost:11434/v1",
    # api_key="ollama",
    model_name="azure/gpt-4o-mini",
    api_base=os.environ["AZURE_API_URL"] + "/gpt-4o-mini",
)

idris = G.Idris(rdb, emb, llm)

### Add context information.

In [None]:
idris.add_describe_table("trips")
idris.load_context(context)
idris.load_question_sql(question_sql)

### Let's talk to it :-)

In [None]:
sql = idris.generate_sql(
    "What is the average trip duration and distance from Manhattan to LGA between 4AM and 8AM"
)
print(sql)

In [None]:
G.is_sql_valid(sql)

In [None]:
idris.execute_sql(sql)