In [8]:
import numpy as np
import os
import sys
import pickle
import json
import warnings
warnings.filterwarnings('ignore')
sys.path.append("/home/ziniuw/zero-shot-cost-estimation")
from cross_db_benchmark.benchmark_tools.utils import load_json
from data_driven_cardinalities.deepdb.schemas.generate_schema import gen_schema


In [41]:
result_path = "/flash1/ziniuw/CEB/zero-shot-results/MSCN3118603545/queries/zero-shot-test-preds/"
all_MSCN_est = dict()
for file in os.listdir(result_path):
    if file.endswith(".pkl"):
        query_no = int(file.split(".pkl")[0])
        with open(result_path + file, "rb") as f:
            all_MSCN_est[query_no] = pickle.load(f)
print(len(all_MSCN_est))

1861


In [7]:
query_sql_file = "/home/ziniuw/zero-shot-data/runs/raw/imdb_full/complex_queries_testing_2k.json"
with open(query_sql_file, "r") as f:
    query_sql = json.load(f)

In [9]:
query_plan_file = '/home/ziniuw/zero-shot-data/runs/parsed_plans/imdb_full/complex_queries_testing_2k.json'
queries = load_json(query_plan_file, namespace=True)
schema = gen_schema("imdb_full", "/home/ziniuw/zero-shot-data/datasets/imdb")

In [26]:
print(MSCN_est[10])
print(query_sql[10])

{('mii',): 1422443.116839281, ('mii', 't'): 1500213.8026668222, ('t',): 2977703.3477305686}
['SELECT COUNT(*) FROM "movie_info_idx"  WHERE "movie_info_idx"."id" >= 239771;', 380088.0, 473972.0, 110.679]


In [53]:
import collections
import itertools
import json
import logging
import types
import os
from json import JSONDecodeError
from time import perf_counter
import copy

import numpy as np
from tqdm import tqdm

from cross_db_benchmark.benchmark_tools.parse_run import dumper
from cross_db_benchmark.benchmark_tools.utils import load_json
from models.training.checkpoint import save_csv

logger = logging.getLogger(__name__)


def augment_cardinalities(schema, all_MSCN_est, src, table_aliases, target, scale=1):
    try:
        run = load_json(src, namespace=True)
    except JSONDecodeError:
        raise ValueError(f"Error reading {src}")

    q_stats = []

    # find out if this an non_inclusive workload (< previously replaced by <=)
    non_inclusive = False
    if any([b in src for b in ['job-light', 'scale', 'synthetic']]):
        non_inclusive = True
        print("Assuming NON-INCLUSIVE workload")

    est_pg = 0
    est_mscn = 0
    all_query_tables = []
    for q_id, p in enumerate(tqdm(run.parsed_plans)):
        if q_id not in all_MSCN_est:
            continue
        MSCN_est = all_MSCN_est[q_id]
        p.plan_parameters.est_pg = 0
        p.plan_parameters.est_mscn = 0
        all_tables = []
        _ = augment_bottom_up(schema, p, q_id, run.database_stats, MSCN_est, table_aliases, q_stats, p, scale,
                                   non_inclusive=non_inclusive, all_tables=all_tables)
        all_query_tables.append(all_tables)
        est_pg += p.plan_parameters.est_pg
        est_mscn += p.plan_parameters.est_mscn

        
        def augment_prod(p):
            if len(p.children) == 0:
                p.plan_parameters.cc_est_children_card = 1
            else:
                child_card = 1
                for c in p.children:
                    child_card *= c.plan_parameters.cc_est_card
                    augment_prod(c)
                p.plan_parameters.cc_est_children_card = child_card

        augment_prod(p)
    
    """
    argumented_queries = types.SimpleNamespace()
    argumented_queries.database_stats = run.database_stats
    argumented_queries.run_kwargs = run.run_kwargs
    argumented_queries.parsed_plans = []
    for q_id, p in enumerate(run.parsed_plans):
        if q_id in all_MSCN_est:
            argumented_queries.parsed_plans.append(p)
            
    print(len(argumented_queries.parsed_plans))
    target_dir = os.path.dirname(target)
    os.makedirs(target_dir, exist_ok=True)
    with open(target, 'w') as outfile:
        json.dump(argumented_queries, outfile, default=dumper)
    """
    return all_query_tables


def report_stats(est_mscn, est_pg, q_stats):
    if len(q_stats) > 0:
        def report_percentiles(key):
            vals = np.array([q_s[key] for q_s in q_stats])
            print(f"{key}: p50={np.median(vals):.2f} p95={np.percentile(vals, 95):.2f} "
                  f"p99={np.percentile(vals, 99):.2f} pmax={np.max(vals):.2f}")

        report_percentiles('q_errors_pg')
        report_percentiles('q_errors_mscn')
        report_percentiles('latencies')
        print(f"{est_mscn / (est_mscn + est_pg) * 100:.2f}% estimated using MSCN")
        

