#### Simple Cost-Model for predicting query execution cost (assume disk IO dominates) and access paths selected 


Main steps in the cost model algorithm:

1) Identify predicates (quatily and range) and payloads
2) Enumerate all possible access paths (sequential scans, index scans, bitmap index scan + bitmap heap scan)
3) Estimate selectivity for each predicate (i.e. what fraction of data needs to be accessed from a table)
4) Estimate cardinality of each access path (using total number of rows and selectivity information)
5) Estimate disk IO cost for each access path (for index scans, we use cardinality estimate to figure out how many pages need to be fetchs)
6) Compare the estimated costs and choose best access paths for covering the query

-----------------------------------------------------------------------------------------------------------------------------------------------
Some useful notes on access paths:
* sequential scans --> better for low selectivity or if no suitable index available or if the table is really small
* index scan --> better for high selectivity and simpler predicates
* index only scan --> better for high selectivity and index is also a covering index and simpler predicates
* bitmap index scan + bitmap heap scan --> better for medium selectivity and complex predicates

In [131]:
import psycopg2
import numpy as np
from pg_utils import *
from ssb_qgen_class import *

#### Getting Table Statistics

In [132]:
def get_page_size():
    return 8192  # 8 KB

def get_table_stats(table_name):

    conn = create_connection()

    cur = conn.cursor()

    try:
        # Execute the query to get the estimated number of rows in the table
        cur.execute(f"""
                    SELECT reltuples::bigint AS estimated_rows
                    FROM pg_class
                    WHERE relname = '{table_name}';
                    """)
        estimated_rows = cur.fetchone()[0]
    except:
        print(f"Error: Could not get the estimated number of rows in the '{table_name}' table.")
        estimated_rows = None

    try:
        # Query to get column statistics
        cur.execute(f"SELECT * FROM pg_stats WHERE tablename = '{table_name}';")
        column_stats = cur.fetchall()

        # Define the column names based on the pg_stats view
        column_names = [
                        "schemaname", "tablename", "attname", "inherited", "null_frac",
                        "avg_width", "n_distinct", "most_common_vals", "most_common_freqs",
                        "histogram_bounds", "correlation", "most_common_elems",
                        "most_common_elem_freqs", "elem_count_histogram"
                    ]

        # Organize the results into a dictionary
        stats_dict = {}
        for row in column_stats:
            column_name = row[2]  # 'attname' is the third column in the result
            stats_dict[column_name] = {column_names[i]: row[i] for i in range(len(column_names))}
    except:
        print(f"Error: Could not get the statistics for the '{table_name}' table")
        stats_dict = None

    # Close the cursor and connection
    cur.close()
    close_connection(conn)

    return stats_dict, estimated_rows


# Get the statistics for all tables in the SSB database
table_names = ["customer", "dwdate", "lineorder", "part", "supplier"]
stats = {}
estimated_rows = {}
for table_name in table_names:
    stats[table_name], estimated_rows[table_name] = get_table_stats(table_name)


# Print the organized statistics dictionary
#for key, value in stats_dict.items():
#    print(f"{key}")
#    for k, v in value.items():
#        print(f"    {k}: {v}")

#print(f"\nEstimated number of rows in the 'customer' table: {estimated_rows}")
print(estimated_rows)

{'customer': 300000, 'dwdate': 2556, 'lineorder': 59986216, 'part': 800000, 'supplier': 20000}


In [133]:
ssb_tables, pk_columns = get_ssb_schema()

# create a dictionary and specify whether each attribute in each table is numeric or char
data_type_dict = {}
for table_name in ["customer", "dwdate", "lineorder", "part", "supplier"]:
    for column_name, column_type in ssb_tables[table_name]:
        if ("INT" in column_type) or ("DECIMAL" in column_type) or ("BIT" in column_type):
            data_type_dict[column_name] = "numeric"
        else:
            data_type_dict[column_name] = "char"
    
data_type_dict    

