#### Generating index selection recommendations via MS SQL Database Tuning Adviser (DTA)

We will use DTA to generate index recommendations on some sample TPC-H OLAP workloads. 

Given a workload containing N queries, we will split it up into m rounds (for simplicity we will split up evenly into disjoint subsets, we can also do overlapping). 

* Experiment 1: On each round, we will execute the queries in that round and measure performance.

* Experiment 2: On each round, we will first use DTA to obtain recommendations, implement those recommendation (i.e. create/drop indices etc.), then execute the queries in that round and measure performance. To generate recommendations, we will use all queries that have been seen up to and including queries in the current round.  


In [1]:
import logging
import datetime
import os
import subprocess
import uuid

import pyodbc
import sys
import random
import pandas as pd
import time
import os
from tqdm import tqdm
import logging
import re

import xml.etree.ElementTree as ET


In [2]:
def read_sql_files(base_dir, instances_per_template):
    queries = []
    
    # Regex to find erroneous "where rownum <=" lines
    erroneous_line_pattern = re.compile(r'^\s*where\s+rownum\s*<=\s*\d+\s*;\s*$', re.IGNORECASE)

    # Loop through each query directory
    for query_dir in sorted(os.listdir(base_dir)):
        query_path = os.path.join(base_dir, query_dir)
        
        if os.path.isdir(query_path):
            # Initialize a counter for the number of instances read from this template
            instance_count = 0
            
            # Loop through each SQL file in the query directory
            for sql_file in sorted(os.listdir(query_path)):
                if instance_count >= instances_per_template:
                    break  # Stop reading more files from this template
                
                sql_file_path = os.path.join(query_path, sql_file)
                
                if sql_file_path.endswith('.sql'):
                    with open(sql_file_path, 'r') as file:
                        lines = file.readlines()
                        
                        # Filter out the erroneous "where rownum <=" lines
                        filtered_lines = [line for line in lines if not erroneous_line_pattern.match(line)]
                        
                        # Extract the query from the filtered lines
                        query = ''.join(filtered_lines[3:]).strip()
                        
                        queries.append(query)
                        instance_count += 1  # Increment the counter
    
    return queries


# Base directory containing the generated queries
base_dir = '../TPCH_generated_queries'

# Read the SQL files and store the queries in a list
queries = read_sql_files(base_dir, instances_per_template=5)

# shuffle the queries
random.shuffle(queries)

print(len(queries))

110


#### Create a workload file with all the queries.

In [3]:
"""workload_filename = 'workload_tpch_20.sql'

# Write the queries to file
with open(workload_filename, 'w') as file:
    for query in queries:
        file.write(query + '\n\n')
        """

"workload_filename = 'workload_tpch_20.sql'\n\n# Write the queries to file\nwith open(workload_filename, 'w') as file:\n    for query in queries:\n        file.write(query + '\n\n')\n        "

#### Define DTA recommender class.

In [4]:
"""conn_str = (
    "Driver={ODBC Driver 17 for SQL Server};"
    "Server=172.16.6.196,1433;"  # Use the IP address and port directly
    "Database=TPCH1;"  
    "UID=wsl;" 
    "PWD=greatpond501;"  
)

conn = pyodbc.connect(conn_str)
cursor = conn.cursor()
# test the connection
print(cursor.execute("SELECT @@version;"))"""

'conn_str = (\n    "Driver={ODBC Driver 17 for SQL Server};"\n    "Server=172.16.6.196,1433;"  # Use the IP address and port directly\n    "Database=TPCH1;"  \n    "UID=wsl;" \n    "PWD=greatpond501;"  \n)\n\nconn = pyodbc.connect(conn_str)\ncursor = conn.cursor()\n# test the connection\nprint(cursor.execute("SELECT @@version;"))'

