In [2]:
import json
from pathlib import Path
import duckdb

from blendsql.ingredients import LLMQA as RawLLMQA
from blendsql.ingredients import LLMMap as RawLLMMap
from blendsql.models import LlamaCpp

model = LlamaCpp(
    filename="Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
    model_name_or_path="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
    config={"n_gpu_layers": -1, "n_ctx": 8000, "seed": 100, "n_threads": 16},
    caching=False,
)
_ = model.model_obj

context_formatter=lambda df: json.dumps(
    df.to_dict(orient="records"), ensure_ascii=False, indent=4
)

LLMMap = RawLLMMap.from_args(
    num_few_shot_examples=3,
)(
    name="LLMMap",
    db=None,
    session_uuid=None,
    context_formatter=context_formatter
)

LLMQA = RawLLMQA.from_args(
    num_few_shot_examples=0,
)(
    name="LLMMap",
    db=None,
    session_uuid=None,
)

./Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

llama_context: n_ctx_per_seq (512) > n_ctx_train (0) -- possible training context overflow
llama_context: n_ctx_per_seq (8000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized


In [3]:
import time
from functools import lru_cache
from tag_queries import BLENDSQL_ANNOTATED_TAG_DATASET
CURR_DIR = Path(".")
NUM_VALUES_PASSED = 0

def load_tag_db_path(name: str) -> str:
    return (
        CURR_DIR / "data/bird-sql/dev_20240627/dev_databases/" / name / f"{name}.sqlite"
    )

@lru_cache(maxsize=1000)
def run_llmmap(question: str, value: str, return_type: str) -> list:
    global NUM_VALUES_PASSED
    print(value)
    # distinct_values = list(set(values))
    # print(f"Passing {len(distinct_values)} to LLMMap...")
    mapped_values = LLMMap.run(
        model=model,
        question=question,
        values=[value],
        list_options_in_prompt=True,
        context_formatter=context_formatter,
        return_type=return_type,
    )
    NUM_VALUES_PASSED += len(mapped_values)
    return mapped_values[0]


for item in BLENDSQL_ANNOTATED_TAG_DATASET:
    conn = duckdb.connect() # New connection for each query
    try:
        conn.remove_function("LLMMap")
    except duckdb.InvalidInputException:
        pass

    conn.create_function(
        name="LLMMap",
        function=run_llmmap,
        parameters=[duckdb.sqltypes.VARCHAR, duckdb.sqltypes.VARCHAR, duckdb.sqltypes.VARCHAR],
        return_type=duckdb.sqltypes.BOOLEAN
    )
    conn.execute(f"""ATTACH '{load_tag_db_path(item["DB used"])}' (TYPE SQLITE)""")
    start = time.time()
    result = conn.execute(
        """
        SELECT COUNT(DISTINCT s.CDSCode)
        FROM california_schools.schools s
        JOIN california_schools.satscores sa ON s.CDSCode = sa.cds
        WHERE sa.AvgScrMath > 560
        AND LLMMap(
            'Is this a county in the California Bay Area?',
            s.County,
            'bool'
        ) = TRUE
        """
    )
    print(time.time() - start)
    break

Alameda


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Alpine


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Amador


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Butte


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Calaveras


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Colusa


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Contra Costa


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Del Norte


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

El Dorado


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Fresno


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Glenn


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Humboldt


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Imperial


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Inyo


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Kern


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Kings


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Lake


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Lassen


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Los Angeles


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Madera


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Marin


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Mariposa


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Mendocino


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Merced


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Modoc


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Mono


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Monterey


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Napa


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Nevada


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Orange


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Placer


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Plumas


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Riverside


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Sacramento


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Benito


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Bernardino


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Diego


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Francisco


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Joaquin


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Luis Obispo


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

San Mateo


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Santa Barbara


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Santa Clara


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Santa Cruz


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Shasta


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Sierra


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Siskiyou


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Solano


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Sonoma


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Stanislaus


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Sutter


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Tehama


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Trinity


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Tulare


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Tuolumne


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Ventura


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Yolo


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

Yuba


LLMMap with batch_size=1:   0%|          | 0/1 [00:00<?, ?it/s]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

13.89301872253418


In [4]:
explain_result = conn.execute(
    """
    EXPLAIN
    SELECT COUNT(DISTINCT s.CDSCode)
    FROM california_schools.schools s
    JOIN california_schools.satscores sa ON s.CDSCode = sa.cds
    WHERE sa.AvgScrMath > 560
    AND LLMMap(
        'Is this a county in the California Bay Area?',
        s.County,
        'bool'
    ) = TRUE
    """
).fetchall()

for row in explain_result:
    print(row[1])

physical_plan


In [6]:
for row in explain_result:
    print(row[1])

┌───────────────────────────┐
│    UNGROUPED_AGGREGATE    │
│    ────────────────────   │
│        Aggregates:        │
│     count(DISTINCT #0)    │
└─────────────┬─────────────┘
┌─────────────┴─────────────┐
│         PROJECTION        │
│    ────────────────────   │
│          CDSCode          │
│                           │
│         ~706 rows         │
└─────────────┬─────────────┘
┌─────────────┴─────────────┐
│         HASH_JOIN         │
│    ────────────────────   │
│      Join Type: INNER     │
│                           │
│        Conditions:        ├──────────────┐
│       CDSCode = cds       │              │
│                           │              │
│         ~706 rows         │              │
└─────────────┬─────────────┘              │
┌─────────────┴─────────────┐┌─────────────┴─────────────┐
│         PROJECTION        ││         PROJECTION        │
│    ────────────────────   ││    ────────────────────   │
│             #0            ││             #0            │

In [None]:
NUM_VALUES_PASSED

In [None]:
# BlendSQL values passed: 19
# DuckDB: 58
# Total distinct schools: 58
# Distinct schools after join: 49
# Distinct schools after filter: 19
# DuckDB latency: 13.9
# BlendSQL: 1.5sec
49 - 19

In [4]:
55.7 - 46.4

9.300000000000004

| Model | Mode |        Accuracy |                        | F1 | | Denotation Accuracy | |
|:------|:-----|----------------:|-----------------------:|--------:|--------:|--------:|--------:|
| | | All 1k Examples | 899 Program Executable | All 1k Examples | 899 Program Executable | All 1k Examples | 899 Program Executable |
| **1b** | No Context |             1.5 |                    1.4 | 3.85 | 3.8 | 2.1 | 2.1 |
| | All Context |             8.0 |                    7.7 | 12.27 | 12.0 | 8.8 | 8.3 |
| | RAG |             8.9 |                    8.7 | 12.27 | 11.8 | 9.7 | 9.5 |
| | Program Execution |            12.1 |                   13.5 | 16.93 | 18.8 | 12.7 | 14.5 |
| **3b** | No Context |             3.1 |                    3.0 | 5.67 | 5.5 | 3.6 | 3.6 |
| | All Context |            37.3 |                   37.7 | 45.02 | 45.3 | 38.8 | 39.3 |
| | RAG |            35.7 |                   36.0 | 42.57 | 43.0 | 37.7 | 38.3 |
| | Program Execution |            41.8 |                   46.5 | 48.70 | 54.1 | 45.3 | 50.4 |
| **8b** | No Context |             3.8 |                    3.8 | 7.38 | 7.5 | 4.5 | 4.7 |
| | All Context |            42.9 |                   42.7 | 50.5 | 50.5 | 44.4 | 44.3 |
| | RAG |            43.8 |                   44.5 | 50.98 | 51.8 | 45.6 | 46.4 |
| | Program Execution |            46.9 |                   52.2 | 53.85 | 59.9 | 50.1 | 55.7 |
| **70b** | No Context |             5.7 |                    6.0 | 10.19 | 10.1 | 6.6 | 7.0 |
| | All Context |               - |                      - | - | - | - | - |
| | RAG |            54.5 |                   55.6 | 63.55 | 64.7 | 57.8 | 59.2 |
| | Program Execution |            48.3 |                   53.7 | 55.5 | 61.7 | 51.6 | 57.4 |