{'c_custkey': 'numeric',
 'c_name': 'char',
 'c_address': 'char',
 'c_city': 'char',
 'c_nation': 'char',
 'c_region': 'char',
 'c_phone': 'char',
 'c_mktsegment': 'char',
 'd_datekey': 'char',
 'd_date': 'char',
 'd_dayofweek': 'char',
 'd_month': 'char',
 'd_year': 'numeric',
 'd_yearmonthnum': 'numeric',
 'd_yearmonth': 'char',
 'd_daynuminweek': 'numeric',
 'd_daynuminmonth': 'numeric',
 'd_daynuminyear': 'numeric',
 'd_monthnuminyear': 'numeric',
 'd_weeknuminyear': 'numeric',
 'd_sellingseason': 'char',
 'd_lastdayinweekfl': 'numeric',
 'd_lastdayinmonthfl': 'numeric',
 'd_holidayfl': 'numeric',
 'd_weekdayfl': 'numeric',
 'lo_orderkey': 'numeric',
 'lo_linenumber': 'numeric',
 'lo_custkey': 'numeric',
 'lo_partkey': 'numeric',
 'lo_suppkey': 'numeric',
 'lo_orderdate': 'char',
 'lo_orderpriority': 'char',
 'lo_shippriority': 'char',
 'lo_quantity': 'numeric',
 'lo_extendedprice': 'numeric',
 'lo_ordtotalprice': 'numeric',
 'lo_discount': 'numeric',
 'lo_revenue': 'numeric',
 'lo

#### Estimating Selectivity for a value range on a particular column, i.e. what fraction of the data (i.e. tuples) fall in the given range, using the Table Statistics

In [134]:
def estimate_selectivity_range(attribute, value_range, stats_dict, total_rows):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # Convert negative n_distinct to an absolute count
    if n_distinct < 0:
        n_distinct = -n_distinct * total_rows

    min_value = value_range[0]
    max_value = value_range[1]
    selectivity = 0.0

    # check for overlap with most common values
    if most_common_vals:
        for val, freq in zip(most_common_vals, most_common_freqs):
            if min_value <= val <= max_value:
                selectivity += freq    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find overlapping bins
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if min_value < bin_lower_bound or max_value > bin_upper_bound:
                # does not overlap
                continue    

            # calculate the overlap fraction within this bin
            overlap_min = max(min_value, bin_lower_bound)
            overlap_max = min(max_value, bin_upper_bound)
            overlap_fraction = (overlap_max - overlap_min) / (bin_upper_bound - bin_lower_bound)

            #print(f"Overlap fraction for bin {i}: {overlap_fraction}")
            #print(f"Bin bounds: {bin_lower_bound}, {bin_upper_bound}")

            # accumulate to the total selectivity
            # Assume each bin represents an equal fraction of the total rows
            selectivity += overlap_fraction * (1.0 / total_bins)

    if selectivity == 0.0:
        # if no overlap with most common values or histogram bins, assume uniform distribution and estimate selectivity
        selectivity = 1.0 / n_distinct       

    return selectivity


def estimate_selectivity_eq(attribute, value, stats_dict):
    data_type = data_type_dict[attribute]
    # get the column statistics
    stats = stats_dict[attribute]
    # get the histogram bounds
    histogram_bounds = stats['histogram_bounds']
    n_distinct = stats['n_distinct']
    most_common_vals = stats['most_common_vals']
    most_common_freqs = stats['most_common_freqs']

    # convert most_common_values string to list of correct data type
    if most_common_vals:
        if data_type == 'numeric':
            most_common_vals = [float(x) for x in most_common_vals.strip('{}').split(',')]
        elif data_type == 'char':
            most_common_vals = [x for x in most_common_vals.strip('{}').split(',')]    
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")

    # first check if the value is in the most common values
    if most_common_vals and value in most_common_vals:
        selectivity = most_common_freqs[most_common_vals.index(value)] 
        return selectivity

    # if not a common value, estimate using n_distinct
    if n_distinct < 0:
        n_distinct = -n_distinct

    selectivity = 1.0 / n_distinct    

    if histogram_bounds is not None:
        if data_type == 'numeric':
            histogram_bounds = [float(x) for x in histogram_bounds.strip('{}').split(',')] # convert to list of integers
        elif data_type == 'char':
            histogram_bounds = [x for x in histogram_bounds.strip('{}').split(',')]
        else:
            raise ValueError("Data type not supported, needs ot be either numeric or char")    

        total_bins = len(histogram_bounds) - 1

        # iterate over bins, find bin that contains the value
        for i in range(total_bins):
            bin_lower_bound = histogram_bounds[i]
            bin_upper_bound = histogram_bounds[i+1]

            # check for range overlap
            if bin_lower_bound <= value <= bin_upper_bound:
                bin_width = bin_upper_bound - bin_lower_bound
                if bin_width > 0:
                    # assume uniform distribution within this bin and calculate selectivity
                    uniform_selectivity = 1.0 / (bin_width*total_bins)
                    selectivity = min(selectivity, uniform_selectivity)
                break    

    return selectivity


def estimate_selectivity_or(attribute, value, stats_dict):
    combined_selectivity = 0.0
    individual_selectivities = []

    # for each value in the IN list, estimate the selectivity separately
    for val in value:
        individual_selectivities.append(estimate_selectivity_eq(attribute, val, stats_dict))

    # compute combined selectivities using inclusion-exclusion principle and assuming independence
    for selectivity in individual_selectivities:
        combined_selectivity += selectivity 

    overlap_adjustment = 1.0
    for selectivity in individual_selectivities:
        overlap_adjustment *= (1.0 - selectivity)

    combined_selectivity -= overlap_adjustment   

    # make sure the combined selectivity is between 0 and 1
    combined_selectivity = max(0.0, min(combined_selectivity, 1.0))

    return combined_selectivity 


In [135]:
# test the selectivity estimation functions on a numeric column
attribute = 'd_year'
stats_dict = stats['dwdate']

# test range selectivity estimation
value_range = (1992, 1994)
selectivity = estimate_selectivity_range(attribute, value_range, stats_dict, estimated_rows)
print(f"Estimated selectivity for range {value_range}: {selectivity}")

# test equality selectivity estimation
value = 1992
selectivity = estimate_selectivity_eq(attribute, value, stats_dict)
print(f"Estimated selectivity for value {value}: {selectivity}")

# now, let's try a char column
attribute = 'd_dayofweek'

# test range selectivity estimation
value_range = ('Monday', 'Wednesday')
selectivity = estimate_selectivity_range(attribute, value_range, stats_dict, estimated_rows)
print(f"Estimated selectivity for range {value_range}: {selectivity}")

# test equality selectivity estimation
value = 'Monday'
selectivity = estimate_selectivity_eq(attribute, value, stats_dict)
print(f"Estimated selectivity for value {value}: {selectivity}")




Estimated selectivity for range (1992, 1994): 0.42879498
Estimated selectivity for value 1992: 0.14319248
Estimated selectivity for range ('Monday', 'Wednesday'): 0.8571987299999999
Estimated selectivity for value Monday: 0.14280125


### Simple Case: consider three possible access paths: Sequential Scan, Index Scan and Index Only Scan

TODO: Add Bitmap Index Scan + Bitmap Heap Scan

#### Step 1: Extract query predicates and payload, the payload is a list of column names (across possibly multiple tables) and the predicate is a list of dictionaries, each dict contains the column name, operator (either equality or range) and either a range tuple or single value

In [136]:
qgen = QGEN()

In [137]:
example_query = qgen.generate_query(11)
print(example_query)
print(example_query.predicate_dict)

template id: 11, query: 
                SELECT d_year, c_nation, SUM(lo_revenue - lo_supplycost) AS profit
                FROM dwdate, customer, supplier, part, lineorder
                WHERE lo_custkey = c_custkey
                AND lo_suppkey = s_suppkey
                AND lo_partkey = p_partkey
                AND lo_orderdate = d_datekey
                AND c_region = 'EUROPE'
                AND s_region = 'EUROPE'
                AND (p_mfgr = 'MFGR#1' OR p_mfgr = 'MFGR#2')
                GROUP BY d_year, c_nation
                ORDER BY d_year, c_nation;
            , payload: {'lineorder': ['lo_revenue', 'lo_supplycost'], 'dwdate': ['d_year'], 'customer': ['c_nation']}, predicates: {'lineorder': ['lo_custkey', 'lo_suppkey', 'lo_orderdate', 'lo_partkey'], 'dwdate': ['d_year', 'd_datekey'], 'customer': ['c_custkey', 'c_region', 'c_nation'], 'part': ['p_partkey', 'p_mfgr'], 'supplier': ['s_suppkey', 's_region']}, order by: {'dwdate': ['d_year'], 'customer': ['c_nation']}, g

In [138]:
# extract tables and associated columns
tables = {}
#tables['lineorder'] = ['lo_linenumber', 'lo_quantity', 'lo_orderdate']
for table_name in example_query.payload:
    tables[table_name] = example_query.payload[table_name]

for table_name in example_query.predicates:
    if table_name not in tables:
        tables[table_name] = []
    tables[table_name] = list(set(tables[table_name] + example_query.predicates[table_name]))

# extract the payload
payload = example_query.payload

# extract the predicates
predicates = example_query.predicate_dict

print(f"Tables and columns: {tables}")   
print(f"Payload: {payload}")
print(f"Predicates:")
for table_name, predicate_list in predicates.items():
    print(f"\n{table_name}")
    for predicate in predicate_list:
        print(f"\t{predicate}")

# create some index objects
index_1 = Index('lineorder', 'IX_lineorder_lo_orderdate_lo_suppkey', index_columns=['lo_orderdate', 'lo_suppkey'])
index_2 = Index('lineorder', 'IX_lineorder_lo_suppkey_lo_partkey_lo_o', index_columns=['lo_suppkey', 'lo_partkey'], include_columns=['lo_orderdate'])
index_3 = Index('part', 'IX_part_p_mfgr', index_columns=['p_mfgr'])
index_4 = Index('customer', 'IX_customer_c_region', index_columns=['c_region'])
index_5 = Index('supplier', 'IX_supplier_s_region', index_columns=['s_region'])
index_6 = Index('dwdate', 'IX_dwdate_d_datekey', index_columns=['d_datekey'])
indexes = {index.index_id: index for index in [index_1, index_2, index_3, index_4, index_5, index_6]}


Tables and columns: {'lineorder': ['lo_custkey', 'lo_suppkey', 'lo_revenue', 'lo_orderdate', 'lo_supplycost', 'lo_partkey'], 'dwdate': ['d_year', 'd_datekey'], 'customer': ['c_region', 'c_custkey', 'c_nation'], 'part': ['p_partkey', 'p_mfgr'], 'supplier': ['s_suppkey', 's_region']}
Payload: {'lineorder': ['lo_revenue', 'lo_supplycost'], 'dwdate': ['d_year'], 'customer': ['c_nation']}
Predicates:

customer
	{'column': 'c_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}

supplier
	{'column': 's_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}

part
	{'column': 'p_mfgr', 'operator': 'or', 'value': ('MFGR#1', 'MFGR#2'), 'join': False}

lineorder
	{'column': 'lo_custkey', 'operator': 'eq', 'value': 'c_custkey', 'join': True}
	{'column': 'lo_suppkey', 'operator': 'eq', 'value': 's_suppkey', 'join': True}
	{'column': 'lo_partkey', 'operator': 'eq', 'value': 'p_partkey', 'join': True}
	{'column': 'lo_orderdate', 'operator': 'eq', 'value': 'd_datekey', 'join': True}


#### Step 2: Enumerate the possible access path for each table involved

In [139]:
# extract join predicate columns
join_predicates = {}
join_predicates_temp = {}
for table_name in predicates:
    table_preicates = predicates[table_name]
    for pred in table_preicates:
        if pred['join'] == True:
            if table_name not in join_predicates_temp:
                join_predicates_temp[table_name] = []
            join_predicates_temp[table_name].append(pred['column'])
            # add the other table's column to the join predicate list
            other_table_column = pred['value']
            # search for the table name containing the other column
            for other_table_name in tables:
                if other_table_column in tables[other_table_name]:
                    if other_table_name not in join_predicates_temp:
                        join_predicates_temp[other_table_name] = []
                    join_predicates_temp[other_table_name].append(other_table_column)
                    join_pred = pred.copy()
                    join_pred['column'] = other_table_column
                    join_pred['value'] = pred['column']
                    if other_table_name not in join_predicates:
                        join_predicates[other_table_name] = []
                    join_predicates[other_table_name].append(join_pred)
                    break

print(f"Predicates: {predicates}")
print(f"Join predicates: {join_predicates}")

# add join predicates to the main predicate dictionary
for table_name in join_predicates:
    if table_name not in predicates:
        predicates[table_name] = []
    # Ensure unique dictionaries in the list
    existing_predicates = {frozenset(pred.items()): pred for pred in predicates[table_name]}
    for pred in join_predicates[table_name]:
        pred_key = frozenset(pred.items())
        if pred_key not in existing_predicates:
            existing_predicates[pred_key] = pred
    predicates[table_name] = list(existing_predicates.values())


access_paths = {}
for table_name in tables:
    if table_name in predicates:
        #table_predicate_cols = [pred['column'] for pred in predicates[table_name] if pred['join'] == False]
        table_predicate_cols = [pred['column'] for pred in predicates[table_name]]
    if table_name in payload:
        table_payload_cols = [col for col in payload[table_name] if col in tables[table_name]]   
    if table_name in join_predicates_temp:
        join_predicate_cols = join_predicates_temp[table_name]
    
    relevant_predicate_cols = set(table_predicate_cols).union(join_predicate_cols)
    table_access_paths = [{'scan_type': 'Sequential Scan', table_name: table_name}]
    print(f"Table predicate columns for {table_name}: {table_predicate_cols}")
    print(f"Relevant predicate columns for {table_name}: {relevant_predicate_cols}")
    for index in indexes.values():
        if index.table_name == table_name:
            print("Checking index: ", index.index_id)
            # Check if index scan is possible
            if set(index.index_columns).intersection(relevant_predicate_cols):
                table_access_paths.append({'scan_type': 'Index Scan', 'index_id': index.index_id})
                print("Index scan possible!")
            # Check if index only scan is possible
            if set(index.index_columns).issuperset(relevant_predicate_cols) and set(
                list(index.index_columns) + list(index.include_columns)).issuperset(table_payload_cols):
                table_access_paths.append({'scan_type': 'Index Only Scan', 'index_id': index.index_id})
                print("Index only scan possible!")

    access_paths[table_name] = table_access_paths


print(f"\nAccess paths: ")
for table, paths in access_paths.items():
    print(f"Table: {table}")
    for path in paths:
        print(f"    {path}")

Predicates: {'customer': [{'column': 'c_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}], 'supplier': [{'column': 's_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}], 'part': [{'column': 'p_mfgr', 'operator': 'or', 'value': ('MFGR#1', 'MFGR#2'), 'join': False}], 'lineorder': [{'column': 'lo_custkey', 'operator': 'eq', 'value': 'c_custkey', 'join': True}, {'column': 'lo_suppkey', 'operator': 'eq', 'value': 's_suppkey', 'join': True}, {'column': 'lo_partkey', 'operator': 'eq', 'value': 'p_partkey', 'join': True}, {'column': 'lo_orderdate', 'operator': 'eq', 'value': 'd_datekey', 'join': True}]}
Join predicates: {'customer': [{'column': 'c_custkey', 'operator': 'eq', 'value': 'lo_custkey', 'join': True}], 'supplier': [{'column': 's_suppkey', 'operator': 'eq', 'value': 'lo_suppkey', 'join': True}], 'part': [{'column': 'p_partkey', 'operator': 'eq', 'value': 'lo_partkey', 'join': True}], 'dwdate': [{'column': 'd_datekey', 'operator': 'eq', 'value': 'lo_orderdate', 

#### Estimate selectivity of the predicates

In [141]:
def estimate_selectivity(attribute, operator, value, stats_dict, total_rows):
    if operator == 'eq':
        return estimate_selectivity_eq(attribute, value, stats_dict)
    elif operator == 'range':
        return estimate_selectivity_range(attribute, value, stats_dict, total_rows)
    elif operator == 'or':
        return estimate_selectivity_or(attribute, value, stats_dict)    
    else:
        raise ValueError("Operator not supported, needs to be either 'eq', 'range', or 'or'")

In [142]:
for table_name, table_preds in predicates.items():
    table_stats_dict = stats[table_name]   
    table_estimated_rows = estimated_rows[table_name]
    for pred in table_preds:
        if pred['join'] == False:
            selectivity = estimate_selectivity(pred['column'], pred['operator'], pred['value'], table_stats_dict, table_estimated_rows)
            print(f"Estimated selectivity for predicate {pred}: {selectivity}")

Estimated selectivity for predicate {'column': 'c_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}: 0.19953333
Estimated selectivity for predicate {'column': 's_region', 'operator': 'eq', 'value': 'EUROPE', 'join': False}: 0.1986
Estimated selectivity for predicate {'column': 'p_mfgr', 'operator': 'or', 'value': ('MFGR#1', 'MFGR#2'), 'join': False}: 0.0


In [143]:
predicates

{'customer': [{'column': 'c_region',
   'operator': 'eq',
   'value': 'EUROPE',
   'join': False},
  {'column': 'c_custkey',
   'operator': 'eq',
   'value': 'lo_custkey',
   'join': True}],
 'supplier': [{'column': 's_region',
   'operator': 'eq',
   'value': 'EUROPE',
   'join': False},
  {'column': 's_suppkey',
   'operator': 'eq',
   'value': 'lo_suppkey',
   'join': True}],
 'part': [{'column': 'p_mfgr',
   'operator': 'or',
   'value': ('MFGR#1', 'MFGR#2'),
   'join': False},
  {'column': 'p_partkey',
   'operator': 'eq',
   'value': 'lo_partkey',
   'join': True}],
 'lineorder': [{'column': 'lo_custkey',
   'operator': 'eq',
   'value': 'c_custkey',
   'join': True},
  {'column': 'lo_suppkey',
   'operator': 'eq',
   'value': 's_suppkey',
   'join': True},
  {'column': 'lo_partkey',
   'operator': 'eq',
   'value': 'p_partkey',
   'join': True},
  {'column': 'lo_orderdate',
   'operator': 'eq',
   'value': 'd_datekey',
   'join': True}],
 'dwdate': [{'column': 'd_datekey',
   'o

#### For each table, estimate selectivity and disk IO cost of all access path for that table and select cheapest path

In [146]:
def calculate_row_overhead(num_nullable_columns=0):
    # Tuple header size
    tuple_header_size = 23  # bytes
    # Null bitmap size (1 byte for every 8 nullable columns)
    null_bitmap_size = (num_nullable_columns + 7) // 8
    # Total overhead
    total_overhead = tuple_header_size + null_bitmap_size

    return total_overhead


def table_avg_rows_per_page(table_stats_dict):
    # add up the average width of all columns to get the average width of a row
    avg_row_size = 0
    avg_row_size = sum(column_stats['avg_width'] for column_stats in table_stats_dict.values())
    # add the row overhead
    avg_row_size += calculate_row_overhead()
    # calculate the average number of rows that can fit in a page
    avg_rows_per_page = int(get_page_size() / avg_row_size)

    return avg_rows_per_page


def index_average_rows_per_page(index, table_stats_dict):
    columns = list(index.index_columns) + list(index.include_columns)   
    # add up the average width of all columns to get the average width of a row
    avg_row_size = sum(table_stats_dict[column]['avg_width'] for column in columns)
    # add the row overhead
    index_row_overhead = 16  # assume 16 bytes 
    avg_row_size += index_row_overhead
    # calculate the average number of rows that can fit in a page
    # (assuming the index is a B+ tree, so only the leaf nodes contain the actual data)
    avg_rows_per_page = int(get_page_size() / avg_row_size)
       
    return avg_rows_per_page


def estimate_index_scan_cost(index, table_stats_dict, table_predicates, total_rows, index_only_scan=False, verbose=False):
    # check if leading index column is in the predicates
    leading_index_column = index.index_columns[0]
    #print(f"\t\t\tTable predicates: {table_predicates}, Leading index column: {leading_index_column}")
    predicate_columns = [pred['column'] for pred in table_predicates]
    
    if leading_index_column not in predicate_columns:
        # assign high cost to prevent using this index, sequential scan will be cheaper
        return float('inf')
    
    # calculate the combined selectivity for this index (assuming attribute independence/no correlations of predicates)
    leading_column_selectivity = 1.0
    combined_selectivity = 1.0
    for pred in table_predicates:
        if pred['column'] in index.index_columns and pred['join'] == False:
            selectivity = estimate_selectivity(pred['column'], pred['operator'], pred['value'], table_stats_dict, total_rows)
            if verbose: print(f"\t\tSelectivity for predicate {pred}: {selectivity}")
            combined_selectivity *= selectivity
            if pred['column'] == leading_index_column:
                leading_column_selectivity = selectivity

    # estimate cardinality of the index scan
    index_cardinality = leading_column_selectivity * total_rows
    # estimate the number of pages that need to be accessed
    avg_rows_per_page = index_average_rows_per_page(index, table_stats_dict)
    index_pages = int(index_cardinality / avg_rows_per_page)
    
    table_pages = 0
    if not index_only_scan: 
        table_cardinality = combined_selectivity * total_rows
        # for index scan, we need to access the table as well
        index_average_rows_per_page_table = table_avg_rows_per_page(table_stats_dict)
        table_pages = int(table_cardinality / index_average_rows_per_page_table)
    
    # return total cost as the sum of index and table pages
    index_scan_cost = index_pages + table_pages
    if verbose: 
        print(f"\tLeading column selectivity: {leading_column_selectivity}, Combined selectivity: {combined_selectivity}")
        print(f"\tEstimated number of pages for index scan: {index_pages}, Table pages: {table_pages}")

    return index_scan_cost


def estimate_sequentail_scan_cost(table_stats_dict, total_rows, verbose=False):
    # estimate cardinality of the scan
    scan_cardinality = total_rows
    # estimate the number of pages that need to be accessed
    avg_rows_per_page = table_avg_rows_per_page(table_stats_dict)
    scan_pages = int(scan_cardinality / avg_rows_per_page)
    # estimate the total cost as the number of pages that need to be accessed
    sequential_scan_cost = scan_pages

    if verbose: print(f"\tEstimated number of pages for sequential scan: {scan_pages}")

    return sequential_scan_cost


def find_cheapest_paths(access_paths, predicates, join_predicates, stats, estimated_rows, verbose=False):
    cheapest_table_access_path = {}    
    if verbose: print(f"Finding cheapest access paths for tables: {access_paths.keys()}")
    # enumerate over tables that need to be accessed
    for table_name in access_paths:
        if verbose: print(f"\nTable: {table_name}")
        # enumerate over access paths for this table
        cheapest_cost = float('inf')
        for path in access_paths[table_name]:
            print(f"\tComputing cost for access path: {path}")
            # compute the cost of this access path
            # (for now, assume cost is proportional to the cardinality of the data that needs to be accessed)
            if path['scan_type'] == 'Sequential Scan':
                cost = estimate_sequentail_scan_cost(stats[table_name], estimated_rows[table_name], verbose=verbose)
            elif path['scan_type'] == 'Index Scan':
                index_id = path['index_id']
                index = indexes[index_id]
                cost = estimate_index_scan_cost(index, stats[table_name], predicates[table_name], estimated_rows[table_name], verbose=verbose)    
            elif path['scan_type'] == 'Index Only Scan':
                index_id = path['index_id']
                index = indexes[index_id]
                cost = estimate_index_scan_cost(index, stats[table_name], predicates[table_name], estimated_rows[table_name], index_only_scan=True, verbose=verbose)
            else:
                raise ValueError("Scan type not supported")            

            if verbose: print(f"\tAccess path: {path}, Cost: {cost}\n")
            if cost < cheapest_cost:
                cheapest_cost = cost
                cheapest_access_path = path
        cheapest_table_access_path[table_name] = cheapest_access_path
        if verbose: print(f"\tCheapest access path: {cheapest_access_path}, Cost: {cheapest_cost}")
    return cheapest_table_access_path

In [145]:
find_cheapest_paths(access_paths, predicates, join_predicates, stats, estimated_rows, verbose=True)

Finding cheapest access paths for tables: dict_keys(['lineorder', 'dwdate', 'customer', 'part', 'supplier'])

Table: lineorder
	Computing cost for access path: {'scan_type': 'Sequential Scan', 'lineorder': 'lineorder'}
	Estimated number of pages for sequential scan: 759319
	Access path: {'scan_type': 'Sequential Scan', 'lineorder': 'lineorder'}, Cost: 759319

	Computing cost for access path: {'scan_type': 'Index Scan', 'index_id': 'IX_lineorder_lo_orderdate_lo_suppkey'}
			Table predicates: [{'column': 'lo_custkey', 'operator': 'eq', 'value': 'c_custkey', 'join': True}, {'column': 'lo_suppkey', 'operator': 'eq', 'value': 's_suppkey', 'join': True}, {'column': 'lo_partkey', 'operator': 'eq', 'value': 'p_partkey', 'join': True}, {'column': 'lo_orderdate', 'operator': 'eq', 'value': 'd_datekey', 'join': True}], Leading index column: lo_orderdate
	Leading column selectivity: 1.0, Combined selectivity: 1.0
	Estimated number of pages for index scan: 175912, Table pages: 759319
	Access path: 

{'lineorder': {'scan_type': 'Sequential Scan', 'lineorder': 'lineorder'},
 'dwdate': {'scan_type': 'Sequential Scan', 'dwdate': 'dwdate'},
 'customer': {'scan_type': 'Index Scan', 'index_id': 'IX_customer_c_region'},
 'part': {'scan_type': 'Index Scan', 'index_id': 'IX_part_p_mfgr'},
 'supplier': {'scan_type': 'Index Scan', 'index_id': 'IX_supplier_s_region'}}