Learned Selectivity Prediction using Online Linear Regression:
-------------------------------------------------------------


So, essentially what I am trying to do is develop a simple algorithm (similar to the postgres query optimizer) which, given a query, predicts the cheapest access path for each table, assuming that the postgres query optimizer will always select the cheapest access path for each table involved in the query, where cost is proportional to total number of disk accesses. To make that prediction the algorithm uses table statistics such as a postgres stats histogram to estimate the selectivity of each precicate and then use independence assumption to compute cardinality (if multiple predicates are present on a single table). For skewed data distributions, I understand that postgres stats histograms may not be accurate and therefore my selectivity estimates may also be highly inaccurate. Thats why my initial idea was to learn a CDF function online for each table attribute. 

However, I realize now that this approach may not work very well, and instead maybe I should develop a single regression model on each table which can directly predict the selectivity of a given predicate. I could train this model online using explain analyze results from real time query executions. Beacuse those query plan results can be used to know the exact selectivity of each predicate, and features for predicates are also easy to extract and have a simple form, this type of model could potentially be better suited for my task. 

To keep things simple in the beginning, I want to start with a model for each table that only predicts selectivity for a single predicate, i.e. a predicate on a single attribute. Maybe once I can get such a model trained and working, I could think of ways to extend it to the multi-attribute case, i.e. if a query contains predicates over multiple attributes of a table, then the model shoudl be able to map a feature vector containing information about all of those predicates and predict a combnied selectivity (this could potentially be better that using individual predicate selectivities and then combining them under the independence assumption).

The model can be pre-trained so that its predictions are consistent with uniform data distribution assumption. Then, the model can be refined online using actual query execution results (i.e. the actual selectivities observed in the query plan operators).

Feature extraction for a single predicate:
-----------------------------------------

Even though I am predicting the selectivity for a single attribute, the model should be able to predict for any attribute in the table. So given that, how should the predicate be encoded into a fixed length feature vector which would allow the model to also know which particular attribute the selectivity corresponds to? i.e. the feature vector needs to somehow be able to encode the identity of the predicate attribute along with details of the predicate itself, such as predictate type (e.g. equality or range) and predicate value.

** Idea for **encoding** the predicate information into a fixed size feature vector:

Example: For table with three attributes A, B, C

Make a feature vector containing equal sized "slots" for each attribute. Then given a predicate on a single attribute, fill the corresponding slot with information about that predicate, e.g. predicate type (such as equality or range) and predicate value. And fill the remaining slots with zero.

`[ -- slot A -- | -- slot B -- | -- slot C -- ]`

This kind of encoding scheme can also be naturally extended to the case of multi-attribute predicates.

In addition to predicate type and predicate value, each slot should also contain other relavant information that might directly help the model to learn the underlying distribution of each attribute.




### Pre-training/Bootstrapping Phase


For the pre-training phase, we will borrow the query selectivity estimator from our simple cost model, which uses postgres internal table statistics and assumes uniform data distributions. We will train the selectivity prediction model to match the predictions of the simple cost model's selectivity estimator.

### Fine-tuning Phase

After pretraining, we can finetune the model via online updates using observed selectivities from actual query execution plans.

In [17]:
# auto reload all modules
%load_ext autoreload
%autoreload 2

from simple_cost_model import *
from ssb_qgen_class import *
import time
import pickle
import numpy as np
import hashlib


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
# set up query generator
qgen = QGEN()

# Get the statistics for all tables in the SSB database
table_names = ["customer", "dwdate", "lineorder", "part", "supplier"]
pg_stats = {}
estimated_rows = {}
for table_name in table_names:
    pg_stats[table_name], estimated_rows[table_name] = get_table_stats(table_name)

table_attributes = {}
for table_name in table_names:
    table_attributes[table_name] = list(pg_stats[table_name].keys())

ssb_tables, pk_columns = get_ssb_schema()
# create a dictionary and specify whether each attribute in each table is numeric or char
data_type_dict = {}
for table_name in ["customer", "dwdate", "lineorder", "part", "supplier"]:
    for column_name, column_type in ssb_tables[table_name]:
        if ("INT" in column_type) or ("DECIMAL" in column_type) or ("BIT" in column_type):
            data_type_dict[column_name] = "numeric"
        else:
            data_type_dict[column_name] = "char"    

