In [7]:
import os


def get_database_info(base_path: str):
    output = []
    for database_name in os.listdir(base_path):
        if os.path.isdir(os.path.join(base_path, database_name)):
            database_path = os.path.join(
                base_path, database_name, f"{database_name}.sqlite"
            )
            output.append((database_name, database_path))
    return output


database_info = get_database_info(
    "/home/data2/luzhan/projects/bird_bench/dev/dev_databases"
) + get_database_info("/home/data2/luzhan/projects/bird_bench/train/train_databases")
database_info

[('codebase_community',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/codebase_community/codebase_community.sqlite'),
 ('formula_1',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/formula_1/formula_1.sqlite'),
 ('card_games',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/card_games/card_games.sqlite'),
 ('superhero',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/superhero/superhero.sqlite'),
 ('student_club',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/student_club/student_club.sqlite'),
 ('toxicology',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/toxicology/toxicology.sqlite'),
 ('california_schools',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/california_schools/california_schools.sqlite'),
 ('thrombosis_prediction',
  '/home/data2/luzhan/projects/bird_bench/dev/dev_databases/thrombosis_prediction/thrombosis_prediction.sqlite'),
 ('financial',
  '/home/data2/luzhan/projects/bird_benc

In [1]:
from pymilvus import DataType, MilvusClient, utility, connections, db, model
import tool


# from pymilvus.model.hybrid import BGEM3EmbeddingFunction
# import pandas as pd

# bge_m3_ef = BGEM3EmbeddingFunction(
#     model_name="BAAI/bge-m3",  # Specify the model name
#     device="cpu",  # Specify the device to use, e.g., 'cpu' or 'cuda:0'
#     use_fp16=False,  # Specify whether to use fp16. Set to `False` if `device` is `cpu`.
# )
embedding_fn = model.dense.SentenceTransformerEmbeddingFunction(
    model_name="all-MiniLM-L6-v2",  # Specify the model name
    device="cuda:0",  # Specify the device to use, e.g., 'cpu' or 'cuda:0'
    trust_remote_code=True,
    local_files_only=True,
)

COLLECTION_NAME_FOR_TABLE_SCHEMA = "{db_name}_table_schema"
COLLECTION_NAME_FOR_COLUMN_SCHEMA = "{db_name}_column_schema"
COLLECTION_NAME_FOR_COLUMN_VALUE = "{db_name}_column_value"

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
def build_table_schema(client: MilvusClient, engine, db_name: str):
    collection_name = COLLECTION_NAME_FOR_TABLE_SCHEMA.format(db_name=db_name)
    client.drop_collection(collection_name=collection_name)
    schema = MilvusClient.create_schema(
        auto_id=True,
        enable_dynamic_field=True,
    )
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="table_name", datatype=DataType.VARCHAR, max_length=128)
    schema.add_field(
        field_name="table_name_vector", datatype=DataType.FLOAT_VECTOR, dim=768
    )
    client.create_collection(
        collection_name=collection_name,
        schema=schema,
    )

    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="table_name_vector",
        metric_type="COSINE",
        index_type="IVF_FLAT",
        index_name="table_name_vector_index",
        params={"nlist": 128},
    )
    client.create_index(collection_name=collection_name, index_params=index_params)
    res = client.list_indexes(collection_name=collection_name)

    table_names = tool.get_tables_in_database(engine)
    data = []
    for table_name in table_names:
        table_name_vector = embedding_fn.encode_documents(table_name)[0]
        # print(table_name_vector)
        data.append(
            {
                "table_name": table_name,
                "table_name_vector": table_name_vector,
            }
        )
    res = client.insert(collection_name=collection_name, data=data)
    print(res)

In [11]:
import pandas as pd


def build_column_schema(client: MilvusClient, df: pd.DataFrame, db_name: str):
    collection_name = COLLECTION_NAME_FOR_COLUMN_SCHEMA.format(db_name=db_name)
    client.drop_collection(collection_name=collection_name)

    schema = MilvusClient.create_schema(
        auto_id=True,
        enable_dynamic_field=True,
    )
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="table_name", datatype=DataType.VARCHAR, max_length=128)
    schema.add_field(
        field_name="column_name", datatype=DataType.VARCHAR, max_length=128
    )
    schema.add_field(
        field_name="column_name_vector", datatype=DataType.FLOAT_VECTOR, dim=768
    )
    schema.add_field(
        field_name="column_description", datatype=DataType.VARCHAR, max_length=4096
    )
    schema.add_field(
        field_name="column_description_vector", datatype=DataType.FLOAT_VECTOR, dim=768
    )

    client.create_collection(
        collection_name=collection_name,
        schema=schema,
    )

    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="column_name_vector",
        metric_type="COSINE",
        index_type="IVF_FLAT",
        index_name="column_name_vector_index",
        params={"nlist": 128},
    )
    index_params.add_index(
        field_name="column_description_vector",
        metric_type="COSINE",
        index_type="IVF_FLAT",
        index_name="column_description_vector_index",
        params={"nlist": 128},
    )
    client.create_index(collection_name=collection_name, index_params=index_params)
    res = client.list_indexes(collection_name=collection_name)

    data = []
    for index, row in df.iterrows():
        table_name = row["table_name"]
        column_name = row["original_column_name"]
        column_description = row["description"]
        column_name_vector, column_description_vector = embedding_fn.encode_documents(
            [column_name, column_description]
        )

        data.append(
            {
                "table_name": table_name,
                "column_name": column_name,
                "column_name_vector": column_name_vector,
                "column_description": column_description,
                "column_description_vector": column_description_vector,
            }
        )
    res = client.insert(collection_name=collection_name, data=data)
    print(res)

In [12]:
from sqlalchemy import select, text


def build_column_value(client: MilvusClient, engine, db_name: str):
    collection_name = COLLECTION_NAME_FOR_COLUMN_VALUE.format(db_name=db_name)
    client.drop_collection(collection_name=collection_name)

    schema = MilvusClient.create_schema(
        auto_id=True,
        enable_dynamic_field=True,
    )
    schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name="table_name", datatype=DataType.VARCHAR, max_length=128)
    schema.add_field(
        field_name="column_name", datatype=DataType.VARCHAR, max_length=128
    )
    schema.add_field(
        field_name="column_value", datatype=DataType.VARCHAR, max_length=128
    )
    schema.add_field(
        field_name="column_value_vector", datatype=DataType.FLOAT_VECTOR, dim=768
    )
    client.create_collection(
        collection_name=collection_name,
        schema=schema,
    )

    index_params = MilvusClient.prepare_index_params()
    index_params.add_index(
        field_name="column_value_vector",
        metric_type="COSINE",
        index_type="IVF_FLAT",
        index_name="column_value_vector_index",
        params={"nlist": 128},
    )
    client.create_index(collection_name=collection_name, index_params=index_params)

    for table_name in tool.get_tables_in_database(engine):
        string_columns_names = tool.get_string_columns_in_database(engine, table_name)
        print(string_columns_names)
        for column_name in string_columns_names:
            data = []
            stmt = select(text(column_name)).distinct().select_from(text(table_name))

            with engine.connect() as connection:
                result = connection.execute(stmt)
                unique_values = result.fetchall()

            unique_values = [str(value[0]) for value in unique_values]
            for column_value, column_value_vector in list(
                zip(unique_values, embedding_fn.encode_documents(unique_values))
            ):
                data.append(
                    {
                        "table_name": table_name,
                        "column_name": column_name,
                        "column_value": column_value,
                        "column_value_vector": column_value_vector,
                    }
                )

            res = client.insert(collection_name=collection_name, data=data)

In [13]:
import pandas as pd

df_train = pd.read_csv("./public_dataset/rag/bird_train.csv")
df_dev = pd.read_csv("./public_dataset/rag/bird_dev.csv")
df = pd.concat([df_train, df_dev], ignore_index=True)
df = df.apply(lambda x: x.fillna(""))

# prepare milvus client and database
conn = connections.connect(host="127.0.0.1", port=19530)
db_name_milvus = "bird_bench"
if db_name_milvus not in db.list_database():
    database = db.create_database(db_name_milvus)
db.using_database(db_name_milvus)
client = MilvusClient(uri="http://localhost:19530", db_name=db_name_milvus)


for db_name, db_path in database_info:
    db_info = {
        "type": "sqlite",
        "url": db_path,
    }
    db_engine = tool.get_engine(db_info)

    build_table_schema(client, db_engine, db_name)
    build_column_schema(client, df[df["database_name"] == db_name], db_name)
    # build_column_value(client, db_engine, db_name)

{'insert_count': 8, 'ids': [450680172451281395, 450680172451281396, 450680172451281397, 450680172451281398, 450680172451281399, 450680172451281400, 450680172451281401, 450680172451281402], 'cost': 0}
{'insert_count': 71, 'ids': [450680172451281404, 450680172451281405, 450680172451281406, 450680172451281407, 450680172451281408, 450680172451281409, 450680172451281410, 450680172451281411, 450680172451281412, 450680172451281413, 450680172451281414, 450680172451281415, 450680172451281416, 450680172451281417, 450680172451281418, 450680172451281419, 450680172451281420, 450680172451281421, 450680172451281422, 450680172451281423, 450680172451281424, 450680172451281425, 450680172451281426, 450680172451281427, 450680172451281428, 450680172451281429, 450680172451281430, 450680172451281431, 450680172451281432, 450680172451281433, 450680172451281434, 450680172451281435, 450680172451281436, 450680172451281437, 450680172451281438, 450680172451281439, 450680172451281440, 450680172451281441, 45068017245

In [6]:
from typing import List, Set, Tuple
from tool import ColumnSchema, TableSchema
from collections import defaultdict


# prepare milvus client and database
conn = connections.connect(host="127.0.0.1", port=19530)
db_name_milvus = "bird_bench"
if db_name_milvus not in db.list_database():
    database = db.create_database(db_name_milvus)
db.using_database(db_name_milvus)
client = MilvusClient(uri="http://localhost:19530", db_name=db_name_milvus)
query_embeddings = embedding_fn.encode_documents(
    [
        "full name",
        "driver",
        "delivered",
        "most shipments",
        "least populated city",
        "Min(population)",
        "first_name",
        "last_name",
        "driver_id",
        "Max(Count(ship_id))",
    ]
)


def search_column_schema(
    client: MilvusClient,
    collection_name: str,
    query_emdbeddings: List[List[float]],
    schema,
):

    res = client.load_collection(
        collection_name=collection_name,
    )
    res = client.get_load_state(collection_name=collection_name)

    res = client.search(
        collection_name=collection_name,
        data=query_emdbeddings,
        limit=3,
        output_fields=["table_name", "column_name", "column_description"],
        search_params={"metric_type": "COSINE", "params": {}},
        anns_field="column_name_vector",
    )
    # parse_table_column_from_result(res)
    print(res)
    for result_set in res:
        for single_result in result_set:
            table_name = single_result["entity"]["table_name"]

            schema[table_name].add_column(
                ColumnSchema(
                    single_result["entity"]["column_name"],
                    single_result["entity"]["column_description"],
                    None,
                )
            )


def search_column_value(
    client: MilvusClient,
    collection_name: str,
    query_emdbeddings: List[List[float]],
    schema,
):

    res = client.load_collection(
        collection_name=collection_name,
    )
    res = client.get_load_state(collection_name=collection_name)

    res = client.search(
        collection_name=collection_name,
        data=query_emdbeddings,
        limit=3,
        output_fields=["table_name", "column_name", "column_value"],
        search_params={"metric_type": "COSINE", "params": {}},
        anns_field="column_value_vector",
    )
    print(res)
    for result_set in res:
        for single_result in result_set:
            table_name = single_result["entity"]["table_name"]

            schema[table_name].add_column(
                ColumnSchema(
                    single_result["entity"]["column_name"],
                    None,
                    single_result["entity"]["column_value"],
                )
            )


schema = defaultdict(TableSchema)
search_column_schema(
    client,
    COLLECTION_NAME_FOR_COLUMN_SCHEMA.format(db_name="shipping"),
    query_embeddings,
    schema,
)
search_column_value(
    client,
    COLLECTION_NAME_FOR_COLUMN_VALUE.format(db_name="shipping"),
    query_embeddings,
    schema,
)
dict(schema)
# SELECT T1.first_name, T1.last_name FROM driver AS T1 INNER JOIN shipment AS T2 ON T1.driver_id = T2.driver_id INNER JOIN city AS T3 ON T3.city_id = T2.city_id GROUP BY T1.first_name, T1.last_name, T3.population HAVING T3.population = MAX(T3.population) ORDER BY COUNT(*) DESC LIMIT 1
# {('city', 'population '), ('customer', 'address'), ('customer', 'cust_name'), ('city', 'city_name'), ('driver', 'driver_id'), ('shipment', 'ship_date'), ('city', 'state'), ('driver', 'address'), ('city', 'area'), ('shipment', 'driver_id'), ('shipment', 'weight'), ('customer', 'phone'), ('shipment', 'cust_id'), ('customer', 'city'), ('driver', 'last_name'), ('driver', 'first_name'), ('shipment', 'truck_id'), ('customer', 'zip'), ('shipment', 'ship_id'), ('driver', 'city')}

data: ["[{'id': 450680172451291297, 'distance': 0.6023199558258057, 'entity': {'table_name': 'driver', 'column_name': 'last_name', 'column_description': 'Family name of the drivercommonsense evidence: full name = first_name + last_name'}}, {'id': 450680172451291296, 'distance': 0.5427758693695068, 'entity': {'table_name': 'driver', 'column_name': 'first_name', 'column_description': 'First given name of the driver'}}, {'id': 450680172451291272, 'distance': 0.45189914107322693, 'entity': {'table_name': 'customer', 'column_name': 'cust_name', 'column_description': 'Business name of the customer'}}]", "[{'id': 450680172451291295, 'distance': 0.673912763595581, 'entity': {'table_name': 'driver', 'column_name': 'driver_id', 'column_description': 'Unique identifier for the driver'}}, {'id': 450680172451291289, 'distance': 0.673912763595581, 'entity': {'table_name': 'shipment', 'column_name': 'driver_id', 'column_description': 'A reference to the driver table that indicates which driver transp

{'driver': Table(table_name=None, columns={'last_name': Column(column_name=last_name, column_description=Family name of the drivercommonsense evidence: full name = first_name + last_name, sample_values=set()), 'first_name': Column(column_name=first_name, column_description=First given name of the driver, sample_values=set()), 'driver_id': Column(column_name=driver_id, column_description=Unique identifier for the driver, sample_values=set()), 'address': Column(column_name=address, column_description=Street address of the driver's home, sample_values=set()), 'city': Column(column_name=city, column_description=City the driver lives in, sample_values=set())}),
 'customer': Table(table_name=None, columns={'cust_name': Column(column_name=cust_name, column_description=Business name of the customer, sample_values=set()), 'phone': Column(column_name=phone, column_description=Telephone number to reach the customer, sample_values=set()), 'zip': Column(column_name=zip, column_description=Postal co