Learned Selectivity Prediction:
------------------------------


So, essentially what I am trying to do is develop a simple algorithm (similar to the postgres query optimizer) which, given a query, predicts the cheapest access path for each table, assuming that the postgres query optimizer will always select the cheapest access path for each table involved in the query, where cost is proportional to total number of disk accesses. To make that prediction the algorithm uses table statistics such as a postgres stats histogram to estimate the selectivity of each precicate and then use independence assumption to compute cardinality (if multiple predicates are present on a single table). For skewed data distributions, I understand that postgres stats histograms may not be accurate and therefore my selectivity estimates may also be highly inaccurate. Thats why my initial idea was to learn a CDF function online for each table attribute. 

However, I realize now that this approach may not work very well, and instead maybe I should develop a single regression model on each table which can directly predict the selectivity of a given predicate. I could train this model online using explain analyze results from real time query executions. Beacuse those query plan results can be used to know the exact selectivity of each predicate, and features for predicates are also easy to extract and have a simple form, this type of model could potentially be better suited for my task. 

To keep things simple in the beginning, I want to start with a model for each table that only predicts selectivity for a single predicate, i.e. a predicate on a single attribute. Maybe once I can get such a model trained and working, I could think of ways to extend it to the multi-attribute case, i.e. if a query contains predicates over multiple attributes of a table, then the model shoudl be able to map a feature vector containing information about all of those predicates and predict a combnied selectivity (this could potentially be better that using individual predicate selectivities and then combining them under the independence assumption).

The model can be pre-trained so that its predictions are consistent with uniform data distribution assumption. Then, the model can be refined online using actual query execution results (i.e. the actual selectivities observed in the query plan operators).

Feature extraction for a single predicate:
-----------------------------------------

Even though I am predicting the selectivity for a single attribute, the model should be able to rpedict for any attribute in the table. So given that, how should the predicate be encoded into a fixed length feature vector which would allow the model to also know which particular attribute the selectivity corresponds to? i.e. the feature vector needs to somehow be able to encode the identity of the predicate attribute along with details of the predicate itself, such as predictate type (e.g. equality or range) and predicate value.

** Idea for **encoding** the predicate information into a fixed size feature vector:

Example: For table with three attributes A, B, C

Make a feature vector containing equal sized "slots" for each attribute. Then given a predicate on a single attribute, fill the corresponding slot with information about that predicate, e.g. predicate type (such as equality or range) and predicate value. And fill the remaining slots with zero.

`[ -- slot A -- | -- slot B -- | -- slot C -- ]`

This kind of encoding scheme can also be naturally extended to the case of multi-attribute predicates.








### Pre-training Phase


For the pre-training phase, we will borrow the query selectivity estimator from our simple cost model, which uses postgres internal table statistics and assumes uniform data distributions. We will train the selectivity prediction model to match the predictions of the simple cost model's selectivity estimator.

### Fine-tuning Phase

After pretraining, we can finetune the model via online updates using observed selectivities from actual query execution plans.

In [7]:
# auto reload all modules
%load_ext autoreload
%autoreload 2

from simple_cost_model import *
from ssb_qgen_class import *
import time

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
# set up query generator
qgen = QGEN()

# Get the statistics for all tables in the SSB database
table_names = ["customer", "dwdate", "lineorder", "part", "supplier"]
stats = {}
estimated_rows = {}
for table_name in table_names:
    stats[table_name], estimated_rows[table_name] = get_table_stats(table_name)

#### Selectivity Estimator from Simple Cost Model

In [6]:
def estimate_selectivity_one_sided_range(self, attribute, boundary_value, operator, stats_dict, total_rows):
    data_type = self.data_type_dict[attribute]
    # Get the column statistics
    stats = stats_dict[attribute]
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # Convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs to be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    selectivity = 0.0

    # Check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if (operator == '>' and val > boundary_value) or (operator == '<' and val < boundary_value):
                selectivity += freq

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')]
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs to be either numeric or char")

        total_bins = len(histogram_bounds) - 1

        # Iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i + 1]

            if data_type == 'numeric':
                # Check for range overlap
                if (operator == '>' and boundary_value < bin_upper_bound) or (operator == '<' and boundary_value > bin_lower_bound):
                    # Calculate the overlap fraction within this bin
                    if operator == '>':
                        overlap_min = max(boundary_value, bin_lower_bound)
                        overlap_fraction = (bin_upper_bound - overlap_min) / (bin_upper_bound - bin_lower_bound)
                    else:  # operator == '<'
                        overlap_max = min(boundary_value, bin_upper_bound)
                        overlap_fraction = (overlap_max - bin_lower_bound) / (bin_upper_bound - bin_lower_bound)

                    # Accumulate to the total selectivity
                    selectivity += overlap_fraction * (1.0 / total_bins)

            elif data_type == 'char':
                if (operator == '>' and boundary_value < bin_upper_bound) or (operator == '<' and boundary_value > bin_lower_bound):
                    # assume the whole bin overlaps
                    overlap_fraction = 1.0
                    # Accumulate to the total selectivity
                    selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # If no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct

    return selectivity