#### Selectivity Estimator - Simple Cost Model

In [19]:
def estimate_selectivity_one_sided_range(attribute, boundary_value, operator, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # Get the column statistics
    stats = stats_dict[attribute]
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # Convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs to be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    selectivity = 0.0

    # Check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if (operator == '>' and val > boundary_value) or (operator == '<' and val < boundary_value):
                selectivity += freq

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')]
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs to be either numeric or char")

        total_bins = len(histogram_bounds) - 1

        # Iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i + 1]

            if data_type == 'numeric':
                # Check for range overlap
                if (operator == '>' and boundary_value < bin_upper_bound) or (operator == '<' and boundary_value > bin_lower_bound):
                    # Calculate the overlap fraction within this bin
                    if operator == '>':
                        overlap_min = max(boundary_value, bin_lower_bound)
                        overlap_fraction = (bin_upper_bound - overlap_min) / (bin_upper_bound - bin_lower_bound)
                    else:  # operator == '<'
                        overlap_max = min(boundary_value, bin_upper_bound)
                        overlap_fraction = (overlap_max - bin_lower_bound) / (bin_upper_bound - bin_lower_bound)

                    # Accumulate to the total selectivity
                    selectivity += overlap_fraction * (1.0 / total_bins)

            elif data_type == 'char':
                if (operator == '>' and boundary_value < bin_upper_bound) or (operator == '<' and boundary_value > bin_lower_bound):
                    # assume the whole bin overlaps
                    overlap_fraction = 1.0
                    # Accumulate to the total selectivity
                    selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # If no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct

    return selectivity


def estimate_selectivity_range(attribute, value_range, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    #print(f"Histogram bounds: {histogram_bounds}")
    #print(f"Most common values: {most_common_vals}")
    #print(f"Most common frequencies: {most_common_freqs}")

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    min_value = value_range[0]
    max_value = value_range[1]
    selectivity = 0.0

    # check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if min_value <= val <= max_value:
                selectivity += freq    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if min_value < bin_lower_bound or max_value > bin_upper_bound:
                # does not overlap
                continue    

            if data_type == 'numeric':
                # calculate the overlap fraction within this bin
                overlap_min = max(min_value, bin_lower_bound)
                overlap_max = min(max_value, bin_upper_bound)
                overlap_fraction = (overlap_max - overlap_min) / (bin_upper_bound - bin_lower_bound)

                #print(f"Overlap fraction for bin {i}: {overlap_fraction}")
                #print(f"Bin bounds: {bin_lower_bound}, {bin_upper_bound}")

                # accumulate to the total selectivity
                # Assume each bin represents an equal fraction of the total rows
                selectivity += overlap_fraction * (1.0 / total_bins)

            elif data_type == 'char':
                # assume the whole bin overlaps
                overlap_fraction = 1.0
                # accumulate to the total selectivity
                # Assume each bin represents an equal fraction of the total rows
                selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # if no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct       

    return selectivity


def estimate_selectivity_eq(attribute, value, stats_dict):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # first check if the value is in the most common values
    if most_common_vals and value in most_common_vals:
        selectivity = most_common_freqs[most_common_vals.index(value)] 
        return selectivity

    # if not a common value, estimate using n_distinct
    if n_distinct < 0:
        n_distinct = -n_distinct

    selectivity = 1.0 / n_distinct    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find bin that contains the value
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            if data_type == 'numeric':
                # check for range overlap
                if bin_lower_bound <= value <= bin_upper_bound:
                    bin_width = bin_upper_bound - bin_lower_bound
                    if bin_width > 0:
                        # assume uniform distribution within this bin and calculate selectivity
                        uniform_selectivity = 1.0 / (bin_width*total_bins)
                        selectivity = min(selectivity, uniform_selectivity)
                    break    

            elif data_type == 'char':
                # check for range overlap
                if bin_lower_bound <= value <= bin_upper_bound:
                    # assume the whole bin overlaps
                    selectivity = 1.0 / total_bins
                    break        

    return selectivity


def estimate_selectivity_or(attribute, value, stats_dict):
    combined_selectivity = 0.0
    individual_selectivities = []

    # for each value in the IN list, estimate the selectivity separately
    for val in value:
        individual_selectivities.append(estimate_selectivity_eq(attribute, val, stats_dict))

    # compute combined selectivities using inclusion-exclusion principle and assuming independence
    for selectivity in individual_selectivities:
        combined_selectivity += selectivity 

    overlap_adjustment = 1.0
    for selectivity in individual_selectivities:
        overlap_adjustment *= (1.0 - selectivity)

    combined_selectivity -= overlap_adjustment   

    # make sure the combined selectivity is between 0 and 1
    combined_selectivity = max(0.0, min(combined_selectivity, 1.0))

    return combined_selectivity 


def estimate_selectivity(attribute, operator, value, stats_dict, total_rows):
    if operator == 'eq':
        return estimate_selectivity_eq(attribute, value, stats_dict)
    elif operator == 'range':
        return estimate_selectivity_range(attribute, value, stats_dict, total_rows)
    elif operator == '<' or operator == '>':
        return estimate_selectivity_one_sided_range(attribute, value, operator, stats_dict, total_rows)
    elif operator == 'or':
        return estimate_selectivity_or(attribute, value, stats_dict)    
    else:
        raise ValueError(f"Operator '{operator}' not supported, needs to be either 'eq', 'range', or 'or'")

In [20]:
# generate example query
example_query = qgen.generate_query(1)
# extract the predicates from the query
predicate_dict = example_query.predicate_dict

print(f"Predicates:")
for table_name, predicates in predicate_dict.items():
    print(f"  Table: {table_name}")
    for predicate in predicates:
        print(f"    {predicate}")
        

Predicates:
  Table: lineorder
    {'column': 'lo_orderdate', 'operator': 'eq', 'value': 'd_datekey', 'join': True}
    {'column': 'lo_discount', 'operator': 'range', 'value': (3, 5), 'join': False}
    {'column': 'lo_quantity', 'operator': '<', 'value': 25, 'join': False}
  Table: dwdate
    {'column': 'd_year', 'operator': 'eq', 'value': 1998, 'join': False}


In [21]:
# test out the selectivity estimation functions using the example query predicates
total_rows = estimated_rows["lineorder"]
for table_name, predicates in predicate_dict.items():
    print(f"Table: {table_name}\n")
    for predicate in predicates:
        if predicate['join'] == False:
            attribute = predicate['column']
            operator = predicate['operator']
            value = predicate['value']
            selectivity = estimate_selectivity(attribute, operator, value, pg_stats[table_name], total_rows)
            print(f"  Predicate: {predicate}")
            print(f"  Selectivity: {selectivity}\n")


Table: lineorder

  Predicate: {'column': 'lo_discount', 'operator': 'range', 'value': (3, 5), 'join': False}
  Selectivity: 0.273600006

  Predicate: {'column': 'lo_quantity', 'operator': '<', 'value': 25, 'join': False}
  Selectivity: 0.4762666669999999

Table: dwdate

  Predicate: {'column': 'd_year', 'operator': 'eq', 'value': 1998, 'join': False}
  Selectivity: 0.14241001



#### Selectivity Estimator - Exact

For exact selectivity estimation, we will use the actual uniform distributions of attributes in the tables.

For sanity check, we can compare the simple cost model's selectivity estimates with the exact selectivity.

In [22]:
# load ssb10_stats.pkl
with open('ssb10_stats.pkl', 'rb') as f:
    actual_stats = pickle.load(f)

In [23]:
def estimate_selectivity_eq_exact(attribute, value, stats_dict):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    #print(f"Stats: {stats}")
    min = stats['min'] 
    max = stats['max']
    total_count = stats['total_count']
    distinct_count = stats['distinct_count']
    histogram = stats['histogram']

    # check if historgam is available
    if histogram is not None:
        # check if the value is in the histogram
        if value in histogram:
            selectivity = histogram[value] / total_count
        else:
            selectivity = 0.0

    else:
        # assume uniform distribution
        if min <= value <= max:
            selectivity = 1.0 / distinct_count
        else:
            selectivity = 0.0    

    return selectivity


def estimate_selectivity_or_exact(attribute, value, stats_dict):
    combined_selectivity = 0.0
    individual_selectivities = []

    # for each value in the IN list, estimate the selectivity separately
    for val in value:
        individual_selectivities.append(estimate_selectivity_eq_exact(attribute, val, stats_dict))

    # compute combined selectivities using inclusion-exclusion principle and assuming independence
    for selectivity in individual_selectivities:
        combined_selectivity += selectivity 

    overlap_adjustment = 1.0
    for selectivity in individual_selectivities:
        overlap_adjustment *= (1.0 - selectivity)

    combined_selectivity -= overlap_adjustment   

    # make sure the combined selectivity is between 0 and 1
    combined_selectivity = max(0.0, min(combined_selectivity, 1.0))

    return combined_selectivity 


def estimate_selectivity_range_exact(attribute, value_range, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    #print(f"Stats: {stats}")
    min = stats['min'] 
    max = stats['max']  
    total_count = stats['total_count']
    distinct_count = stats['distinct_count']
    histogram = stats['histogram']
    #print(f"Data type: {data_type}, Min: {min}, Max: {max}, Total count: {total_count}, Distinct count: {distinct_count}")

    selectivity = 0.0
    if min <= value_range[0] <= value_range[1] <= max:
        # use histogram if available
        if histogram is not None:
            # iterate over historgram values and counts
            for val, count in histogram.items():
                #print(f"Historgram, value: {val}, count: {count}")
                if value_range[0] <= val <= value_range[1]:
                    selectivity += count

            selectivity /= total_count

        # otherwise use min and max values and assume uniform distribution
        # Note: this won't work for char type columns
        else:
            selectivity = (value_range[1] - value_range[0]) / (max - min)     

    return selectivity           


def estimate_selectivity_one_sided_range_exact(attribute, boundary_value, operator, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    #print(f"Stats: {stats}")
    min = stats['min'] 
    max = stats['max']
    total_count = stats['total_count']
    distinct_count = stats['distinct_count']
    histogram = stats['histogram']

    selectivity = 0.0
    if operator == '<':
        # check if within bounds
        if min <= boundary_value <= max:
            # use histogram if available
            if histogram is not None:
                # iterate over historgram values and counts
                for val, count in histogram.items():
                    if val < boundary_value:
                        selectivity += count

                selectivity /= total_count        
                
            # otherwise use min and max values and assume uniform distribution
            # Note: this won't work for char type columns
            else:
                selectivity = (boundary_value - min) / (max - min)

    elif operator == '>':
        # check if within bounds
        if min <= boundary_value <= max:
            # use histogram if available
            if histogram is not None:
               # iterate over historgram values and counts
                for val, count in histogram.items():
                    if val > boundary_value:
                        selectivity += count

                selectivity /= total_count
               
            # otherwise use min and max values and assume uniform distribution
            # Note: this won't work for char type columns
            else:
                selectivity = (max - boundary_value) / (max - min)

    return selectivity


def estimate_selectivity_exact(attribute, operator, value, stats_dict, total_rows):
    if operator == 'eq':
        return estimate_selectivity_eq_exact(attribute, value, stats_dict)
    elif operator == 'range':
        return estimate_selectivity_range_exact(attribute, value, stats_dict, total_rows)
    elif operator == '<' or operator == '>':
        return estimate_selectivity_one_sided_range_exact(attribute, value, operator, stats_dict, total_rows)
    elif operator == 'or':
        return estimate_selectivity_or_exact(attribute, value, stats_dict)    
    else:
        raise ValueError(f"Operator '{operator}' not supported, needs to be either 'eq', 'range', or 'or'")

In [24]:
# test out the exact selectivity estimation functions using the example query predicates
total_rows = estimated_rows["lineorder"]
for table_name, predicates in predicate_dict.items():
    print(f"Table: {table_name}\n")
    for predicate in predicates:
        if predicate['join'] == False:
            attribute = predicate['column']
            operator = predicate['operator']
            value = predicate['value']
            selectivity = estimate_selectivity_exact(attribute, operator, value, actual_stats[table_name], total_rows)
            print(f"  Predicate: {predicate}")
            print(f"  Selectivity: {selectivity}\n")


Table: lineorder

  Predicate: {'column': 'lo_discount', 'operator': 'range', 'value': (3, 5), 'join': False}
  Selectivity: 0.27278847769922604

  Predicate: {'column': 'lo_quantity', 'operator': '<', 'value': 25, 'join': False}
  Selectivity: 0.47999450340373206

Table: dwdate

  Predicate: {'column': 'd_year', 'operator': 'eq', 'value': 1998, 'join': False}
  Selectivity: 0.14241001564945227



#### Define the Query Selectivity Model


Feature Engineering:

The feature vector contains equal sized slots for each attribute on the table. Each slot will contain the following information:

1) A binary indicator (0 - no, 1 - yes) for whether the predicate is on the corresponding slot
2) A binary indicator for whether the attribute is of numeric or char type (0 - char, 1 - numeric)
3) A binary indicator for predicate type (0 - equality, 1 - range) to indicate whether the predicate is an equality or range
4) Two scalar predicate range values (i.e. upper and lower bounds), for equality set both values to the equality predicate value