def match_sub_queries(tables, MSCN_est, table_aliases, q_id):
    aliased_tables = set()
    for table in tables:
        alias = table_aliases[table]
        aliased_tables.add(alias)
    for alias in MSCN_est:
        alias_set = set(alias)
        if alias_set == aliased_tables:
            return MSCN_est[alias]
    print(f"query {q_id}: {aliased_tables} not found in {MSCN_est.keys()}")
    return None


def augment_bottom_up(schema, plan, q_id, database_statistics, MSCN_est, table_aliases,
                      q_stats, top_p, scale, all_tables, non_inclusive=False):
    workers_planned = vars(plan.plan_parameters).get('workers_planned')
    if workers_planned is None:
        workers_planned = 0
    # assert workers_planned is not None

    aggregation_below = 'Aggregate' in plan.plan_parameters.op_name

    # augment own tables
    tables = set()
    t_idx = vars(plan.plan_parameters).get('table')
    if t_idx is not None:
        table_stats = database_statistics.table_stats[t_idx]
        if hasattr(table_stats, 'relname'):
            table_name = table_stats.relname
        elif hasattr(table_stats, 'table'):
            table_name = table_stats.table
        else:
            raise NotImplementedError
        tables.add(table_name)

    for c in plan.children:
        c_aggregation_below, c_tables = augment_bottom_up(schema, c, q_id, database_statistics,
                                                          MSCN_est, table_aliases, q_stats,
                                                          top_p, scale,
                                                          non_inclusive=non_inclusive,
                                                          all_tables=all_tables
                                                          )
        aggregation_below |= c_aggregation_below
        tables.update(c_tables)

    # evaluate query
    act_card, pg_est_card = get_act_est_card(plan.plan_parameters)

    query_parsed = True
    q = None
    if len(tables) == 0:
        print("Could not parse query")
        query_parsed = False

    # query not supported
    if not query_parsed:
        plan.plan_parameters.cc_est_card = pg_est_card
        top_p.plan_parameters.est_pg += 1

    # group by not directly supported
    elif aggregation_below:
        plan.plan_parameters.cc_est_card = pg_est_card
        top_p.plan_parameters.est_pg += 1

    # we do not care about really small cardinalities
    elif (act_card is not None and pg_est_card <= 1000 and act_card <= 1000):
        plan.plan_parameters.cc_est_card = pg_est_card
        top_p.plan_parameters.est_pg += 1

    else:
        if plan.plan_parameters.op_name in {'Parallel Seq Scan', 'Hash Join', 'Nested Loop', 'Seq Scan', 'Materialize',
                                            'Hash', 'Parallel Hash', 'Merge Join', 'Gather', 'Gather Merge',
                                            'Hash Right Join', 'Hash Left Join', 'Nested Loop Left Join',
                                            'Merge Left Join', 'Merge Right Join'} \
                or plan.plan_parameters.op_name.startswith('XN ') \
                or plan.plan_parameters.op_name in {'Broadcast', 'Distribute'}:
            op_name = plan.plan_parameters.op_name
            
            cardinality_predict = match_sub_queries(tables, MSCN_est, table_aliases, q_id)
            if cardinality_predict is None:
                cardinality_predict = pg_est_card
            all_tables.append(copy.deepcopy(tables))
            if workers_planned > 0 and (op_name.startswith('Parallel')):
                cardinality_predict /= (workers_planned + 1)

            if act_card is not None:
                q_err_mscn = q_err(cardinality_predict, act_card)
                q_err_pg = q_err(pg_est_card, act_card)
            else:
                q_err_mscn = 1
                q_err_pg = 1

            # this was probably a bug, anyway rarely happens
            if q_err_mscn > 100 * q_err_pg:
                plan.plan_parameters.cc_est_card = pg_est_card
                top_p.plan_parameters.est_pg += 1
            else:
                plan.plan_parameters.cc_est_card = cardinality_predict
                top_p.plan_parameters.est_mscn += 1

                q_stats.append({
                    'query_id': q_id,
                    'q_errors_pg': q_err_pg,
                    'q_errors_mscn': q_err_mscn
                })

        # ignore this in the stats since pg semantics for cardinalities are different for this operator
        elif plan.plan_parameters.op_name in {'Index Only Scan', 'Index Scan', 'Parallel Index Only Scan',
                                              'Bitmap Index Scan', 'Parallel Bitmap Heap Scan', 'Bitmap Heap Scan',
                                              'Sort', 'Parallel Index Scan', 'BitmapAnd'}:
            plan.plan_parameters.cc_est_card = pg_est_card
            top_p.plan_parameters.est_pg += 1
        else:
            raise NotImplementedError(plan.plan_parameters.op_name)
    
    return aggregation_below, tables


def get_act_est_card(params):
    if hasattr(params, 'act_card'):
        act_card = params.act_card
        pg_est_card = params.est_card
    elif hasattr(params, 'est_rows'):
        act_card = params.act_avg_rows
        pg_est_card = params.est_rows
    # only estimated available
    elif hasattr(params, 'est_card'):
        # pretend that postgres is true
        act_card = None
        pg_est_card = params.est_card
    else:
        print(params)
        raise NotImplementedError
    return act_card, pg_est_card


