#### Simple Cost-Model for predicting query execution cost (assume disk IO dominates) and access paths selected 


Main steps in the cost model algorithm:

1) Identify predicates (quatily and range) and payloads
2) Enumerate all possible access paths (sequential scans, index scans, bitmap index scan + bitmap heap scan)
3) Estimate selectivity for each predicate (i.e. what fraction of data needs to be accessed from a table)
4) Estimate cardinality of each access path (using total number of rows and selectivity information)
5) Estimate disk IO cost for each access path (for index scans, we use cardinality estimate to figure out how many pages need to be fetchs)
6) Compare the estimated costs and choose best access paths for covering the query

-----------------------------------------------------------------------------------------------------------------------------------------------
Some useful notes on access paths:
* sequential scans --> better for low selectivity or if no suitable index available or if the table is really small
* index scan --> better for high selectivity and simpler predicates
* index only scan --> better for high selectivity and index is also a covering index and simpler predicates
* bitmap index scan + bitmap heap scan --> better for medium selectivity and complex predicates

In [56]:
import psycopg2
import numpy as np
from pg_utils import *
from ssb_qgen_class import *

#### Getting Table Statistics

In [54]:
def get_page_size():
    return 8192  # 8 KB

def get_table_stats(table_name):

    conn = create_connection()

    cur = conn.cursor()

    try:
        # Execute the query to get the estimated number of rows in the table
        cur.execute(f"""
                    SELECT reltuples::bigint AS estimated_rows
                    FROM pg_class
                    WHERE relname = '{table_name}';
                    """)
        estimated_rows = cur.fetchone()[0]
    except:
        print(f"Error: Could not get the estimated number of rows in the '{table_name}' table.")
        estimated_rows = None

    try:
        # Query to get column statistics
        cur.execute(f"SELECT * FROM pg_stats WHERE tablename = '{table_name}';")
        column_stats = cur.fetchall()

        # Define the column names based on the pg_stats view
        column_names = [
                        "schemaname", "tablename", "attname", "inherited", "null_frac",
                        "avg_width", "n_distinct", "most_common_vals", "most_common_freqs",
                        "histogram_bounds", "correlation", "most_common_elems",
                        "most_common_elem_freqs", "elem_count_histogram"
                    ]

        # Organize the results into a dictionary
        stats_dict = {}
        for row in column_stats:
            column_name = row[2]  # 'attname' is the third column in the result
            stats_dict[column_name] = {column_names[i]: row[i] for i in range(len(column_names))}
    except:
        print(f"Error: Could not get the statistics for the '{table_name}' table")
        stats_dict = None

    # Close the cursor and connection
    cur.close()
    close_connection(conn)

    return stats_dict, estimated_rows


# Get the statistics for all tables in the SSB database
table_names = ["customer", "dwdate", "lineorder", "part", "supplier"]
stats = {}
estimated_rows = {}
for table_name in table_names:
    stats[table_name], estimated_rows[table_name] = get_table_stats(table_name)


# Print the organized statistics dictionary
#for key, value in stats_dict.items():
#    print(f"{key}")
#    for k, v in value.items():
#        print(f"    {k}: {v}")

#print(f"\nEstimated number of rows in the 'customer' table: {estimated_rows}")
print(estimated_rows)

{'customer': 300000, 'dwdate': 2556, 'lineorder': 59986216, 'part': 800000, 'supplier': 20000}


In [61]:
ssb_tables, pk_columns = get_ssb_schema()

# create a dictionary and specify whether each attribute in each table is numeric or char
data_type_dict = {}
for table_name in ["customer", "dwdate", "lineorder", "part", "supplier"]:
    for column_name, column_type in ssb_tables[table_name]:
        if ("INT" in column_type) or ("DECIMAL" in column_type) or ("BIT" in column_type):
            data_type_dict[column_name] = "numeric"
        else:
            data_type_dict[column_name] = "char"
    
data_type_dict    