For now, we will only predict selectivity for equality and range predicates.

TODO: For char data types, currently we're using a hash function to convert it to an int. Need to figure out a better way to handle the feature encoding of char attributes


In [25]:
class SelectivityModel:

    def __init__(self, table_name, table_attributes, lambda_reg=0.1, epsilon=1e-8):
        self.table_name = table_name
        self.table_attributes = table_attributes
        # mapping from attribute name to id
        self.attr2id = {attr: i for i, attr in enumerate(table_attributes)}
        # mapping from id to attribute name
        self.id2attr = {i: attr for i, attr in enumerate(table_attributes)}
        
        # set feature dimensions
        self.features_per_attribute = 5
        self.feature_dims = self.features_per_attribute * len(table_attributes)   # 4 features per attribute
            
        self.V = lambda_reg * np.eye(self.feature_dims)
        self.b = np.zeros(self.feature_dims)
        self.theta = np.zeros(self.feature_dims)
        self.lambda_reg = lambda_reg
        self.epsilon = epsilon
        self.loss_history = []
        #self.scaler = StandardScaler()
        #self.normalize_features()


    #def normalize_features(self):
    #    all_features = np.array(list(self.feature_vectors.values()))
    #    self.scaler.fit(all_features)
    #    for key in self.feature_vectors:
    #        self.feature_vectors[key] = self.scaler.transform([self.feature_vectors[key]])[0]

    def update(self, predicate, y_actual, verbose=False):
        x = self.predicate_to_feature_vector(predicate)
        self.V += np.outer(x, x)
        self.b += y_actual * x
        # add small epsilon to diagonal of V for conditioning
        #self.theta = np.linalg.solve(self.V + self.epsilon*np.eye(self.feature_dims), self.b)
        self.theta = np.linalg.solve(self.V, self.b)
        loss, y_pred = self.compute_loss(predicate, y_actual)
        if verbose:
            print(f"Actual selectivity: {y_actual}, predicted selectivity: {y_pred:.3f}, loss incurred: {loss}")

    
    def predict(self, predicate):
        if predicate['operator'] == 'or':
            return self.predict_or(predicate)
        x = self.predicate_to_feature_vector(predicate)
        # need to constrain the predicted selectivity to be between 0 and 1
        y_pred = np.dot(self.theta, x)
        y_pred = min(max(0, y_pred), 1)
        return y_pred


    # handle or predicates separately
    def predict_or(self, predicate):
        # for or predicates, we need to predict the selectivity for each value in the list, then combine them
        
        # split the or predicate into individual eq predicates
        attribute = predicate['column']
        value_list = predicate['value']
        selectivities = []
        for value in value_list:
            eq_predicate = {'column': attribute, 'operator': 'eq', 'value': value}
            selectivity = self.predict(eq_predicate)
            selectivities.append(selectivity)

        # compute combined selectivities using inclusion-exclusion principle and assuming independence
        combined_selectivity = 0.0
        for selectivity in selectivities:
            combined_selectivity += selectivity

        overlap_adjustment = 1.0
        for selectivity in selectivities:
            overlap_adjustment *= (1.0 - selectivity)

        combined_selectivity -= overlap_adjustment

        # make sure the combined selectivity is between 0 and 1
        combined_selectivity = max(0.0, min(combined_selectivity, 1.0))

        return combined_selectivity            
        

    def compute_loss(self, predicate, y_actual):
        y_pred = self.predict(predicate)
        mse = (y_actual - y_pred)**2
        reg = self.lambda_reg * np.dot(self.theta, self.theta)
        loss = mse + reg
        self.loss_history.append(loss)
        return loss, y_pred    
        

    def predicate_to_feature_vector(self, predicate):
        # prepare feature vector
        attribute = predicate['column']
        data_type = data_type_dict[attribute]
        operator = predicate['operator']
        value = predicate['value']

        # it data type is char, hash the value to an integer
        if data_type_dict[attribute] == 'char':
            if operator == 'range':
                value = tuple([self.hash_string_to_bin(v) for v in value])    
            else:
                value = self.hash_string_to_bin(value)

        # set attribute type indicator
        attribute_type_indicator = 0 if data_type == 'char' else 1

        #print(f"Attribute: {attribute}, Data type: {data_type_dict[attribute]}, Operator: {operator}, Value: {value}")

        if operator == 'eq':
            predicate_type_indicator = 0
            lower = value
            upper = value
        elif operator == 'range':
            predicate_type_indicator = 1
            lower = value[0]
            upper = value[1]
        elif operator == '<':
            predicate_type_indicator = 1
            lower = 0
            upper = value
        elif operator == '>':
            predicate_type_indicator = 1
            lower = value
            upper = 0

        predicate_features = [1, attribute_type_indicator, predicate_type_indicator, lower, upper]
        slot_start = self.attr2id[attribute] * self.features_per_attribute
        x = np.zeros(self.feature_dims)
        x[slot_start:slot_start+self.features_per_attribute] = predicate_features

        return x


    def hash_string_to_bin(self, value, num_bins=1000000000):
        # Hash the string and convert to an integer
        hash_value = int(hashlib.md5(value.encode()).hexdigest(), 16) % num_bins  
        # Map to a float
        return hash_value / num_bins
    