def q_err(cardinality_predict, cardinality_true):
    if cardinality_predict == 0 and cardinality_true == 0:
        q_error = 1.
    elif cardinality_true == 0:
        q_error = 1.
    elif cardinality_predict == 0:
        q_error = cardinality_true
    else:
        q_error = max(cardinality_predict / cardinality_true, cardinality_true / cardinality_predict)
    return q_error



In [54]:
table_aliases = {}
table_aliases["title"] = "t"
table_aliases["cast_info"] = "ci"
table_aliases["movie_info"] = "mi"
table_aliases["movie_info_idx"] = "mii"
table_aliases["person_info"] = "pi"
table_aliases["name"] = "n"
table_aliases["aka_name"] = "an"
table_aliases["keyword"] = "k"
table_aliases["movie_keyword"] = "mk"
table_aliases["movie_companies"] = "mc"
table_aliases["movie_link"] = "ml"
table_aliases["aka_title"] = "at"
table_aliases["complete_cast"] = "cc"
table_aliases["kind_type"] = "kt"
table_aliases["role_type"] = "rt"
table_aliases["char_name"] = "chn"
table_aliases["info_type"] = "it"
table_aliases["company_type"] = "ct"
table_aliases["company_name"] = "cn"
table_aliases["movie_link"] = "ml"
table_aliases["link_type"] = "lt"
table_aliases["comp_cast_type"] = "cct"

In [55]:
query_plan_file = '/home/ziniuw/zero-shot-data/runs/parsed_plans/imdb_full/complex_queries_testing_2k.json'
target = '/home/ziniuw/zero-shot-data/runs/MSCN_augmented/imdb_full/complex_queries_testing_2k.json'
all_tables = augment_cardinalities(schema, all_MSCN_est, query_plan_file, table_aliases, target)

100%|████████████████████████████████████| 2029/2029 [00:00<00:00, 14553.17it/s]

query 3: {'it', 'mii'} not found in dict_keys([('mii',), ('mii', 't'), ('t',)])
query 7: {'an'} not found in dict_keys([('it',), ('it', 'mii'), ('it', 'mii', 't'), ('mii',), ('mii', 't'), ('t',)])
query 15: {'mii'} not found in dict_keys([('k',), ('k', 'mk'), ('mk',)])
query 16: {'an'} not found in dict_keys([('k',), ('k', 'mk'), ('mk',)])
query 18: {'an'} not found in dict_keys([('k',), ('k', 'mk'), ('mk',)])
query 24: {'an'} not found in dict_keys([('k',), ('k', 'mk'), ('mk',)])
query 28: {'mii'} not found in dict_keys([('an',), ('an', 'n'), ('n',)])
query 34: {'mc'} not found in dict_keys([('it',), ('it', 'mii'), ('it', 'mii', 't'), ('mii',), ('mii', 't'), ('t',)])
query 40: {'mc'} not found in dict_keys([('kt',), ('kt', 'mii', 't'), ('kt', 't'), ('mii',), ('mii', 't'), ('t',)])
query 42: {'mc'} not found in dict_keys([('kt',), ('kt', 'mii', 't'), ('kt', 't'), ('mii',), ('mii', 't'), ('t',)])
query 44: {'mc'} not found in dict_keys([('it',), ('it', 'mii'), ('mii',)])
query 49: {'mii




In [66]:
idx = 410
print(MSCN_est[idx])
print("====================================================")
with open(f"/flash1/ziniuw/CEB/queries/zero-shot-test/zero-shot-test-all/{idx}.pkl", "rb") as f:
    mscn_query = pickle.load(f)
print(mscn_query['sql'])
print("====================================================")
print(query_sql[idx])
print("====================================================")
print(all_tables[idx])

{('kt',): 8.14919714468732, ('kt', 'mc', 't'): 4144642.153147094, ('kt', 't'): 2432500.360294634, ('mc',): 2786319.252112494, ('mc', 't'): 2684213.3873678246, ('t',): 2977703.3477305686}
SELECT AVG("name"."id")  FROM "title" AS t,"kind_type" AS kt,"movie_companies" AS mc WHERE kt."id"=t."kind_id" AND t."id"=mc."movie_id" AND CT AVG("name"."id") FROM "name";
['SELECT AVG("name"."id") FROM "name";', 1389164.0, 1736515.0, 408.857]
[{'title'}, {'title', 'kind_type'}]


In [57]:

print(query_sql[10])

[{'movie_info_idx'}, {'info_type', 'movie_info_idx'}]

In [35]:
query_sql[10]

['SELECT COUNT(*) FROM "movie_info_idx"  WHERE "movie_info_idx"."id" >= 239771;',
 380088.0,
 473972.0,
 110.679]