## Notebook to demo the usage of `Trainer` on Titanic dataset.

Make sure to install the following modules:

```bash
uv pip install ".[jupyter]"
```

You can download the Titanic passenger data from [here](https://www.kaggle.com/c/titanic/data).

The following is a sample bash script to start a Spark instance:

```bash
#!/usr/bin/env bash

export SPARK_MAJOR="3.5"
export SPARK_MINOR="${SPARK_MAJOR}.5"
export HADOOP_VER="3"
export SPARK_HOME="${HOME}/spark-${SPARK_MINOR}-bin-hadoop${HADOOP_VER}"
export PYSPARK="${SPARK_HOME}/bin/pyspark"
export PYSPARK_DRIVER_PYTHON="jupyter"
export PYSPARK_DRIVER_PYTHON_OPTS='lab --ip=0.0.0.0 --port 8989 --allow-root --no-browser --IdentityProvider.token=""'
export PYARROW_IGNORE_TIMEZONE=1
export SPARK_LOCAL_IP="localhost"
export EXTRA_JAVA_OPTIONS="-XX:+AggressiveHeap -XX:ParallelGCThreads=20 -Djava.awt.headless=true"
export MASTER="local[*]"

$PYSPARK \
  --master "${MASTER}" \
  --conf spark.default.parallelism=20 \
  --conf spark.driver.extraJavaOptions="${EXTRA_JAVA_OPTIONS}" \
  --conf spark.executor.extraJavaOptions="${EXTRA_JAVA_OPTIONS}" \
  --conf spark.driver.maxResultSize=2G \
  --conf spark.driver.memory=64G \
  --conf spark.executor.cores=10 \
  --conf spark.executor.memory=64G \
  --conf spark.kryo.unsafe=true \
  --conf spark.kryoserializer.buffer.max=256M \
  --conf spark.memory.offHeap.enabled=true \
  --conf spark.memory.offHeap.size=64G \
  --conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
  --conf spark.sql.catalogImplementation=in-memory \
  --conf spark.sql.execution.arrow.pyspark.enabled=true \
  --conf spark.sql.shuffle.partitions=200 \
  --conf spark.cleaner.referenceTracking.cleanCheckpoints=true \
  --conf spark.ui.enabled=false \
  --conf spark.ui.showConsoleProgress=true
```

### References:
- https://www.kaggle.com/competitions/titanic/data
- https://python.langchain.com/docs/integrations/tools/spark_sql/

### Show spark instance.

In [1]:
spark

In [2]:
import os
import re

from gait import Idris, IdrisLiteEmb, IdrisLiteLLM, IdrisSparkSQL, IdrisTrainer
from pyspark.sql import functions as F
from pyspark.sql import types as T
from rich.pretty import pprint

### Define Titanic passenger CSV dataset schema.

In [3]:
schema = ",".join(
    [
        "`PassengerId` string",
        "`Survived` integer",
        "`Pclass` integer",
        "`Name` string",
        "`Sex` string",
        "`Age` float",
        "`SibSp` integer",
        "`Parch` integer",
        "`Ticket` string",
        "`Fare` float",
        "`Cabin` string",
        "`Embarked` string",
    ]
)

### Read the data and create the `passengers` table.

Adjust the path to where you downloaded the data.

In [4]:
path = os.path.expanduser("~/data/titanic/train.csv")

pdf = (
    spark.read.csv(
        path=path,
        schema=schema,
        sep=",",
        header=True,
        encoding="utf-8",
    )
    .drop("PassengerId", "Ticket")
    .cache()
)

pdf.createOrReplaceTempView("passengers")

pdf = pdf.toPandas()

### Optional - check if we can see the data in DuckDB.

Note: the table is the pandas dataframe variable name. Neato !

In [5]:
# import duckdb

# with duckdb.connect(":memory:") as conn:
#     _ = conn.execute("create or replace view idris as select * from pdf")
#     print(conn.sql("DESCRIBE idris").df())

### Create context information by training the model.

Here we create a an alias look up table.  The prefix `_col` referes to the field alias.

In [6]:
aliases = {
    "_col:Survived": "survived",  # 0 = No, 1 = Yes
    "_col:Pclass": "ticket class",  # 1 = 1st, 2 = 2nd, 3 = 3rd
    "_col:SibSp": "number of siblings / spouses aboard",
    "_col:Parch": "number of parents / children aboard",
    "_col:Fare": "passenger fare",
    "_col:Embarked": "port of embarkation",  # C = Cherbourg, Q = Queenstown, S = Southampton
    #
    # What to substitute when we see a value for a field.
    #
    "Survived:1": "yes",
    "Survived:0": "no",
    "Pclass:1": "first",
    "Pclass:2": "2nd",
    "Pclass:3": "3rd",
    "Embarked:C": "Cherbourg",
    "Embarked:Q": "Queenstown",
    "Embarked:S": "Southampton",
}

### Train the model.

In [7]:
train_result = IdrisTrainer(aliases).train(pdf, "passengers")

In [8]:
# print(train_result.create_table)

In [9]:
# Add more instructions here.
train_result.context.extend(
    [
        "Alias COUNT(*) to 'number_of_passengers'.",
        "Use 1 if the passenger survived. Use 0 if the passenger did not survive.",
        "Use 'male' and 'female' for sex.",
    ]
)

# pprint(train_result.context, expand_all=True)

In [10]:
# pprint(train_result.question_sql, expand_all=True)

In [11]:
# response = embedding(
#     model="openai/mxbai-embed-large:latest",
#     input=["good morning from litellm"],
#     api_base="http://localhost:11434/v1",
#     api_key="ollama",
# )
# pprint(response, expand_all=True)

### Create an IDRIS instance using Ollama services.

In [12]:
rdb = IdrisSparkSQL()

emb = IdrisLiteEmb(
    model_name="openai/mxbai-embed-large:latest",
    api_base="http://localhost:11434/v1",
    api_key="ollama",
)

llm = IdrisLiteLLM(
    # model_name="openai/phi4:14b-q8_0",
    model_name="openai/gemma3:4b",
    api_base="http://localhost:11434/v1",
    api_key="ollama",
)

idris = Idris(rdb, emb, llm)

### Load initial trained data.

In [13]:
idris.add_create_table(train_result.create_table)
idris.load_context(train_result.context)
idris.load_question_sql(train_result.question_sql)

In [14]:
def clean_sql(sql: str) -> str:
    return re.sub(r"^```sql\s*|\s*```$", "", sql)

### Start asking questions.

In [15]:
sql = idris.generate_sql("What's the name and age of the oldest surviving passenger?")
sql = clean_sql(sql)
print(sql)
idris.execute_sql(sql)

SELECT Name, Age FROM passengers WHERE Survived=1 ORDER BY Age DESC LIMIT 1


Unnamed: 0,Name,Age
0,"Barkworth, Mr. Algernon Henry Wilson",80.0


In [16]:
# Number of female passengers in 3rd class that survived.
sql = idris.generate_sql("생존한 3등석 여성승객 수")
sql = clean_sql(sql)
print(sql)
idris.execute_sql(sql)

SELECT count(*) FROM passengers WHERE Survived = 1 AND Pclass = 3 AND Sex = 'female'


Unnamed: 0,count(1)
0,72


In [17]:
sql = idris.generate_sql(
    "Quel est le nombre de passagers qui n'ont pas survécu et sont montés à bord depuis Southhampton?"
)
sql = clean_sql(sql)
print(sql)
idris.execute_sql(sql)

SELECT COUNT(*) FROM passengers WHERE Survived = 0 AND Embarked = 'S'


Unnamed: 0,count(1)
0,427


In [18]:
sql = idris.generate_sql(
    "What is the average age of passengers in first class that survived by sex?"
)
sql = clean_sql(sql)
print(sql)
idris.execute_sql(sql)

SELECT AVG(Age) FROM passengers WHERE Pclass = 1 AND Survived = 1 GROUP BY Sex


Unnamed: 0,avg(Age)
0,34.939024
1,36.248