#### Data Generator for Pre-Training of Query Selectivity Prediction Model

In [26]:
# a data generator for training the selectivity prediction model, generates predicates and estimated selectivities using simple_cost_model
def generate_predicate(table_name=None): 
    done = False
    while not done:
        if table_name is None:
            # pick a random table
            table_name = np.random.choice(table_names)
        # get the table attributes
        attributes = table_attributes[table_name]
        # pick a random attribute
        attribute = np.random.choice(attributes)
        data_type = data_type_dict[attribute]
        # pick a random operator, with higher probability for equality and range operators
        operator = np.random.choice(['eq', 'range', '<', '>', 'or'], p=[0.3, 0.3, 0.15, 0.15, 0.1])
        # pick a random value
        if operator == 'eq':
            # skip if histogram is not available
            if attribute not in actual_stats[table_name]:
                continue
            # use histogram if available
            if 'histogram' in actual_stats[table_name][attribute]:
                if actual_stats[table_name][attribute]['histogram'] is not None:
                    value = np.random.choice(list(actual_stats[table_name][attribute]['histogram'].keys()))
                    done = True
            else:
                if data_type == 'numeric':
                    # pick uniformly from the min and max values if data type is numeric
                    value = np.random.uniform(actual_stats[table_name][attribute]['min'], actual_stats[table_name][attribute]['max'])      
                    done = True
                else:
                    # for char type, if historgram is not available, then skip
                    continue

        elif operator == 'range':
            # use histogram if available
            # skip if histogram is not available
            if attribute not in actual_stats[table_name]:
                continue
            if 'histogram' in actual_stats[table_name][attribute]:
                if actual_stats[table_name][attribute]['histogram'] is not None:
                    # pick a random range from the histogram
                    value = list(np.sort(np.random.choice(list(actual_stats[table_name][attribute]['histogram'].keys()), 2)))
                    done = True
            else:
                if data_type == 'numeric':
                    # pick a random range from the min and max values if data type is numeric
                    value = list(np.sort(np.random.uniform(actual_stats[table_name][attribute]['min'], actual_stats[table_name][attribute]['max'], 2)))
                    done = True
                else:
                    # for char type, if historgram is not available, then skip
                    continue

        elif operator == '<' or operator == '>':
            # skip if histogram is not available
            if attribute not in actual_stats[table_name]:
                continue
            # use histogram if available
            if 'histogram' in actual_stats[table_name][attribute]:
                if actual_stats[table_name][attribute]['histogram'] is not None:
                    value = np.random.choice(list(actual_stats[table_name][attribute]['histogram'].keys()))
                    done = True
            else:
                if data_type == 'numeric':
                    # pick uniformly from the min and max values if data type is numeric
                    value = np.random.uniform(actual_stats[table_name][attribute]['min'], actual_stats[table_name][attribute]['max'])      
                    done = True
                else:
                    # for char type, if historgram is not available, then skip
                    continue

    predicate = {'column': attribute, 'operator': operator, 'value': value, 'join': False}

    # get simple cost model estimated selectivity
    total_rows = estimated_rows[table_name]
    selectivity = estimate_selectivity(attribute, operator, value, pg_stats[table_name], total_rows)  

    return table_name, predicate, selectivity                


