estimated cardinality on each single table from the traditional CardEst methods(psql)

In [None]:
# collect suitable sql from cardbench
import glob
import numpy as np
import tqdm

from sparse_deferred.structs import graph_struct
import sparse_deferred.np as sdnp

GraphStruct = graph_struct.GraphStruct
InMemoryDB = graph_struct.InMemoryDB

def clean_query_string(query):
  return query.replace("'", "'").replace("\n", "").replace(";", "").replace("bq-cost-models-exp.", "")

def get_cardbench_dataset(join_type, dataset_name):
    '''
    The training datasets are stored sharded (namely split into multiple files)
    to find all the shards of a training dataset we use glob. 
    '''

    filename = f"../CardBench/CardBench_zero_shot_cardinality_training/training_datasets/{join_type}/{dataset_name}_{join_type}.npz"
    filenames = glob.glob(filename + '-*')
    filenames.sort(key=lambda f: int(f.split('-')[-1]))
    cardinalities = []
    queries = []
    
    for file in tqdm.tqdm(filenames): 
        np_data = np.load(open(file, 'rb'), allow_pickle=True)  
        for k, np_arr in np_data.items():
            k_parts = k.split('.')
            if k_parts[0] == 'feat' and k_parts[1] == 'n' and k_parts[2] == 'g': 
                if k_parts[3] == 'cardinality':
                    cardinalities += np_arr.tolist()
                elif k_parts[3] == 'query':
                    queries += np_arr.tolist()
    
    assert len(cardinalities) == len(queries)
    sqls = []
    for query, card in zip(queries, cardinalities):
        if 'IS NOT NULL' in query.decode():
            continue
        sqls.append(f'{clean_query_string(query.decode())};||{card}||')
    print(len(sqls))    
    with open(f'/opt/hdd/datasets/user/cardbench/{join_type}_workload/{dataset_name}_without_pg_est_card.sql', "w") as f:
        for sql in sqls:
            f.write(sql + "\n")
    
    return sqls

In [None]:
dataset_names = ['accidents','airline','cms_synthetic_patient_data_omop','consumer','covid19_weathersource_com','crypto_bitcoin_cash','employee','ethereum_blockchain','geo_openstreetmap','github_repos','human_variant_annotation','idc_v10','movielens','open_targets_genetics','samples','stackoverflow','tpch_10G','usfs_fia','uspto_oce_claims','wikipedia']
for dataset_name in dataset_names:
    get_cardbench_dataset('binary_join', dataset_name)

In [11]:
# collect querygraphs from cardbench
import numpy as np
import glob
import tqdm

from sparse_deferred.structs import graph_struct
import sparse_deferred.np as sdnp

GraphStruct = graph_struct.GraphStruct
InMemoryDB = graph_struct.InMemoryDB

def get_querygraphs_from_cardbench(join_type, dataset_name):
    filename = f"/opt/hdd/datasets/user/cardbench/{join_type}/{dataset_name}_{join_type}.npz"
    filenames = glob.glob(filename + '-*')
    filenames.sort(key=lambda f: int(f.split('-')[-1]))
    cardinalities = []
    queries = []
    
    db = graph_struct.InMemoryDB()

    for file in tqdm.tqdm(filenames):
        db_temp = graph_struct.InMemoryDB.from_file(file)
        for i in range(db_temp.size):
            db.add(db_temp.get_item(i))
    db.finalize()
    
    print("First table name:", db.get_item(0).nodes["tables"]["name"][0])
        

In [12]:
get_querygraphs_from_cardbench('binary_join', 'accidents')

100%|██████████| 29/29 [00:01<00:00, 16.85it/s]


First table name: b'bq-cost-models-exp.accidents.nesreca'


In [33]:
# add query_id
import numpy as np

def add_query_id(join_type, dataset_name):
    filename = f"/opt/hdd/datasets/user/cardbench/{join_type}/{dataset_name}_{join_type}.npz"
    query_ids = []
    with np.load(open(filename, 'rb'), allow_pickle=True) as np_data:
        for k, np_arr in np_data.items():
            k_parts = k.split('.')
            if k_parts[0] == 'feat' and k_parts[1] == 'n' and k_parts[2] == 'g': 
                print(k)
                if k_parts[3] == 'database_query_id':
                    query_ids += np_arr.tolist()
                    print(np_arr.tolist())
    print(len(query_ids))
    

In [34]:
add_query_id('binary_join', 'baseball')