In [5]:
""" Code originally from Malinga Perera's work """
class QueryPlan:
    def __init__(self, xml_string):
        self.estimated_rows = 0
        self.est_statement_sub_tree_cost = 0
        self.elapsed_time = 0
        self.cpu_time = 0
        self.non_clustered_index_usage = []
        self.clustered_index_usage = []

        ns = {'sp': 'http://schemas.microsoft.com/sqlserver/2004/07/showplan'}
        root = ET.fromstring(xml_string)
        stmt_simple = root.find('.//sp:StmtSimple', ns)
        if stmt_simple is not None:
            self.estimated_rows = float(stmt_simple.attrib.get('StatementEstRows', 0))
            self.est_statement_sub_tree_cost = float(stmt_simple.attrib.get('StatementSubTreeCost', 0))

        query_stats = root.find('.//sp:QueryTimeStats', ns)
        if query_stats is not None:
            self.cpu_time = float(query_stats.attrib.get('CpuTime', 0))
            self.elapsed_time = float(query_stats.attrib.get('ElapsedTime', 0)) / 1000

        rel_ops = root.findall('.//sp:RelOp', ns)
        total_po_sub_tree_cost = 0
        total_po_actual = 0

        for rel_op in rel_ops:
            temp_act_elapsed_time = 0
            if rel_op.attrib.get('PhysicalOp') in {'Index Seek', 'Index Scan', 'Clustered Index Scan', 'Clustered Index Seek'}:
                total_po_sub_tree_cost += float(rel_op.attrib.get('EstimatedTotalSubtreeCost', 0))
                runtime_thread_information = rel_op.findall('.//sp:RunTimeCountersPerThread', ns)
                for thread_info in runtime_thread_information:
                    temp_act_elapsed_time = max(
                        int(thread_info.attrib.get('ActualElapsedms', 0)), temp_act_elapsed_time)
                total_po_actual += temp_act_elapsed_time / 1000

        for rel_op in rel_ops:
            rows_read = 0
            act_rel_op_elapsed_time = 0
            if rel_op.attrib.get('PhysicalOp') in {'Index Seek', 'Index Scan', 'Clustered Index Scan', 'Clustered Index Seek'}:
                runtime_thread_information = rel_op.findall('.//sp:RunTimeCountersPerThread', ns)
                for thread_info in runtime_thread_information:
                    rows_read += int(thread_info.attrib.get('ActualRowsRead', 0))
                    act_rel_op_elapsed_time = max(int(thread_info.attrib.get('ActualElapsedms', 0)), act_rel_op_elapsed_time)
            act_rel_op_elapsed_time = act_rel_op_elapsed_time / 1000
            if rows_read == 0:
                rows_read = float(rel_op.attrib.get('EstimatedRowsRead', 0))
            rows_output = float(rel_op.attrib.get('EstimateRows', 0))
            if rel_op.attrib.get('PhysicalOp') in {'Index Seek', 'Index Scan'}:
                po_index_scan = rel_op.find('.//sp:IndexScan', ns)
                if po_index_scan is not None:
                    po_index = po_index_scan.find('.//sp:Object', ns).attrib.get('Index', '').strip("[]")
                    self.non_clustered_index_usage.append(
                        (po_index, act_rel_op_elapsed_time, self.cpu_time, self.est_statement_sub_tree_cost, rows_read, rows_output))
            elif rel_op.attrib.get('PhysicalOp') in {'Clustered Index Scan', 'Clustered Index Seek'}:
                po_index_scan = rel_op.find('.//sp:IndexScan', ns)
                if po_index_scan is not None:
                    table = po_index_scan.find('.//sp:Object', ns).attrib.get('Table', '').strip("[]")
                    self.clustered_index_usage.append(
                        (table, act_rel_op_elapsed_time, self.cpu_time, self.est_statement_sub_tree_cost, rows_read, rows_output))
                    

# more sophisticated query execution metrics        
def execute_query(query, connection, cost_type='elapsed_time', verbose=False):
    try:
        cursor = connection.cursor()
        # clear cache
        cursor.execute("DBCC DROPCLEANBUFFERS")
        # enable statistics collection
        cursor.execute("SET STATISTICS XML ON")
        # execute the query
        cursor.execute(query)
        cursor.nextset()
        # fetch execution stats
        stat_xml = cursor.fetchone()[0]
        cursor.execute("SET STATISTICS XML OFF")
        # parse query plan
        query_plan = QueryPlan(stat_xml)

        if verbose:
            print(f"QUERY: \n{query}\n")
            print(f"ELAPSED TIME: \n{query_plan.elapsed_time}\n")
            print(f"CPU TIME: \n{query_plan.cpu_time}\n")
            print(f"SUBTREE COST: \n{query_plan.est_statement_sub_tree_cost}\n")
            print(f"NON CLUSTERED INDEX USAGE: \n{query_plan.non_clustered_index_usage}\n")
            print(f"CLUSTERED INDEX USAGE: \n{query_plan.clustered_index_usage}\n")

        # Determine the cost type and return the appropriate metric
        if cost_type == 'elapsed_time':
            return float(query_plan.elapsed_time), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        elif cost_type == 'cpu_time':
            return float(query_plan.cpu_time), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        elif cost_type == 'sub_tree_cost':
            return float(query_plan.est_statement_sub_tree_cost), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        else:
            return float(query_plan.est_statement_sub_tree_cost), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
    except Exception as e:
        logging.error(f"Exception when executing query: {query}, Error: {e}")
        return-1, [], []
    finally:
        cursor.close()