In [27]:
table_name, predicate, selectivity = generate_predicate()

print(f"Table: {table_name}")
print(f"Predicate: {predicate}")
print(f"Simplified cost model estimated selectivity: {selectivity:.3f}")

Table: lineorder
Predicate: {'column': 'lo_shipmode', 'operator': 'eq', 'value': 'AIR', 'join': False}
Simplified cost model estimated selectivity: 0.147


#### Pre-training

In [31]:
# for each table, initialize a selectivity model
selectivity_models = {}
for table_name in table_names:
    attributes = [column[0] for column in ssb_tables[table_name]]
    selectivity_models[table_name] = SelectivityModel(table_name, attributes, lambda_reg=0.01)

In [32]:
# perform training for each table
def train(num_steps=1000000, num_val_steps=100):
    for step in range(num_steps):
        for table_name in table_names:
            table_name, predicate, selectivity = generate_predicate(table_name)
            selectivity_models[table_name].update(predicate, selectivity)
    
        if step % 10000 == 0:
            print(f"\nStep: {step}")
            mean_absolute_error_step = {table_name: [] for table_name in table_names}
            for i in range(num_val_steps):
                mean_absolute_error = evaluate_model(selectivity_models, verbose=False)
                for table_name, mae in mean_absolute_error.items():
                    mean_absolute_error_step[table_name].append(mae)
            
            for table_name in mean_absolute_error_step:
                mean_absolute_error_step[table_name] = np.mean(mean_absolute_error_step[table_name])
                print(f"Table: {table_name} --> Mean absolute error: {mean_absolute_error_step[table_name]:.3f}")
                    
            #for table_name in table_names:
            #    print(f"Table: {table_name}, loss: {selectivity_models[table_name].loss_history[-1]:.3f}")