feat.n.g.cardinality
feat.n.g.database_query_id
[b'15234', b'15235', b'15236', b'15237', b'15238', b'15239', b'15240', b'15241', b'15242', b'15243', b'15244', b'15245', b'15246', b'15247', b'15248', b'15249', b'15250', b'15251', b'15252', b'15253', b'15255', b'15256', b'15257', b'15258', b'15259', b'15260', b'15261', b'15262', b'15263', b'15264', b'15265', b'15266', b'15267', b'15268', b'15270', b'15271', b'15272', b'15273', b'15274', b'15275', b'15276', b'15277', b'15278', b'15279', b'15280', b'15281', b'15282', b'15283', b'15284', b'15285', b'15286', b'15287', b'15288', b'15289', b'15290', b'15291', b'15292', b'15293', b'15294', b'15295', b'15296', b'15297', b'15298', b'15299', b'15300', b'15301', b'15302', b'15303', b'15304', b'15305', b'15306', b'15307', b'15308', b'15309', b'15310', b'15311', b'15312', b'15313', b'15314', b'15315', b'15316', b'15317', b'15318', b'15319', b'15320', b'15321', b'15322', b'15323', b'15324', b'15325', b'15326', b'15327', b'15328', b'15329', b'15330', b

In [None]:
# collect the cardinality estimated by traditional methods
import sqlglot
from pilotscope.DBInteractor.PilotDataInteractor import PilotDataInteractor
from pilotscope.PilotConfig import PostgreSQLConfig

def get_pg_est_card(datasets_name = []):
    '''
    using pilotscope to get single table ce.
    :param datasets_name: the list of names about the datasets to be preprocess
    '''
    config = PostgreSQLConfig()
    config.db_host = "localhost"
    config.db_user = "postgres"
    config.db_user_pwd = "postgres"
    config.db_port = 54323

    for dataset_name in datasets_name:
        workload_in_path = f'/opt/hdd/datasets/user/cardbench/binary_join_workload/{dataset_name}_without_pg_est_card.sql'
        workload_out_path = f'/opt/hdd/datasets/user/cardbench/binary_join_workload/{dataset_name}.sql'
        print(f"dataset: {dataset_name}, workload_in_path: {workload_in_path}, workload_out_path: {workload_out_path}")
        config.db = dataset_name
        data_interactor = PilotDataInteractor(config)
        
        with open(workload_in_path, 'r') as read_files:
            sqls = []
            for line in read_files:
                sql = line.strip()
                sqls.append(sql)

        subsqls, count = [], 0
        for line in sqls:
            spilt_infos = line.split("||")
            sql, true_card, _ = spilt_infos[0], spilt_infos[1], spilt_infos[2]
            data_interactor.pull_subquery_card()
            result = data_interactor.execute(sql)
            subquerys = list(result.subquery_2_card.keys())
            subsqls.append(sql)
            count += 1
            if count % 100 == 0:
                print(f"{count} sqls has been processed")
            for subquery in subquerys:
                try:
                    tables = [table for table in sqlglot.parse_one(subquery).find_all(sqlglot.exp.Table)]
                except Exception as e:
                    print(f"sql: {sql}")
                    print(f"subquery: {subquery}")
                    break
                if len(tables) == 1: # skip single table estimated cardinality
                    continue
                subsqls.append(f'{subquery};||{true_card}||{result.subquery_2_card[subquery]}')

        with open(workload_out_path, "w") as f:
            for sql in subsqls:
                f.write(sql + "\n")

        print(f"dataset: {dataset_name}")

In [None]:
# waiting for raw data
dataset_names = ['accidents','airline','cms_synthetic_patient_data_omop','consumer','covid19_weathersource_com','crypto_bitcoin_cash','employee','ethereum_blockchain','geo_openstreetmap','github_repos','human_variant_annotation','idc_v10','movielens','open_targets_genetics','samples','stackoverflow','tpch_10G','usfs_fia','uspto_oce_claims','wikipedia']
for dataset_name in dataset_names:
    get_pg_est_card(dataset_name)

In [None]:
# generate the pk_fk_test.sql for baseball used in cardbench
# referred to utils/statistics/gen_fanout.py

import os
import re

import pandas as pd

from utils.statistics.tools import load_abbrev_coltype, load_table_datas, load_tbls_cols_types, replace_comments

def dtype2sqltype(datatype):
  match datatype:
    # case "NUMERIC" | "BIGNUMERIC":
    #   return Datatype.NUMERIC
    case pd.Int64Dtype():
      return "INT64"
    case pd.Float64Dtype():
      return "FLOAT64"
    case pd.StringDtype():
      return "STRING"
    # case "TIME" | "TIMESTAMP" | "DATE" | "DATETIME":
    #   return Datatype.CATEGORICAL
    case _:
      print("ERROR in datatype conversion", datatype)
      
def load_tbl_primary_key(folder_path, db):
    load_file = f"{folder_path}/datasets/{db}/postgres_create_{db}.sql"
    primary_keys = {}
    print("load table info from " + load_file)
    with open(load_file, 'r') as lf:
        for line in lf:
            if "create table" in line:
                tbl = line.split("create table")[1].split()[0].replace('"', '')
                primary_keys[tbl] = []
            elif "primary key" in line:
                primary_keys[tbl] = re.findall(r'\((.*?)\)', line)[0].split(",")
    return primary_keys            

def gen_pk_fk_sql(db, usage):
    current_dir = f"./"
    folder_path = f"{current_dir}/datas"
    
    # load abbrev: table name and alias, col_type: continuous or discrete
    abbrev, col_type = load_abbrev_coltype(folder_path, db, usage)
    abbrev_inv = {v: k for k, v in abbrev.items()}

    # load each table's column types
    tbls_cols_types, _ = load_tbls_cols_types(folder_path, db)

    # load data
    # print("------------------load table------------------")
    # tables = load_table_datas(folder_path, db, abbrev, tbls_cols_types)
    
    # load table's primary keys
    primary_keys = load_tbl_primary_key(folder_path, db)
    
    # load workload
    workload_file = f"{folder_path}/workloads/{usage}/{db}/workloads.sql"
    print(f"read workload: {workload_file}")
    joins = set([])
    with open(workload_file, 'r') as lf:
        for line in lf:
            line = replace_comments(line, '')
            line = line.split("||")[0].strip()
            if line.startswith('select'):
                if 'where' not in line:
                    continue
                candidates = line.strip().split('where')[1].strip(';').strip()
                candidates = re.split(r'(?i)\band\b', candidates)        
            elif line.startswith('SELECT'):
                if 'WHERE' not in line:
                    continue
                candidates = line.strip().split('WHERE')[1].strip(';').strip()
                candidates = re.split(r'(?i)\band\b', candidates)
            else:
                raise ValueError('workload file must start with select or SELECT')
            
            candidates = [c.strip('(') for c in candidates]
            candidates = [c.strip(')') for c in candidates]
            candidates = [c.strip() for c in candidates if ' = ' in c and '.' in c.split(' = ')[0] and '.' in c.split(' = ')[1]]
            for c in candidates:
                left, right = c.split('=')[0].strip(), c.split('=')[1].strip()
                left = left.replace('(', '').replace(')', '').replace(';', '')
                right = right.replace('(', '').replace(')', '').replace(';', '')
                if left.split('.')[0] in abbrev_inv and right.split('.')[0] in abbrev_inv:
                    if (left, right) not in joins and (right, left) not in joins:
                        joins.add((left, right))
                    else:
                        continue
                else:
                    continue
    print('joins:', joins)
    
    # pk_fk.sql
    pk_fk = {}
    print('------------------calculating fanout...------------------')
    sql = (
        f"INSERT INTO 'cardbench_metadata.baseball_metadata.pk_fk' \n" 
        f"('project_name', 'dataset_name', 'primary_key_table_name', 'primary_key_column_name', 'foreign_key_table_name', 'foreign_key_column_name', 'column_type') \n"
        f"VALUES \n"
    )
    for join in joins:
        left, right = join[0], join[1]
        left_table, left_column = left.split('.')[0], left.split('.')[1].replace(')', '').replace('(', '').replace(';', '')
        right_table, right_column = right.split('.')[0], right.split('.')[1].replace(')', '').replace('(', '').replace(';', '')
        
        left_table, right_table = abbrev_inv[left_table], abbrev_inv[right_table]
        
        column_type = tbls_cols_types[left_table][left_column]
        
        if left_column not in primary_keys[left_table]:
            tmp_table, tmp_column = left_table, left_column
            left_table, left_column = right_table, right_column
            right_table, right_column = tmp_table, tmp_column
        
        sql += (f"('cardbench_data', '{db}', '{left_table}', '{left_column}', '{right_table}', '{right_column}', '{dtype2sqltype(column_type)}'), \n")
    sql = sql.strip()[:-1] + ';' 
    with open('./test.sql', 'w') as wf:  
        wf.write(sql)  

In [4]:
gen_pk_fk_sql('baseball', 'pretrain')

abbrev:  {'allstarfull': 'bb_asf', 'appearances': 'bb_apr', 'awardsmanagers': 'bb_am', 'awardsplayers': 'bb_ap', 'awardssharemanagers': 'bb_asm', 'awardsshareplayers': 'bb_asp', 'batting': 'bb_bat', 'battingpost': 'bb_btp', 'els_teamnames': 'bb_etn', 'fielding': 'bb_fld', 'fieldingof': 'bb_fof', 'fieldingpost': 'bb_fp', 'halloffame': 'bb_hof', 'managers': 'bb_mgr', 'managershalf': 'bb_mgh', 'pitching': 'bb_pit', 'pitchingpost': 'bb_ptp', 'players': 'bb_plr', 'salaries': 'bb_slr', 'schools': 'bb_sch', 'schoolsplayers': 'bb_scp', 'seriespost': 'bb_sp', 'teams': 'bb_tm', 'teamsfranchises': 'bb_tf', 'teamshalf': 'bb_th'}
bb_asf.ctn: ['playerid', 'teamid']
bb_asf.dsct: ['yearid', 'gamenum', 'gameid', 'teamid', 'lgid', 'gp', 'startingpos']
bb_asf.PK: []
--------------------------------------------------
bb_apr.ctn: ['yearid', 'teamid', 'playerid', 'g_all', 'g_batting', 'g_defense', 'g_p', 'g_1b', 'g_2b', 'g_3b', 'g_ss', 'g_lf', 'g_cf', 'g_rf', 'g_of', 'g_dh', 'g_ph']
bb_apr.dsct: ['lgid', 'g