In [6]:
class DTA_recommender:
    def __init__(self, queries, invoke_ta_rounds, verbose=False):
        self.queries = queries # list of queries to be used as workload
        self.invoke_ta_rounds = invoke_ta_rounds # list specifying the rounds to invoke TA
        self.verbose = verbose
        self.conn_string = ("Driver={ODBC Driver 17 for SQL Server};"
                            "Server=172.16.6.196,1433;"  # Use the IP address and port directly
                            "Database=TPCH1;"  
                            "UID=wsl;" 
                            "PWD=greatpond501;")
        
        self.server = "172.16.6.196,1433"
        self.database = "TPCH1"
        self.username = "wsl"
        self.password = "greatpond501"

        
    def run_dta(self, num_rounds=1, invoke_DTA=True, clear_indexes_start=False, clear_indexes_end=True):

        # establish connection to the database
        self.conn = pyodbc.connect(self.conn_string)

        # clear all non-clustered indexes at the start
        if clear_indexes_start:   
            self.remove_all_nonclustered_indexes()
  
        if num_rounds > 0:

            # reset workload file
            self.workload_file = "workload_tpch1.sql"
            open(self.workload_file, 'w').close()

            num_queries_per_round = len(self.queries) // num_rounds

            # iterate over rounds
            counter = 0
            for i in range(num_rounds):
                print(f"Round {i+1} of {num_rounds}")
                current_round_queries = self.queries[i*num_queries_per_round:(i+1)*num_queries_per_round]
                
                if invoke_DTA:
                    # write the queries for current round to the workload file
                    with open(self.workload_file, 'a+') as file:
                        for query in current_round_queries:
                            # exclude queries with "view" in them, otherwise DTA will throw a syntax error
                            if "view" not in query.lower():
                                file.write(query)
                                file.write('\n\n\n')
                                counter += 1

                    print(f"{counter} queries written to workload file")

                    # invoke DTA if current round is in invoke_ta_rounds
                    if i in self.invoke_ta_rounds:
                        recommendation_cost_round, recommmendation_output_file = self.get_recommendations()      
                        if os.path.isfile(recommmendation_output_file):
                            self.implement_recommendations(recommmendation_output_file)

                # now execute the workload for current round
                execution_cost_round = self.execute_workload(current_round_queries)        

            # clear all indexes
            if clear_indexes_end: self.remove_all_nonclustered_indexes()
            # clear out all the recommendations and session files from directory
            for file in os.listdir():
                if file.startswith("recommendations") or file.startswith("session_output"):
                    os.remove(file)

        # close the connection
        self.conn.close()

        
    def get_recommendations(self):
        session_name = f"session_{uuid.uuid4()}"
        max_memory = 4*1024 # MB
        max_time = 1 # minutes
        recommendation_output_file = f"recommendations_{session_name}.sql"
        session_output_xml_file = f"session_output_{session_name}.xml"        
        dta_exe_path = '"/mnt/c/Program Files (x86)/Microsoft SQL Server Management Studio 20/Common7/DTA.exe"'
        dta_command = f'{dta_exe_path} -S 172.16.6.196 -U wsl -P greatpond501 -D {self.database} -d {self.database} ' \
                    f'-if "{self.workload_file}" -s {session_name} ' \
                    f'-of "{recommendation_output_file}" ' \
                    f'-ox "{session_output_xml_file}" ' \
                    f'-fa NCL_IDX -fp NONE -fk CL_IDX -B {max_memory} -A {max_time} -F'

        start_time = datetime.datetime.now()
        subprocess.run(dta_command, shell=True)
        end_time = datetime.datetime.now()
        time_elapsed = (end_time - start_time).total_seconds()
        
        print(f"DTA recommendation time --> {time_elapsed} seconds.")
                  
        return time_elapsed, recommendation_output_file     


    def implement_recommendations(self, recommendation_output_file):
        try:
            with open(recommendation_output_file, 'r', encoding="utf-16") as file:
                query_lines = file.readlines()
                sql = ' '.join(query_lines)
                sql = sql.replace('go\n', ';')
        except Exception as e:
            print(f"Error reading recommendations file: {e}")
            return 0                    

        recommendation_queries = sql.split(';')
        #if self.verbose:
        #    print(f"Recommendation queries: \n{recommendation_queries}")
        
        total_index_creation_cost = 0
        for query in recommendation_queries[1:]:
            if not query.isspace():
                if "create nonclustered index" in query.lower():
                    total_index_creation_cost += self.create_nonclustered_index(query) 
                elif "drop index" in query.lower():
                    self.drop_nonclustered_index(query=query)

        print(f"Implemented recommendations.")
        print(f"Total index creation time --> {total_index_creation_cost} seconds. Total size of configuration --> {self.get_current_pds_size()} MB")

        return total_index_creation_cost


    def create_nonclustered_index(self, query):
        cursor = self.conn.cursor()
        try:
            cursor.execute("SET STATISTICS XML ON")
            cursor.execute(query)
            stat_xml = cursor.fetchone()[0]
            cursor.execute("SET STATISTICS XML OFF")
            self.conn.commit()    

            if self.verbose:
                #print(f"Query: {query}")
                # Extract the index name
                index_start = query.upper().find("CREATE NONCLUSTERED INDEX") + len("CREATE NONCLUSTERED INDEX")
                index_end = query.upper().find("ON", index_start)
                index_name = query[index_start:index_end].strip()

                # Extract the table name
                table_start = query.upper().find("ON", index_end) + len("ON")
                table_end = query.find("(", table_start)
                table_name = query[table_start:table_end].strip()

                # Extract the indexed columns
                columns_start = query.find("(", table_end) + 1
                columns_end = query.find(")", columns_start)
                indexed_columns = [col.split()[0].strip() for col in query[columns_start:columns_end].split(",")]

                # Extract the included columns
                include_start = query.upper().find("INCLUDE", columns_end)
                if include_start != -1:
                    include_start = query.find("(", include_start) + 1
                    include_end = query.find(")", include_start)
                    included_columns = [col.strip() for col in query[include_start:include_end].split(",")]
                else:
                    included_columns = []
                
                print(f"Created index --> {table_name}.{index_name}, Indexed Columns --> {indexed_columns}, Included Columns --> {included_columns}")

                # get index creation time
            query_plan = QueryPlan(stat_xml)
            elapsed_time = query_plan.elapsed_time
            #cpu_time = query_plan.cpu_time

        except pyodbc.Error as e:
            print(f"Error creating index {query}: {e}")
            elapsed_time = 0
        finally:
            cursor.close()    

        return elapsed_time


    def drop_nonclustered_index(self, schema_name=None, table_name=None, index_name=None, query=None):
        cursor = self.conn.cursor()
        if query is None:
            query = f"DROP INDEX {schema_name}.{table_name}.{index_name}"
        else:
            # extract the schema, table and index names from the query
            split = query.split()
            index_name = split[2][1:-1]
            schema_name = split[4].split('.')[0][1:-1]
            table_name = split[4].split('.')[1][1:-1]
        try:
            cursor.execute(query)
            self.conn.commit()
            if self.verbose:
                print(f"Dropped index --> [{schema_name}].[{table_name}].[{index_name}]")
        except pyodbc.Error as e:
            print(f"Error dropping index [{schema_name}].[{table_name}].[{index_name}]: {e}")     
        finally:
            cursor.close()          


    def get_nonclustered_indexes(self):
        cursor = self.conn.cursor()
        query = """
                SELECT 
                s.name AS SchemaName,
                t.name AS TableName,
                i.name AS IndexName
                FROM 
                    sys.indexes i
                JOIN 
                    sys.tables t ON i.object_id = t.object_id
                JOIN 
                    sys.schemas s ON t.schema_id = s.schema_id
                WHERE 
                    i.type_desc = 'NONCLUSTERED'  -- Only non-clustered indexes
                    AND i.is_primary_key = 0  -- Exclude primary key indexes
                    AND i.is_unique_constraint = 0  -- Exclude unique constraints
                ORDER BY 
                    s.name, t.name, i.name; 
                """
        try:
            cursor.execute(query)
            indexes = cursor.fetchall() # return list of tuples: (schema_name, table_name, index_name)
        except pyodbc.Error as e:
            print(f"Error fetching non-clustered indexes: {e}")
            indexes = []
        finally:    
            cursor.close()
        
        return indexes


    def remove_all_nonclustered_indexes(self):
        # get all non-clustered indexes
        indexes = self.get_nonclustered_indexes()
        print(f"All non-clustered indexes --> {indexes}")
        # drop all non-clustered indexes
        for (schema_name, table_name, index_name) in indexes:
            self.drop_nonclustered_index(schema_name=schema_name, table_name=table_name, index_name=index_name)

        if self.verbose:
            print("All nonclustered indexes removed.")


    # get size of all PDS in the database
    def get_current_pds_size(self):
        cursor = self.conn.cursor()
        query = '''SELECT (SUM(s.[used_page_count]) * 8)/1024.0 AS size_mb FROM sys.dm_db_partition_stats AS s'''
        try:
            cursor.execute(query)
            pds_size = cursor.fetchone()[0]
        except pyodbc.Error as e:
            print(f"Error fetching PDS size: {e}")
            pds_size = 0
        finally:
            cursor.close()    

        return pds_size



    def execute_simple(self, query):
        cursor = self.conn.cursor()
        try:
            cursor.execute(query)
            self.conn.commit()
        except pyodbc.Error as e:
            print(f"Error executing query {query}: {e}")
        finally:
            cursor.close()    


    def execute_query(self, query, cost_type='elapsed_time'):
        cursor = self.conn.cursor()
        try:
            # clear cache
            cursor.execute("DBCC DROPCLEANBUFFERS")
            # enable statistics collection
            cursor.execute("SET STATISTICS XML ON")
            # execute the query
            cursor.execute(query)
            cursor.nextset()
            # fetch execution stats
            stat_xml = cursor.fetchone()[0]
            cursor.execute("SET STATISTICS XML OFF")
            # parse query plan
            query_plan = QueryPlan(stat_xml)

            """if self.verbose:
                print(f"QUERY: \n{query}\n")
                print(f"ELAPSED TIME: \n{query_plan.elapsed_time}\n")
                print(f"CPU TIME: \n{query_plan.cpu_time}\n")
                print(f"SUBTREE COST: \n{query_plan.est_statement_sub_tree_cost}\n")
                print(f"NON CLUSTERED INDEX USAGE: \n{query_plan.non_clustered_index_usage}\n")
                print(f"CLUSTERED INDEX USAGE: \n{query_plan.clustered_index_usage}\n")
            """
        except pyodbc.Error as e:
            logging.error(f"Error executing query: {query}, Error: {e}")
            return 0, [], []
        finally:
            cursor.close()

        # Determine the cost type and return the appropriate metric
        if cost_type == 'elapsed_time':
            return float(query_plan.elapsed_time), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        elif cost_type == 'cpu_time':
            return float(query_plan.cpu_time), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        elif cost_type == 'sub_tree_cost':
            return float(query_plan.est_statement_sub_tree_cost), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage
        else:
            return float(query_plan.est_statement_sub_tree_cost), query_plan.non_clustered_index_usage, query_plan.clustered_index_usage


    def execute_workload(self, workload):
        if self.verbose:
            print(f"Executing workload of {len(workload)} queries")
        total_elapsed_time = 0
        # execute the workload
        for query in workload:
            cost, index_seeks, clustered_index_scans = self.execute_query(query)
            total_elapsed_time += cost   
        print(f"Current round workload execution time --> {total_elapsed_time} seconds.")     

        return total_elapsed_time



In [8]:
# test dta
dta_recommender = DTA_recommender(queries, [0, 4, 8], verbose=True)
dta_recommender.run_dta(num_rounds=10, invoke_DTA=True, clear_indexes_start=False, clear_indexes_end=False)

Round 1 of 10
11 queries written to workload file
Microsoft (R) SQL Server dta
Version 20.2.30.0
Copyright (c) Microsoft. All rights reserved.

Tuning session successfully created. Session ID is 41.

Time elapsed: 00:00:10            
Workload consumed:  100%, Estimated improvement:    0%                         


In [None]:
# now run without invoking DTA, the query execution times should be higher
dta_recommender.run_dta(num_rounds=1, clear_indexes_start=True, invoke_DTA=False)

All non-clustered indexes --> []
All nonclustered indexes removed.
Round 1 of 1
Executing workload of 22 queries
Current round workload execution time --> 28.278999999999996
All non-clustered indexes --> []
All nonclustered indexes removed.