def estimate_selectivity_range(self, attribute, value_range, stats_dict, total_rows):
    data_type = self.data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    #print(f"Histogram bounds: {histogram_bounds}")
    #print(f"Most common values: {most_common_vals}")
    #print(f"Most common frequencies: {most_common_freqs}")

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    min_value = value_range[0]
    max_value = value_range[1]
    selectivity = 0.0

    # check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if min_value <= val <= max_value:
                selectivity += freq    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if min_value < bin_lower_bound or max_value > bin_upper_bound:
                # does not overlap
                continue    

            if data_type == 'numeric':
                # calculate the overlap fraction within this bin
                overlap_min = max(min_value, bin_lower_bound)
                overlap_max = min(max_value, bin_upper_bound)
                overlap_fraction = (overlap_max - overlap_min) / (bin_upper_bound - bin_lower_bound)

                #print(f"Overlap fraction for bin {i}: {overlap_fraction}")
                #print(f"Bin bounds: {bin_lower_bound}, {bin_upper_bound}")

                # accumulate to the total selectivity
                # Assume each bin represents an equal fraction of the total rows
                selectivity += overlap_fraction * (1.0 / total_bins)

            elif data_type == 'char':
                # assume the whole bin overlaps
                overlap_fraction = 1.0
                # accumulate to the total selectivity
                # Assume each bin represents an equal fraction of the total rows
                selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # if no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct       

    return selectivity


def estimate_selectivity_eq(self, attribute, value, stats_dict):
    data_type = self.data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # first check if the value is in the most common values
    if most_common_vals and value in most_common_vals:
        selectivity = most_common_freqs[most_common_vals.index(value)] 
        return selectivity

    # if not a common value, estimate using n_distinct
    if n_distinct < 0:
        n_distinct = -n_distinct

    selectivity = 1.0 / n_distinct    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find bin that contains the value
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            if data_type == 'numeric':
                # check for range overlap
                if bin_lower_bound <= value <= bin_upper_bound:
                    bin_width = bin_upper_bound - bin_lower_bound
                    if bin_width > 0:
                        # assume uniform distribution within this bin and calculate selectivity
                        uniform_selectivity = 1.0 / (bin_width*total_bins)
                        selectivity = min(selectivity, uniform_selectivity)
                    break    

            elif data_type == 'char':
                # check for range overlap
                if bin_lower_bound <= value <= bin_upper_bound:
                    # assume the whole bin overlaps
                    selectivity = 1.0 / total_bins
                    break        

    return selectivity


def estimate_selectivity_or(self, attribute, value, stats_dict):
    combined_selectivity = 0.0
    individual_selectivities = []

    # for each value in the IN list, estimate the selectivity separately
    for val in value:
        individual_selectivities.append(self.estimate_selectivity_eq(attribute, val, stats_dict))

    # compute combined selectivities using inclusion-exclusion principle and assuming independence
    for selectivity in individual_selectivities:
        combined_selectivity += selectivity 

    overlap_adjustment = 1.0
    for selectivity in individual_selectivities:
        overlap_adjustment *= (1.0 - selectivity)

    combined_selectivity -= overlap_adjustment   

    # make sure the combined selectivity is between 0 and 1
    combined_selectivity = max(0.0, min(combined_selectivity, 1.0))

    return combined_selectivity 


def estimate_selectivity(self, attribute, operator, value, stats_dict, total_rows):
    if operator == 'eq':
        return self.estimate_selectivity_eq(attribute, value, stats_dict)
    elif operator == 'range':
        return self.estimate_selectivity_range(attribute, value, stats_dict, total_rows)
    elif operator == '<' or operator == '>':
        return self.estimate_selectivity_one_sided_range(attribute, value, operator, stats_dict, total_rows)
    elif operator == 'or':
        return self.estimate_selectivity_or(attribute, value, stats_dict)    
    else:
        raise ValueError(f"Operator '{operator}' not supported, needs to be either 'eq', 'range', or 'or'")
    

In [12]:
# generate example query
example_query = qgen.generate_query(1)
# extract the predicates from the query
predicate_dict = example_query.predicate_dict

print(f"Predicates:")
for table_name, predicates in predicate_dict.items():
    print(f"  Table: {table_name}")
    for predicate in predicates:
        print(f"    {predicate}")
        
    

Predicates:
  Table: lineorder
    {'column': 'lo_orderdate', 'operator': 'eq', 'value': 'd_datekey', 'join': True}
    {'column': 'lo_discount', 'operator': 'range', 'value': (2, 4), 'join': False}
    {'column': 'lo_quantity', 'operator': '<', 'value': 25, 'join': False}
  Table: dwdate
    {'column': 'd_year', 'operator': 'eq', 'value': 1998, 'join': False}