{'c_custkey': 'numeric',
 'c_name': 'char',
 'c_address': 'char',
 'c_city': 'char',
 'c_nation': 'char',
 'c_region': 'char',
 'c_phone': 'char',
 'c_mktsegment': 'char',
 'd_datekey': 'char',
 'd_date': 'char',
 'd_dayofweek': 'char',
 'd_month': 'char',
 'd_year': 'numeric',
 'd_yearmonthnum': 'numeric',
 'd_yearmonth': 'char',
 'd_daynuminweek': 'numeric',
 'd_daynuminmonth': 'numeric',
 'd_daynuminyear': 'numeric',
 'd_monthnuminyear': 'numeric',
 'd_weeknuminyear': 'numeric',
 'd_sellingseason': 'char',
 'd_lastdayinweekfl': 'numeric',
 'd_lastdayinmonthfl': 'numeric',
 'd_holidayfl': 'numeric',
 'd_weekdayfl': 'numeric',
 'lo_orderkey': 'numeric',
 'lo_linenumber': 'numeric',
 'lo_custkey': 'numeric',
 'lo_partkey': 'numeric',
 'lo_suppkey': 'numeric',
 'lo_orderdate': 'char',
 'lo_orderpriority': 'char',
 'lo_shippriority': 'char',
 'lo_quantity': 'numeric',
 'lo_extendedprice': 'numeric',
 'lo_ordtotalprice': 'numeric',
 'lo_discount': 'numeric',
 'lo_revenue': 'numeric',
 'lo

In [55]:
stats['customer']

{'c_custkey': {'schemaname': 'public',
  'tablename': 'customer',
  'attname': 'c_custkey',
  'inherited': False,
  'null_frac': 0.0,
  'avg_width': 4,
  'n_distinct': -1.0,
  'most_common_vals': None,
  'most_common_freqs': None,
  'histogram_bounds': '{6,2899,6114,8927,12227,15096,18002,21064,24084,27083,30233,33121,36597,39582,42470,45223,48127,50938,54109,56972,59898,62789,65654,68491,71562,74603,77584,80229,83386,86443,89528,92546,95625,98636,101754,104837,107890,111162,114410,117435,120672,123823,126668,129469,132808,135752,138847,141812,144740,147797,150948,153926,157031,159873,162949,165980,168812,172023,174986,177768,180770,183770,186610,189492,192193,195225,198284,201308,204255,207206,210257,213558,216737,219703,222943,225823,228723,231728,234660,237726,240430,243481,246317,249467,252487,255400,258396,261199,264278,267219,270075,273093,275941,279014,282138,285426,288188,291181,293711,296956,299991}',
  'correlation': 1.0,
  'most_common_elems': None,
  'most_common_elem_freqs

#### Estimating Selectivity for a value range on a particular column, i.e. what fraction of the data (i.e. tuples) fall in the given range, using the Table Statistics

In [62]:
def estimate_selectivity_range(attribute, value_range, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    min_value = value_range[0]
    max_value = value_range[1]
    selectivity = 0.0

    # check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if min_value <= val <= max_value:
                selectivity += freq    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if min_value < bin_lower_bound or max_value > bin_upper_bound:
                # does not overlap
                continue    

            # calculate the overlap fraction within this bin
            overlap_min = max(min_value, bin_lower_bound)
            overlap_max = min(max_value, bin_upper_bound)
            overlap_fraction = (overlap_max - overlap_min) / (bin_upper_bound - bin_lower_bound)

            #print(f"Overlap fraction for bin {i}: {overlap_fraction}")
            #print(f"Bin bounds: {bin_lower_bound}, {bin_upper_bound}")

            # accumulate to the total selectivity
            # Assume each bin represents an equal fraction of the total rows
            selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # if no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct       

    return selectivity


def estimate_selectivity_eq(attribute, value, stats_dict):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # first check if the value is in the most common values
    if most_common_vals and value in most_common_vals:
        selectivity = most_common_freqs[most_common_vals.index(value)] 
        return selectivity

    # if not a common value, estimate using n_distinct
    if n_distinct < 0:
        n_distinct = -n_distinct

    selectivity = 1.0 / n_distinct    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find bin that contains the value
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if bin_lower_bound <= value <= bin_upper_bound:
                bin_width = bin_upper_bound - bin_lower_bound
                if bin_width > 0:
                    # assume uniform distribution within this bin and calculate selectivity
                    uniform_selectivity = 1.0 / (bin_width*total_bins)
                    selectivity = min(selectivity, uniform_selectivity)
                break    

    return selectivity


In [64]:
# test the selectivity estimation functions on a numeric column
attribute = 'd_year'
stats_dict = stats['dwdate']

# test range selectivity estimation
value_range = (1992, 1994)
selectivity = estimate_selectivity_range(attribute, value_range, stats_dict, estimated_rows)
print(f"Estimated selectivity for range {value_range}: {selectivity}")

# test equality selectivity estimation
value = 1992
selectivity = estimate_selectivity_eq(attribute, value, stats_dict)
print(f"Estimated selectivity for value {value}: {selectivity}")

# now, let's try a char column
attribute = 'd_dayofweek'

# test range selectivity estimation
value_range = ('Monday', 'Wednesday')
selectivity = estimate_selectivity_range(attribute, value_range, stats_dict, estimated_rows)
print(f"Estimated selectivity for range {value_range}: {selectivity}")

# test equality selectivity estimation
value = 'Monday'
selectivity = estimate_selectivity_eq(attribute, value, stats_dict)
print(f"Estimated selectivity for value {value}: {selectivity}")




Estimated selectivity for range (1992, 1994): 0.42879498
Estimated selectivity for value 1992: 0.14319248
Estimated selectivity for range ('Monday', 'Wednesday'): 0.8571987299999999
Estimated selectivity for value Monday: 0.14280125


### Simple Case: consider three possible access paths: Sequential Scan, Index Scan and Index Only Scan (will add in the rest later)

#### Step 1: Extract query predicates and payload, the payload is a list of column names (across possibly multiple tables) and the predicate is a list of dictionaries, each dict contains the column name, operator (either equality or range) and either a range tuple or single value

In [65]:
# simple example query on single table (will later be expanded to join queries)
example_query = """
                SELECT lo_linenumber, lo_quantity, lo_orderdate  
                FROM lineorder
                WHERE lo_linenumber >= {linenumber_low} AND lo_linenumber <= {linenumber_high}
                AND lo_quantity = {quantity};
                """

# extract tables and associated columns
tables = {}
tables['lineorder'] = ['lo_linenumber', 'lo_quantity', 'lo_orderdate']

# extract the payload
payload = {}
payload['lineorder'] = ['lo_linenumber', 'lo_quantity', 'lo_orderdate']

# extract the predicates
predicates = {}
predicates['lineorder'] =  [  {'column': 'lo_linenumber', 'operator': 'range', 'value': (1, 5)},
                              {'column': 'lo_quantity', 'operator': 'eq', 'value': 10}
                           ]

print(f"Payload: {payload}")
print(f"Predicates: {predicates}")

# create some index objects
index_1 = Index('lineorder', 'IX_lineorder_lo_linenumber_lo_quantity', index_columns=['lo_linenumber', 'lo_quantity'])
index_2 = Index('lineorder', 'IX_lineorder_lo_linenumber_lo_quantity_lo_o', index_columns=['lo_linenumber', 'lo_quantity'], include_columns=['lo_orderdate'])
index_3 = Index('lineorder', 'IX_lineorder_lo_orderdate', index_columns=['lo_orderdate'])
index_4 = Index('lineorder', 'IX_lineorder_lo_quantity', index_columns=['lo_quantity'])
indexes = {index.index_id: index for index in [index_1, index_2, index_3, index_4]}


Payload: {'lineorder': ['lo_linenumber', 'lo_quantity', 'lo_orderdate']}
Predicates: {'lineorder': [{'column': 'lo_linenumber', 'operator': 'range', 'value': (1, 5)}, {'column': 'lo_quantity', 'operator': 'eq', 'value': 10}]}


#### Step 2: Enumerate the possible access path for each table involved

In [66]:
access_paths = {}
for table_name in tables:
    table_predicate_cols = [pred['column'] for pred in predicates[table_name]]
    table_payload_cols = [col for col in payload[table_name] if col in tables[table_name]]   
    table_access_paths = [{'scan_type': 'Sequential Scan', table_name: table_name}]
    for index in indexes.values():
        if index.table_name == table_name:
            print("Checking index: ", index.index_id)
            # check if index scan is possible, at least one index column should be in the predicate
            if set(index.index_columns).intersection(table_predicate_cols):
                table_access_paths.append({'scan_type': 'Index Scan', 'index_id': index.index_id})
                print("Index scan possible")
            # check if index only scan is possible
            if set(index.index_columns).issuperset(table_predicate_cols) and set(list(index.index_columns) + list(index.include_columns)).issuperset(table_payload_cols):
                table_access_paths.append({'scan_type': 'Index Only Scan', 'index_id': index.index_id})
                print("Index only scan possible")
    access_paths[table_name] = table_access_paths

print(f"Access paths: ")
for table, paths in access_paths.items():
    print(f"Table: {table}")
    for path in paths:
        print(f"    {path}")

Checking index:  IX_lineorder_lo_linenumber_lo_quantity
Index scan possible
Checking index:  IX_lineorder_lo_linenumber_lo_quantity_lo_o
Index scan possible
Index only scan possible
Checking index:  IX_lineorder_lo_orderdate
Checking index:  IX_lineorder_lo_quantity
Index scan possible
Access paths: 
Table: lineorder
    {'scan_type': 'Sequential Scan', 'lineorder': 'lineorder'}
    {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity'}
    {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}
    {'scan_type': 'Index Only Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}
    {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_quantity'}


#### Estimate selectivity of the predicates

In [67]:
def estimate_selectivity(attribute, operator, value, stats_dict, total_rows):
    if operator == 'eq':
        return estimate_selectivity_eq(attribute, value, stats_dict)
    elif operator == 'range':
        return estimate_selectivity_range(attribute, value, stats_dict, total_rows)
    else:
        raise ValueError("Operator not supported, needs to be either eq or range")

In [68]:
for table_name, table_preds in predicates.items():
    table_stats_dict = stats[table_name]   
    table_estimated_rows = estimated_rows[table_name]
    for pred in table_preds:
        selectivity = estimate_selectivity(pred['column'], pred['operator'], pred['value'], table_stats_dict, table_estimated_rows)
        print(f"Estimated selectivity for predicate {pred}: {selectivity}")

Estimated selectivity for predicate {'column': 'lo_linenumber', 'operator': 'range', 'value': (1, 5)}: 0.894000005
Estimated selectivity for predicate {'column': 'lo_quantity', 'operator': 'eq', 'value': 10}: 0.0182


In [69]:
predicates

{'lineorder': [{'column': 'lo_linenumber',
   'operator': 'range',
   'value': (1, 5)},
  {'column': 'lo_quantity', 'operator': 'eq', 'value': 10}]}

#### For each table, estimate selectivity and disk IO cost of all access path for that table and select cheapest path

In [70]:
def calculate_row_overhead(num_nullable_columns=0):
    # Tuple header size
    tuple_header_size = 23  # bytes
    # Null bitmap size (1 byte for every 8 nullable columns)
    null_bitmap_size = (num_nullable_columns + 7) // 8
    # Total overhead
    total_overhead = tuple_header_size + null_bitmap_size

    return total_overhead


def table_avg_rows_per_page(table_stats_dict):
    # add up the average width of all columns to get the average width of a row
    avg_row_size = 0
    for column, column_stats in table_stats_dict.items():
        avg_row_size += column_stats['avg_width']
    # add the row overhead
    avg_row_size += calculate_row_overhead()
    # calculate the average number of rows that can fit in a page
    avg_rows_per_page = int(get_page_size() / avg_row_size)

    return avg_rows_per_page

def index_average_rows_per_page(index, table_stats_dict):
    columns = list(index.index_columns) + list(index.include_columns)   
    # add up the average width of all columns to get the average width of a row
    avg_row_size = 0
    for column in columns:
        avg_row_size += table_stats_dict[column]['avg_width']
    # add the row overhead
    index_row_overhead = 16  # assume 16 bytes 
    avg_row_size += index_row_overhead
    # calculate the average number of rows that can fit in a page
    # (assuming the index is a B+ tree, so only the leaf nodes contain the actual data)
    avg_rows_per_page = int(get_page_size() / avg_row_size)
       
    return avg_rows_per_page


def estimate_index_scan_cost(index, table_stats_dict, table_predicates, total_rows, index_only_scan=False):
    # check if leading index column is in the predicates
    leading_index_column = index.index_columns[0]
    predicate_columns = [pred['column'] for pred in table_predicates]
    
    if leading_index_column not in predicate_columns:
        # assign high cost to prevent using this index, sequential scan will be cheaper
        return float('inf')
    
    # calculate the combined selectivity for this index (assuming attribute independence/no correlations of predicates)
    combined_selectivity = 1.0
    for pred in table_predicates:
        if pred['column'] in index.index_columns:
        
            selectivity = estimate_selectivity(pred['column'], pred['operator'], pred['value'], table_stats_dict, total_rows)
            combined_selectivity *= selectivity

    # estimate cardinality of the index scan
    index_cardinality = combined_selectivity * total_rows
    # estimate the number of pages that need to be accessed
    avg_rows_per_page = index_average_rows_per_page(index, table_stats_dict)
    index_pages = int(index_cardinality / avg_rows_per_page)
    
    if index_only_scan: 
        index_scan_cost = index_pages
    else:
        # for index scan, we need to access the table as well
        index_average_rows_per_page_table = table_avg_rows_per_page(table_stats_dict)
        table_pages = int(index_cardinality / index_average_rows_per_page_table)
        # return total cost as the sum of index and table pages
        index_scan_cost = index_pages + table_pages

    return index_scan_cost



def estimate_sequentail_scan_cost(table_stats_dict, total_rows):
    # estimate cardinality of the scan
    scan_cardinality = total_rows
    # estimate the number of pages that need to be accessed
    avg_rows_per_page = table_avg_rows_per_page(table_stats_dict)
    scan_pages = int(scan_cardinality / avg_rows_per_page)
    # estimate the total cost as the number of pages that need to be accessed
    sequential_scan_cost = scan_pages

    return sequential_scan_cost


def find_cheapest_paths(access_paths, predicates, stats, estimated_rows, verbose=False):
    cheapest_table_access_path = {}    
    if verbose: print(f"Finding cheapest access paths for tables: {access_paths.keys()}")
    # enumerate over tables that need to be accessed
    for table_name in access_paths:
        if verbose: print(f"\nTable: {table_name}")
        # enumerate over access paths for this table
        cheapest_cost = float('inf')
        for path in access_paths[table_name]:
            # compute the cost of this access path
            # (for now, assume cost is proportional to the cardinality of the data that needs to be accessed)
            if path['scan_type'] == 'Sequential Scan':
                cost = estimate_sequentail_scan_cost(stats[table_name], estimated_rows[table_name])
            elif path['scan_type'] == 'Index Scan':
                index_id = path['index_id']
                index = indexes[index_id]
                cost = estimate_index_scan_cost(index, stats[table_name], predicates[table_name], estimated_rows[table_name])
            elif path['scan_type'] == 'Index Only Scan':
                index_id = path['index_id']
                index = indexes[index_id]
                cost = estimate_index_scan_cost(index, stats[table_name], predicates[table_name], estimated_rows[table_name], index_only_scan=True)
            else:
                raise ValueError("Scan type not supported")            

            if verbose: print(f"\t\tAccess path: {path}, Cost: {cost}")
            if cost < cheapest_cost:
                cheapest_cost = cost
                cheapest_access_path = path
        cheapest_table_access_path[table_name] = cheapest_access_path
        if verbose: print(f"\tCheapest access path: {cheapest_access_path}, Cost: {cheapest_cost}")
    return cheapest_table_access_path

In [71]:
find_cheapest_paths(access_paths, predicates, stats, estimated_rows, verbose=True)

Finding cheapest access paths for tables: dict_keys(['lineorder'])

Table: lineorder
		Access path: {'scan_type': 'Sequential Scan', 'lineorder': 'lineorder'}, Cost: 759319
		Access path: {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity'}, Cost: 15216
		Access path: {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}, Cost: 15696
		Access path: {'scan_type': 'Index Only Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}, Cost: 3342
		Access path: {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_quantity'}, Cost: 16488
	Cheapest access path: {'scan_type': 'Index Only Scan', 'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}, Cost: 3342


{'lineorder': {'scan_type': 'Index Only Scan',
  'index_id': 'IX_lineorder_lo_linenumber_lo_quantity_lo_o'}}