In [33]:
train(num_steps=200000)


Step: 0
Table: customer --> Mean absolute error: 0.120
Table: dwdate --> Mean absolute error: 0.290
Table: lineorder --> Mean absolute error: 0.239
Table: part --> Mean absolute error: 0.016
Table: supplier --> Mean absolute error: 0.014

Step: 10000
Table: customer --> Mean absolute error: 0.015
Table: dwdate --> Mean absolute error: 0.118
Table: lineorder --> Mean absolute error: 0.132
Table: part --> Mean absolute error: 0.052
Table: supplier --> Mean absolute error: 0.013

Step: 20000
Table: customer --> Mean absolute error: 0.013
Table: dwdate --> Mean absolute error: 0.123
Table: lineorder --> Mean absolute error: 0.130
Table: part --> Mean absolute error: 0.046
Table: supplier --> Mean absolute error: 0.013

Step: 30000
Table: customer --> Mean absolute error: 0.013
Table: dwdate --> Mean absolute error: 0.123
Table: lineorder --> Mean absolute error: 0.131
Table: part --> Mean absolute error: 0.046
Table: supplier --> Mean absolute error: 0.013

Step: 40000
Table: customer -->

#### Model Evaluation on SSB Query Template Predicates

In [34]:
def evaluate_model(selectivity_models, verbose=False):

    # use the model to make some predictions on predicate selectivity for all query templates, keep track of cumlative absolute error for each table
    mean_absolute_error = {table_name: [] for table_name in table_names}
    for i in range(1, 16):
        # generate example query
        query = qgen.generate_query(i)
        # extract the predicates from the query
        predicate_dict = query.predicate_dict
        if verbose: print(f"Query Template: {i}\n")

        for table_name, predicates in predicate_dict.items():
            if verbose: print(f"  Table: {table_name}\n")
            for predicate in predicates:
                if predicate['join'] == False:
                    if verbose: print(f"    Predicate: {predicate}")
                    attribute = predicate['column']
                    operator = predicate['operator']
                    value = predicate['value']
                    y_pred = selectivity_models[table_name].predict(predicate)
                    if verbose: print(f"    Attribute data type: {data_type_dict[attribute]}")
                    if verbose: print(f"    Predicted selectivity: {y_pred:.3f}")
                    # simple cost model estimated selectivity
                    total_rows = estimated_rows[table_name]
                    selectivity = estimate_selectivity(attribute, operator, value, pg_stats[table_name], total_rows)
                    if verbose: print(f"    Simple cost model selectivity: {selectivity:.3f}\n")
                    mean_absolute_error[table_name].append(abs(y_pred - selectivity))

    for table_name in table_names:
        mean_absolute_error[table_name] = np.mean(mean_absolute_error[table_name])
        if verbose:
            print(f"Table: {table_name} --> Mean absolute error: {mean_absolute_error[table_name]:.3f}")

    return mean_absolute_error    


