In [None]:
import pickle
import sys
sys.path.append("../")
from Join_scheme.data_prepare import process_stats_data
from BayesCard.Models.Bayescard_BN import Bayescard_BN
import time
import pandas as pd
import numpy as np
from BayesCard.Evaluation.cardinality_estimation import parse_query_single_table

In [None]:
from Join_scheme.data_prepare import process_stats_data
data_path = "/home/ubuntu/End-to-End-CardEst-Benchmark/datasets/stats_simplified/{}.csv"
model_folder = "/home/ubuntu/data_CE/saved_models"
data, null_values, key_attrs, table_buckets, equivalent_keys, schema, bin_size, all_bin_means, all_bin_width = process_stats_data(data_path,
                                                model_folder, 200, "sub_optimal", return_bin_means=True)

In [None]:
all_bin_means['tags.ExcerptPostId'] = np.ones(48)
all_bin_width['tags.ExcerptPostId'] = np.ones(48) * len(data["tags"]["tags.ExcerptPostId"])/48
all_bin_means['posts.Id'] = np.ones(48)
all_bin_width['posts.Id'] = np.ones(48) * len(data["posts"]["posts.Id"])/48
all_bin_means['users.Id'] = np.ones(107)
all_bin_width['users.Id'] = np.ones(107) * len(data["users"]["users.Id"])/107


In [None]:
def learn_histogram(data, key_attrs, all_bin_means, all_bin_width, all_bin_size, bin_size=50):
    all_histogram = dict()
    all_boundary = dict()
    for table in data:
        all_histogram[table] = dict()
        all_boundary[table] = dict()
        for attr in data[table]:
            if attr in key_attrs[table]:
                assert all_bin_size[table][attr] == len(all_bin_means[attr]) == len(all_bin_width[attr])
            else:
                hist, curr_bins = np.histogram(data[table][attr].values, bins=bin_size)
                all_histogram[table][attr] = hist/np.sum(hist)
                all_boundary[table][attr] = curr_bins
    return all_histogram, all_boundary

In [None]:
all_histogram, all_boundary = learn_histogram(data, key_attrs, all_bin_means, all_bin_width, bin_size, bin_size=50)

In [None]:
import numpy as np
import copy

from Join_scheme.join_graph import process_condition, get_join_hyper_graph
from Join_scheme.data_prepare import identify_key_values
from BayesCard.Evaluation.cardinality_estimation import timestamp_transorform, construct_table_query

OPS = {
    '>': np.greater,
    '<': np.less,
    '>=': np.greater_equal,
    '<=': np.less_equal,
    '=': np.equal,
    '==': np.equal
}

class Bound_ensemble:
    """
    This the class where we store all the trained models and perform inference on the bound.
    """
    def __init__(self, hist, boundary, all_bin_means, all_key_size, schema):
        self.hist = hist
        self.boundary = boundary
        self.schema = schema
        self.all_bin_means = all_bin_means
        self.all_key_size = all_key_size
        self.all_keys, self.equivalent_keys = identify_key_values(schema)

    def parse_query_simple(self, query):
        """
        If your selection query contains no aggregation and nested sub-queries, you can use this function to parse a
        join query. Otherwise, use parse_query function.
        """
        query = query.replace(" where ", " WHERE ")
        query = query.replace(" from ", " FROM ")
        query = query.replace(" and ", " AND ")
        query = query.split(";")[0]
        query = query.strip()
        tables_all = {}
        join_cond = []
        table_probs = {}
        join_keys = {}
        tables_str = query.split(" WHERE ")[0].split(" FROM ")[-1]
        for table_str in tables_str.split(","):
            table_str = table_str.strip()
            if " as " in table_str:
                tables_all[table_str.split(" as ")[-1]] = table_str.split(" as ")[0]
            else:
                tables_all[table_str.split(" ")[-1]] = table_str.split(" ")[0]

        # processing conditions
        conditions = query.split(" WHERE ")[-1].split(" AND ")
        for cond in conditions:
            table, cond, join, join_key = process_condition(cond, tables_all)
            if table not in table_probs:
                table_probs[table] = 1
            if not join:
                attr = cond[0]
                op = cond[1]
                value = cond[2]
                if "Date" in attr:
                    assert "::timestamp" in value
                    value = timestamp_transorform(value.strip().split("::timestamp")[0])
                curr_prob = 0
                for i in range(0, len(self.boundary[table][attr])-1):
                    if OPS[op](self.boundary[table][attr][i], value):
                        curr_prob += self.hist[table][attr][i]
                table_probs[table] *= curr_prob
                #construct_table_query(self.bns[table], table_query[table], attr, op, value)
            else:
                join_cond.append(cond)
                for tab in join_key:
                    if tab in join_keys:
                        join_keys[tab].add(join_key[tab])
                    else:
                        join_keys[tab] = set([join_key[tab]])
        final_probs = 1
        for table in table_probs:
            final_probs *= table_probs[table]
        #print(final_probs)
        if final_probs == 0:
            final_probs = 0.001
        return tables_all, final_probs, join_cond, join_keys
    
    def multiply_hist_oned(self, all_probs, all_means):
        all_probs = np.stack(all_probs, axis=0)
        all_means = np.stack(all_means, axis=0)
        multiplier = np.prod(all_means, axis=0)
        min_number = np.amin(all_probs, axis=0)
        multiplier = multiplier * min_number
        return np.sum(multiplier)

    def eliminate_one_key_group(self, key_group, relevant_keys, res):
        all_means = []
        all_probs = []
        for key in relevant_keys:
            if res:
                hist = self.all_bin_means[key] * self.all_key_size[key]
                #print("key", np.sum(hist))
                ratio = res/np.sum(hist)
                all_means.append(self.all_bin_means[key]*ratio)
                all_probs.append(self.all_key_size[key])
            else:
                #print(key, np.sum(self.all_bin_means[key] * self.all_key_size[key]))
                all_means.append(self.all_bin_means[key])
                all_probs.append(self.all_key_size[key])
        return self.multiply_hist_oned(all_probs, all_means)
        

    def get_cardinality(self, query_str):
        tables_all, table_probs, join_cond, join_keys = self.parse_query_simple(query_str)
        equivalent_group = get_join_hyper_graph(join_keys, self.equivalent_keys)
        res = None
        for key_group in equivalent_group:
            res = self.eliminate_one_key_group(key_group, equivalent_group[key_group], res)
        if res <= 1:
            res = 1
        return res * table_probs

In [None]:
BE = Bound_ensemble(all_histogram, all_boundary, all_bin_means, all_bin_width, schema)

In [None]:
query_file = "/home/ubuntu/End-to-End-CardEst-Benchmark/workloads/stats_CEB/sub_plan_queries/stats_CEB_sub_queries.sql"
with open(query_file, "r") as f:
    queries = f.readlines()

In [None]:
qerror = []
latency = []
pred = []
for i, query_str in enumerate(queries):
    #if i == 10: break
    query = query_str.split("||")[0][:-1]
    print("========================")
    true_card = int(query_str.split("||")[-1])
    t = time.time()
    res = BE.get_cardinality(query)
    pred.append(res)
    latency.append(time.time() - t)
    qerror.append(res/true_card)
    print(f"estimating query {i}: predicted {res}, true_card {true_card}, qerror {res/true_card}, latency {time.time() - t}")

In [None]:
for i in [50, 90, 95, 99, 100]:
    print(f"q-error {i}% percentile is {np.percentile(qerror, i)}")
print(f"total inference time: {np.sum(latency)}")

In [None]:
with open("stats_CEB_join_hist.txt", "w") as f:
    for p in pred:
        f.write(str(p)+"\n")