In [35]:
evaluate_model(selectivity_models, verbose=True)

Query Template: 1

  Table: lineorder

    Predicate: {'column': 'lo_discount', 'operator': 'range', 'value': (1, 3), 'join': False}
    Attribute data type: numeric
    Predicted selectivity: 0.465
    Simple cost model selectivity: 0.272

    Predicate: {'column': 'lo_quantity', 'operator': '>', 'value': 25, 'join': False}
    Attribute data type: numeric
    Predicted selectivity: 0.258
    Simple cost model selectivity: 0.503

  Table: dwdate

    Predicate: {'column': 'd_year', 'operator': 'eq', 'value': 1992, 'join': False}
    Attribute data type: numeric
    Predicted selectivity: 0.143
    Simple cost model selectivity: 0.143

Query Template: 2

  Table: lineorder

    Predicate: {'column': 'lo_discount', 'operator': 'range', 'value': (7, 9), 'join': False}
    Attribute data type: numeric
    Predicted selectivity: 0.412
    Simple cost model selectivity: 0.272

    Predicate: {'column': 'lo_quantity', 'operator': 'range', 'value': (8, 17), 'join': False}
    Attribute data t

{'customer': 0.011533925368874934,
 'dwdate': 0.1255201845977264,
 'lineorder': 0.1631757050049823,
 'part': 0.05326037923682262,
 'supplier': 0.010863420282814085}

In [36]:
for attribute, stats in actual_stats['lineorder'].items():
    print(attribute)
    print(f"  {stats}\n")

lo_orderkey
  {'min': 1, 'max': 60000000, 'total_count': 59986214, 'distinct_count': 15000000, 'histogram': None}

lo_linenumber
  {'min': 1, 'max': 7, 'total_count': 59986214, 'distinct_count': 7, 'histogram': {1: 15000000, 2: 12856079, 3: 10710991, 4: 8569709, 5: 6425628, 6: 4283621, 7: 2140186}}

lo_custkey
  {'min': 1, 'max': 299999, 'total_count': 59986214, 'distinct_count': 200000, 'histogram': None}

lo_partkey
  {'min': 1, 'max': 600000, 'total_count': 59986214, 'distinct_count': 600000, 'histogram': None}

lo_suppkey
  {'min': 1, 'max': 20000, 'total_count': 59986214, 'distinct_count': 20000, 'histogram': None}

lo_orderdate
  {'min': datetime.date(1992, 1, 1), 'max': datetime.date(1998, 8, 2)}

lo_orderpriority
  {'min': '1-URGENT', 'max': '5-LOW', 'total_count': 59986214, 'distinct_count': 5, 'histogram': {'1-URGENT': 12008321, '2-HIGH': 12002304, '3-MEDIUM': 11990511, '4-NOT SPECIFIED': 11999553, '5-LOW': 11985525}}

lo_shippriority
  {'min': '0', 'max': '0', 'total_count':