In [1]:
import os
import pandas as pd
import configparser
from sql_metadata import Parser
from datetime import datetime
from mo_sql_parsing import parse
from mo_sql_parsing import format
from itertools import combinations
import logging
from collections import defaultdict
import xml.etree.ElementTree as ET
from typing import List

"""
Parser is not working as expected for 
pivot queries
IGNORE NULLS OVER (oracledb)
* exclude
cross join
"""
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - Line: %(lineno)d - %(message)s')

config_file_path = r'C:\Users\DOUNDKARSHUBHAMBALU\Downloads\sql_reports_analysis\config.txt'

table_details_excel = r"C:\Users\DOUNDKARSHUBHAMBALU\Downloads\sql_reports_analysis\tables_coe.xlsx"

sheet_names = pd.ExcelFile(table_details_excel).sheet_names

dfs = []

for sheet in sheet_names:
    df = pd.read_excel(table_details_excel, sheet_name=sheet)
    df = df[['Name','Context']]
    dfs.append(df)

lookup = pd.concat(dfs, ignore_index=True)

lookup[['loc','db','schema']] = lookup['Context'].str.split(' >> ', expand=True)

df_column = pd.DataFrame(columns=['Report_Name', 'Db', 'Schema', 'Table', 'Column'])
df_table = pd.DataFrame(columns=['Report_Name', 'Table'])
df_join_condition = pd.DataFrame(columns=['Report_Name', 'Union/Intersection/Minus if any',
                                          'Joining_Type', 'Joining_Condition',
                                          'Joining_Columns_With_Table_Alias_if_any',
                                          'Joining_Columns_Having_Proper_Table_Details',
                                          'LeftTable_Subquery_if_any',
                                          'RightTable_Subquery_if_any'])
 
 
def getconfig():
    """ to grab input and output file location"""
    config = configparser.ConfigParser()
    config.read(config_file_path)
    config_dict = dict(config.items('PATH'))
    return config_dict
 
def extract_with(sql_dict,c=0):
    columns = []

    if 'with' in sql_dict:
        wit = sql_dict.pop('with')

        try:
            formatted_query = format(sql_dict)
            columns += Parser(formatted_query).columns
        except Exception as e:
            columns += extract_with(sql_dict)
            logging.error(f"Error parsing query inside extract with: {formatted_query}",exc_info=True)
        if isinstance(wit,list):
            for subquery in wit:
                subquery_value = subquery['value']
                columns += extract_with(subquery_value)
        else:
            subquery_value = wit['value']
            columns += extract_with(subquery_value)
    else:
        try:
            formatted_query = format(sql_dict)
            columns += Parser(formatted_query).columns
        except Exception as e:
            if 'from' in sql_dict:
                from_ = sql_dict.pop('from')
                sql_dict['from'] = "placeholder"
                for i in sql_dict.get('select', []):
                    if not isinstance(i, dict):
                        continue
                    value = i.get('value', '')
                    if not isinstance(value, dict):
                        continue
                    if 'select' in value:
                        columns += extract_with(i.get('value'))
                        i['value'] = 'placeholder'
                c += 1
                if c > 10:
                    print(sql_dict,"RECURSION"*10)
                    exit()
                columns += extract_with(sql_dict,c) ## recursion loop prone
                c = 0
                if isinstance(from_, list):
                    for item in from_:
                        columns += extract_with(item)
                elif isinstance(from_, dict):
                    columns += extract_with(from_)
                formatted_query = format(sql_dict)
                try:
                    columns += Parser(formatted_query).columns
                except Exception as e:
                    for i in sql_dict.get('select'):
                        if 'select' in i.get('value'):
                            columns += extract_with(i.get('value'))
                            i['value'] = 'placeholder'

            else:
                if isinstance(sql_dict, list):
                    for item in sql_dict:
                        columns += self.extract_with(item)
    return columns
 
 
def table(input_sql, report_name):
    df_table_temp = pd.DataFrame(Parser(input_sql).tables, columns=['Table'])
    df_table_temp['Report_Name'] = report_name
    df_table_temp = df_table_temp[['Report_Name', 'Table']]
    return df_table_temp
 
 
def excel_formatting(dict_compare, output_file_name, default_row=50):
    cfg = getconfig()
    output_file_path = cfg.get('output_path')
    file_path = os.path.join(output_file_path, output_file_name)
    writer = pd.ExcelWriter(file_path, engine='xlsxwriter')
 
    [dict_compare[i].to_excel(writer, sheet_name=i, index=False) for i in dict_compare.keys()]
 
    workbook = writer.book
    sheet_list = list(dict_compare.keys())
 
    [writer.sheets[sheet].set_default_row(default_row) for sheet in sheet_list]
 
    text_format = workbook.add_format({'text_wrap': True,
                                       'valign': 'left',
                                       'align': 'top',
                                       'border': 1
                                       })
 
    cell_format = workbook.add_format({'text_wrap': True,
                                       'valign': 'left',
                                       'align': 'top',
                                       'border': 0
                                       })
 
 
    [writer.sheets[sheet].set_column(0, dict_compare[sheet].shape[1]-1,
                                     30, text_format) for sheet in sheet_list]
 
 
 
    [writer.sheets[sheet].set_row(i,1 , cell_format) for sheet in sheet_list for i in
     range(dict_compare[sheet].shape[0]+1, dict_compare[sheet].shape[0]+1000)]
 
 
    header_format = workbook.add_format(
        {'bold': True,
         'text_wrap': True,
         'valign': 'center',
         'fg_color': '#D7E4BC',
         'border': 1,
         'align': 'top'}
    )
 
 
    [writer.sheets[sheet].write(0, col_num, value, header_format) for sheet in sheet_list
     for col_num, value in enumerate(dict_compare[sheet].columns.values)]
 
    writer.close()
 
def excel_formatting_metadata(df_table, df_column, df_join_condition, output_file_path, report_name, default_row=50):
    input_file_path = getconfig()['input_path']
    output_file_name = 'Metadata_extract' + '_' + datetime.now().strftime('%Y%m%d') + '_' + datetime.now().strftime(
        '%H%M') + '.xlsx'

    file_path = os.path.join(output_file_path, output_file_name)
    print(file_path)
    writer = pd.ExcelWriter(file_path, engine='xlsxwriter')
    df_table.to_excel(writer, sheet_name='Table List', index=False)
    df_column.to_excel(writer, sheet_name='Column_Details', index=False)
    df_join_condition.to_excel(writer, sheet_name='Joining_Condition', index=False)
    dict_row_column = {}
    dict_df = {}
    workbook = writer.book
    sheet_list = ['Table List',  'Column_Details', 'Joining_Condition'] #+ [i for i in dict_compare.keys()]
    dict_row_column['Table List'] = df_table.shape
    dict_row_column['Column_Details'] = df_column.shape
    dict_row_column['Joining_Condition'] = df_join_condition.shape

    dict_df['Table List'] = df_table
    dict_df['Column_Details'] = df_column
    dict_df['Joining_Condition'] = df_join_condition


    [writer.sheets[sheet].set_default_row(default_row) for sheet in sheet_list]

    text_format = workbook.add_format({'text_wrap': True,
                                       'valign': 'left',
                                       'align': 'top',
                                       'border': 1
                                       })
    cell_format = workbook.add_format({'text_wrap': True,
                                       'valign': 'left',
                                       'align': 'top',
                                       'border': 0
                                       })

    [writer.sheets[sheet].set_column(0, dict_row_column[sheet][1]-1,
                                     30, text_format) for sheet in sheet_list]


    [writer.sheets[sheet].set_row(i,1 , cell_format) for sheet in sheet_list for i in
     range(dict_row_column[sheet][0]+1, dict_row_column[sheet][0]+100)]

    header_format = workbook.add_format(
        {'bold': True,
         'text_wrap': True,
         'valign': 'center',
         'fg_color': '#D7E4BC',
         'border': 1,
         'align': 'top'}
    )


    [writer.sheets[sheet].write(0, col_num, value, header_format) for sheet in sheet_list
     for col_num, value in enumerate(dict_df[sheet].columns.values)]

    writer.close()


def extract_join_info(join_info, dict_table, joining_type, joining_condition, joining_columns_ta, 
                      joining_columns, lefttable_subquery, righttable_subquery):
    """Extract join information from a single join."""
    if isinstance(join_info, str):
        dict_table[join_info.lower()] = format(join_info)
        return
    elif 'pivot' in join_info:
        return
    elif 'on' not in join_info and 'using' not in join_info:
        if 'value' in join_info:
            dict_table[join_info['name'].lower()] = format(join_info['value'])
        else:
            cross_join = list(join_info.keys())[0]
            join_info = join_info[cross_join]
        return
    elif isinstance(join_info[list(join_info.keys())[0]], str):
        join_key = list(join_info.keys())[0]
        join_value = join_info[join_key]
        dict_table[join_value.lower()] = format(join_value)
    else:
        join_key = list(join_info.keys())[0]
        join_value = join_info[join_key]
        dict_table[join_value['name'].lower()] = format(join_value['value'])
    joining_type.append(join_key)
    if 'on' in join_info:
        joining_condition.append(list(join_info['on'].keys())[0])
    elif 'using' in join_info:
        joining_condition.append(list(join_info['using'])[0])
    try:
        if 'on' in join_info:
            joining_columns_ta.append(','.join(list(join_info['on'].values())[0]))
            left_part, right_part = list(join_info['on'].values())[0]
        elif 'using' in join_info:
            joining_columns_ta.append(','.join(join_info['using']))
            left_part = join_info['using']
            right_part = join_info['using']

    except TypeError as e:
        def extract_values(data):
            result = []
            if isinstance(data, dict):
                for key, value in data.items():
                    if key == 'literal':
                        continue
                    result.extend(extract_values(value))
            elif isinstance(data, list):
                for item in data:
                    result.extend(extract_values(item))
            elif isinstance(data, str):
                result.append(data)
    
            return result
        sub_value = extract_values(list(join_info['on'].values()))
        joining_columns_ta.append(', '.join(sub_value))  # Join the strings
        left_part, right_part = sub_value[:2]
    left_part, right_part = left_part.lower(), right_part.lower()
    if '.' in left_part:
        left_part_tbl, left_part_clm = left_part.split('.')
        lefttable_subquery.append(dict_table.get(left_part_tbl, '')[:6].lower() == 'select' and dict_table[left_part_tbl] or '')
    else:
        left_part_tbl, left_part_clm = '', left_part
        lefttable_subquery.append('')

    if '.' in right_part:
        right_part_tbl, right_part_clm = right_part.split('.')
        righttable_subquery.append(dict_table.get(right_part_tbl, '')[:6].lower() == 'select' and dict_table[right_part_tbl] or '')
        
    else:
        right_part_tbl, right_part_clm = '', right_part
        righttable_subquery.append('')
    if left_part_tbl == '' and right_part_tbl == '':
        joining_columns.append(
            f"{left_part_clm},{right_part_clm}")
    elif left_part_tbl == '':
        joining_columns.append(
            f"{left_part_clm},{right_part_tbl if dict_table[right_part_tbl].lower().startswith('select') else dict_table[right_part_tbl]}.{right_part_clm}"
        )
    elif right_part_tbl == '':
        joining_columns.append(
            f"{left_part_tbl if dict_table[left_part_tbl].lower().startswith('select') else dict_table[left_part_tbl]}.{left_part_clm},{right_part_clm}"
        )
    else:
        joining_columns.append(
            f"{left_part_tbl if dict_table[left_part_tbl].lower().startswith('select') else dict_table[left_part_tbl]}.{left_part_clm},"
            f"{right_part_tbl if dict_table[right_part_tbl].lower().startswith('select') else dict_table[right_part_tbl]}.{right_part_clm}"
        )
    # print(len(joining_type), len(joining_condition), len(joining_columns_ta), len(joining_columns), len(lefttable_subquery), len(righttable_subquery))


    

    

def join_condition(input_sql, report_name):
    """
    Processes SQL input and extracts join conditions into a DataFrame.

    :param input_sql: The SQL query to analyze.
    :param report_name: The name of the report for identification.
    :return: A DataFrame containing join conditions.
    """
    parsed_sql = parse(input_sql)
    dict_table = {}
    joining_info = {
        "type": [], "condition": [], "columns_ta": [],
        "columns": [], "left_subquery": [], "right_subquery": []
    }

    # Initialize joining_columns
    joining_columns = []

    if parsed_sql.get('from'):
        for join in parsed_sql['from']:
            extract_join_info(join, dict_table, joining_info["type"], joining_info["condition"],
                              joining_info["columns_ta"], joining_columns,
                              joining_info["left_subquery"], joining_info["right_subquery"])

    else:
        # Will do if any case arises like this
        # if list(n.keys()) == ['select']:
        #     logging.info(f'Skipped following query in join parsing {n}')
        #     return pd.DataFrame(columns=['Report_Name', 'Union/Intersection/Minus if any', 'Joining_Type',
        #                        'Joining_Condition', 'Joining_Columns_With_Table_Alias_if_any',
        #                        'Joining_Columns_Having_Proper_Table_Details', 
        #                        'LeftTable_Subquery_if_any', 'RightTable_Subquery_if_any'])
        try:
            for m in parsed_sql.keys():
                for index, n in enumerate(parsed_sql[m]):
                    if n.get('from'):
                        for p in n['from']:
                            extract_join_info(p, dict_table, joining_info["type"], joining_info["condition"],
                                            joining_info["columns_ta"], joining_columns,
                                            joining_info["left_subquery"], joining_info["right_subquery"])
                            joining_info["columns"].append(f"{m}_{index}")
                    else:
                        if list(n.keys()) == ['select']:
                            logging.info(f'Skipped following query in join parsing {n}')
                            continue
                        if 'from' not in n.get('value'):
                            op = list(n.get('value').keys())[0]
                            val = n.get('value').get(op)
                            for index, f in enumerate(val):
                                for p in f.get('from'):
                                    extract_join_info(p, dict_table, joining_info["type"], joining_info["condition"],
                                                    joining_info["columns_ta"], joining_columns,
                                                    joining_info["left_subquery"], joining_info["right_subquery"])
                                    joining_info["columns"].append(f"{m}_{index}")
                        else:
                            for p in n.get('value').get('from'):
                                extract_join_info(p, dict_table, joining_info["type"], joining_info["condition"],
                                                joining_info["columns_ta"], joining_columns,
                                                joining_info["left_subquery"], joining_info["right_subquery"])
                                joining_info["columns"].append(f"{m}_{index}")

        except KeyError as e:
            logging.error(e,"is not found")
        except Exception as e:
            logging.error("Error processing secondary join conditions", exc_info=True)
    df_join_condition = pd.DataFrame({
        'Joining_Type': joining_info["type"],
        'Joining_Condition': joining_info["condition"],
        'Joining_Columns_With_Table_Alias_if_any': joining_info["columns_ta"],
        'Joining_Columns_Having_Proper_Table_Details': joining_columns,
        'LeftTable_Subquery_if_any': joining_info["left_subquery"],
        'RightTable_Subquery_if_any': joining_info["right_subquery"],
    })
    
    df_join_condition['Report_Name'] = report_name
    df_join_condition['Union/Intersection/Minus if any'] = 'NA'
    df_join_condition = df_join_condition.drop_duplicates()
        
    return df_join_condition[['Report_Name', 'Union/Intersection/Minus if any', 'Joining_Type',
                               'Joining_Condition', 'Joining_Columns_With_Table_Alias_if_any',
                               'Joining_Columns_Having_Proper_Table_Details', 
                               'LeftTable_Subquery_if_any', 'RightTable_Subquery_if_any']]
 

def calculate_matching_percentages(common_tables, all_tables, common_joins, all_joins):
    table_percent_match = (len(common_tables) / len(all_tables)) * 100 if all_tables else 0
    join_percent_match = (len(common_joins) / len(all_joins)) * 100 if all_joins else 0
    overall_matching_percentage = table_percent_match if len(all_joins) == 0 else (table_percent_match + join_percent_match) / 2
    return table_percent_match, join_percent_match, overall_matching_percentage



def analyze_reports(overall_df, report_files):
    overall_df['Report Combinations'] = overall_df['Report Combinations'].apply(lambda x: x.split(','))

    # Extract the matching files from the DataFrame as a list of lists
    matching_files = overall_df['Report Combinations'].tolist()
    # Mapping of files to reports
    file_to_report = {}
    for report, files in report_files.items():
        for file in files:
            file_to_report[file] = report
    
    # Analyze full report merging
    report_merge_candidates = defaultdict(set)
    file_merge_map = defaultdict(set)

    for match_group in matching_files:
        # For each group of matching files, associate all reports
        reports_in_group = set(file_to_report[file] for file in match_group)
        for file in match_group:
            report = file_to_report[file]
            report_merge_candidates[report].update(reports_in_group - {report})
            file_merge_map[report].update(match_group)
    
    fully_mergeable_reports = set()
    partially_mergeable_reports = defaultdict(list)
    within_report_consolidations = defaultdict(list)

    for report, files in report_files.items():
        matched_files = set(file_merge_map[report])
        if matched_files == set(files):
            # If all files in the report are matched, the report can be fully merged
            fully_mergeable_reports.add(report)
        else:
            # For partial merging, calculate how many files can be merged with others
            partially_merged_files = matched_files.intersection(files)
            if partially_merged_files:
                partially_mergeable_reports[report] = {
                    "total_files": len(files),
                    "matched_files": len(partially_merged_files),
                    "reports_matched_with": report_merge_candidates[report]
                }
        # Intra-report file consolidation analysis
        within_report_matches = [file for file in files if any(other_file in file_merge_map[report] and other_file != file for other_file in files)]
        if within_report_matches:
            within_report_consolidations[report] = within_report_matches
    
    # Output
    print("Reports that can be fully merged and removed:")
    print(fully_mergeable_reports)

    print("\nReports that can be partially merged:")
    for report, merge_info in partially_mergeable_reports.items():
        print(f"Report: {report}")
        print(f"  Total files: {merge_info['total_files']}")
        print(f"  Matched files: {merge_info['matched_files']}")
        print(f"  Reports matched with: {merge_info['reports_matched_with']}")

    print("\nWithin-report file consolidations (files that can be consolidated):")
    for report, consolidations in within_report_consolidations.items():
        print(f"Report: {report}")
        print(f"  Files that can be consolidated: {consolidations}")


def compare(list1, df_table, df_joining):
    dict_compare = {}
    not_recommended_combinations = set()
    for i in range(2, len(list1) + 1): # loop 1
        sheet_name = f'Comparison_{i}'
        df_temp = pd.DataFrame(columns=[
            'Report Combinations', '# of tables across all reports', 
            '# of joining keys across all reports', '# of common tables',
            '# of common joining keys', 'Table Matching Percentage',
            'Joining Keys Matching Percentage', 'Overall Matching Percentage', 
            'Recommendation'
        ])

        break_loop = True
        for combo in combinations(list1, i): # loop 2
            combo_set = frozenset(combo)  
            
            if any(nrc.issubset(combo_set) for nrc in not_recommended_combinations):
                recommendation = 'Reports are having different sets of data, should not get merged'
                df1 = pd.DataFrame({
                    'Report Combinations': ', '.join(combo),
                    '# of tables across all reports': 'N/A',
                    '# of joining keys across all reports': 'N/A',
                    '# of common tables': 'N/A',
                    '# of common joining keys': 'N/A',
                    'Table Matching Percentage': 'N/A',
                    'Joining Keys Matching Percentage': 'N/A',
                    'Overall Matching Percentage': 'N/A',
                    'Recommendation': recommendation
                }, index=[0])
                # df_temp = pd.concat([df_temp, df1], ignore_index=True)
                continue 
            all_tables = set(df_table[df_table['Report_Name'].isin(combo)]['Table'])
            all_joins = set(df_joining[df_joining['Report_Name'].isin(combo)]['Joining_Columns_Having_Proper_Table_Details'])
            
            dict_set = {report: set(df_table[df_table['Report_Name'] == report]['Table']) for report in combo}
            dict_join = {report: set(df_joining[df_joining['Report_Name'] == report]['Joining_Columns_Having_Proper_Table_Details']) for report in combo}

            common_tables = set.intersection(*(dict_set.values()))
            common_joins = set.intersection(*(dict_join.values()))

            table_percent_match, join_percent_match, overall_matching_percentage = calculate_matching_percentages(
                common_tables, all_tables, common_joins, all_joins
            )
            
            recommendation = 'Reports can be merged' if overall_matching_percentage > 80 else 'Reports are having different sets of data, should not get merged'

            if overall_matching_percentage < 80:
                not_recommended_combinations.add(combo_set)
            else:
                break_loop = False

            df1 = pd.DataFrame({
                'Report Combinations': ','.join(combo),
                '# of tables across all reports': len(all_tables),
                '# of joining keys across all reports': len(all_joins),
                '# of common tables': len(common_tables),
                '# of common joining keys': len(common_joins),
                'Table Matching Percentage': round(table_percent_match, 2),
                'Joining Keys Matching Percentage': round(join_percent_match, 2),
                'Overall Matching Percentage': round(overall_matching_percentage, 2),
                'Recommendation': recommendation
            }, index=[0])
            if df_temp.empty:
                df_temp = df1
            else:
                df_temp = pd.concat([df_temp, df1], ignore_index=True)

        if break_loop and sheet_name != 'Comparison_2':
            break
        dict_compare[sheet_name] = df_temp
    return dict_compare

def column_details(input_sql, report_name):
    try:
        cols = Parser(input_sql).columns
    except Exception as e:
        try:
            parsed = parse(input_sql)
            cols = extract_with(parsed)
        except Exception as ex:
            logging.error(f"Error parsing query in column details: {input_sql}",exc_info=True)
            tabs = Parser(input_sql).tables
            cols = [i+'.' for i in tabs]
    list1 = [i.split('.') for i in cols]
    if list1 == []:
        return pd.DataFrame(columns=['Report_Name', 'Db', 'Schema', 'Table', 'Column'])
    for i in list1:
        if len(i) == 3:
            i.insert(0, '')
    list3 = [['', ''] + i if len(i) == 2 else  ['', '', ''] + i  if len(i)==1 else i for i in list1]
    #list3 = [['', '', ''] + i if len(i) == 1 else i for i in list1]
    if len(Parser(input_sql).tables)>0:
        table_var = Parser(input_sql).tables[0].split('.')
    else:
        table_var = []
    if len(table_var) == 3:
        db, schema, table = table_var
    elif len(table_var) == 2:
        db = ''
        schema, table = table_var
    elif len(table_var) == 1:
        db = ''
        schema = ''
        table = table_var[0]
    else:
        db, schema, table = '','',''

    if (len(Parser(input_sql).tables) == 1) & (max([len(i) for i in list3]) == 1):
        list4 = [[db, schema, table] + i for i in list3]
        df_column_temp = (
            pd.DataFrame(list4)
            .dropna()
            .apply(lambda col: col.map(lambda x: x.lower() if isinstance(x, str) else x))
            .drop_duplicates()
        )
        df_column_temp = df_column_temp.iloc[:, -4:]
        df_column_temp.columns = ['Db', 'Schema', 'Table', 'Column']
        df_column_temp['Report_Name'] = report_name
        df_column_temp = df_column_temp[['Report_Name', 'Db', 'Schema', 'Table', 'Column']]
    else:
        df_column_temp = (
            pd.DataFrame(list3)
            .dropna()
            .apply(lambda col: col.map(lambda x: x.lower() if isinstance(x, str) else x))
            .drop_duplicates()
        )
        df_column_temp = df_column_temp.iloc[:, -4:]
        df_column_temp.columns = ['Db', 'Schema', 'Table', 'Column']
        df_column_temp['Report_Name'] = report_name
        df_column_temp = df_column_temp[['Report_Name', 'Db', 'Schema', 'Table', 'Column']]
    
    # Identify rows where 'Table' is not empty
    # non_empty_table = df_column_temp[df_column_temp['Table'] != '']

    # # Identify rows where 'Table' is empty and there exists a matching (Report_Name, Column)
    # remove_rows = df_column_temp[
    #     (df_column_temp['Table'] == '') &
    #     df_column_temp[['Report_Name', 'Column']].apply(tuple, axis=1).isin(
    #         non_empty_table[['Report_Name', 'Column']].apply(tuple, axis=1)
    #     )
    # ]
    # # Remove the identified rows
    # df_column_temp = df_column_temp.drop(remove_rows.index)
    return df_column_temp

def remove_subset_files(df):
    df['set_values'] = df['Report Combinations'].apply(lambda x: set(x.split(',')))

    # Remove duplicates and subsets
    df = df.drop_duplicates(subset='set_values', keep='first')  # Remove exact duplicates
    df_cleaned = df[~df['set_values'].apply(lambda x: any(x < y for y in df['set_values']))]

    # Output the cleaned DataFrame
    df_cleaned = df_cleaned.drop(columns=['set_values']).reset_index(drop=True)

    return df_cleaned


def create_sql_files(twb_path: str, output_dir: str = "sql_for_comparison") -> None:
    """
    Extracts SQL statements from Tableau workbook XML files and saves them in individual SQL files.

    Args:
        path (List[str]): List of file paths to Tableau workbook XML files.
        output_dir (str): Directory where SQL files will be saved. Defaults to "sql_for_comparison".

    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    twb_files = os.listdir(twb_path)
    for twb_file in twb_files:
        full_path = os.path.join(twb_path, twb_file)
        tree = ET.parse(full_path) 
        root = tree.getroot()
        sql = root.findall(".//relation")
        unique_sql = {elem.text.strip() for elem in sql if elem.text}
        
        for index, datasource in enumerate(unique_sql):
            if datasource:
                datasource_text = datasource.strip().replace('//', '--')\
                    .replace('<<', '<').replace('>>', '>').replace('::', ':')
                filename = f"{twb_file.removesuffix('.twb')}-sql{index}.sql"
                filepath = os.path.join(output_dir, filename)
                
                # Write `datasource_text` to the file
                with open(filepath, 'w') as file:
                    file.write(datasource_text)
    return output_dir

def remove_sql_files(output_dir: str) -> None:
    """
    Removes all SQL files in the specified directory.

    Args:
        output_dir (str): Directory containing SQL files to remove.

    """
    if os.path.exists(output_dir):
        for file in os.listdir(output_dir):
            os.remove(os.path.join(output_dir, file))



if __name__ == "__main__":
    sql_for_comparison = 'sql_for_comparison'
    remove_sql_files(sql_for_comparison)
    input_file_path = getconfig()['input_path']
    sql_file_path = create_sql_files(input_file_path, sql_for_comparison)
    files = os.listdir(sql_file_path) 
    report = {}

    for i in files:
        before_dash = i.split('-')[0].strip()
        if before_dash not in report:
            report[before_dash] = []
        report[before_dash].append(i)
        print(i)
        report_name = i
        with open(os.path.join(sql_file_path, i), 'r') as file:
            input_sql = file.read()
        # input_sql = input_sql.replace('"', '')
        input_sql = input_sql.replace('[', ' ').replace(']',' ') 
        df_column_temp = column_details(input_sql, report_name)
        input_sql = input_sql.replace('[', '')
        input_sql = input_sql.replace(']', '')
        df_table_temp = table(input_sql, report_name)
        df_join_condition_temp = join_condition(input_sql, report_name)
        df_table = pd.concat([df_table, df_table_temp])
        df_join_condition = pd.concat([df_join_condition, df_join_condition_temp])
        df_column = pd.concat([df_column,df_column_temp])

    
    # only when we have completed parsed output
    # df_column = pd.read_excel('SQL Report Analysis-Consolidation with deletion 23-10.xlsx')

    
    df_column['Table'] = df_column['Table'].fillna('').str.lower()
    lookup['table'] = lookup['Name'].fillna('').str.lower()
    df_column['tab'] = df_column['Table'].str.split('.').str[-1]
    merged_df = df_column.merge(lookup, left_on='tab', right_on='table', how='left', suffixes=('', '_lookup'))
    merged_df['Db'] = merged_df['Db'].replace('',pd.NA).fillna(merged_df['db']).str.lower()
    merged_df['Db'] = merged_df['Db'].str.lower()
    merged_df.drop(columns=['Name', 'Context','tab','table', 'loc', 'db', 'schema'], inplace=True)
    merged_df.drop_duplicates(inplace=True)
    excel_formatting_metadata(df_table, merged_df, df_join_condition, '', 'Metadata_extract')
    db_dict = {}
    for index, row in merged_df.iterrows():
        sql_file = row.iloc[0]
        database = row.iloc[1]
        table_ = row.iloc[3] 
        
        if pd.isna(database) or database == '':
            continue
        
        if database not in db_dict:
            db_dict[database] = {}
        
        if table_ not in db_dict[database]:
            db_dict[database][table_] = set()
        
        db_dict[database][table_].add(sql_file)

    df_temp = pd.DataFrame(columns=[
            'Report Combinations', '# of tables across all reports', 
            '# of joining keys across all reports', '# of common tables',
            '# of common joining keys', 'Table Matching Percentage',
            'Joining Keys Matching Percentage', 'Overall Matching Percentage', 
            'Recommendation'
        ])
    
    for db in db_dict:
        for table_ in db_dict[db]:
            output_file_name = f"{db}_{table_}" + '_' + datetime.now().strftime('%Y%m%d') + '_' + datetime.now().strftime('%H%M') + '.xlsx'
            combos = list(db_dict[db][table_])
            if 20 > len(combos) > 1:
                dict_compare = compare(combos, df_table, df_join_condition)
                for sheet_name in dict_compare.keys():
                    sheet = dict_compare[sheet_name]
                    df1 = sheet.loc[sheet['Overall Matching Percentage']>=80,:]
                    if df_temp.empty:
                        df_temp = df1
                    else:
                        df_temp = pd.concat([df_temp, df1], ignore_index=True)
                # excel_formatting(dict_compare, output_file_name)
                
            else:
                # print(f"{db}_{table_}")
                pass
    di = {}
    if not df_temp.empty:
        df_temp.drop_duplicates(inplace=True)     
        df_temp = remove_subset_files(df_temp)
        df_temp['Total files'] = df_temp['Report Combinations'].apply(lambda x: len(x.split(',')))
        col = df_temp.pop('Total files')  # Remove the column
        df_temp.insert(1, 'Total files', col)
        di['Overall'] = df_temp
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        output_file_name = f"Overall_{timestamp}.xlsx"
        excel_formatting(di, output_file_name)
        print("Total number of files :",len(files))
        print(df_temp)
        print(f"{df_temp['Total files'].sum()} files can be consolidated to {df_temp.shape[0]}")
        # Run the analysis
        analyze_reports(df_temp, report)
    else:
        print("No Reports can be consolidated")
    


ADS Channel Summary Weekly-sql0.sql
ADS Channel Summary-sql0.sql
ADS vs ESF vs SC Dashboard-sql0.sql
Metadata_extract_20241223_1749.xlsx
Total number of files : 3
                                 Report Combinations  Total files  \
0  ADS Channel Summary Weekly-sql0.sql,ADS Channe...            2   

   # of tables across all reports  # of joining keys across all reports  \
0                              18                                     1   

   # of common tables  # of common joining keys  Table Matching Percentage  \
0                  12                         1                      66.67   

   Joining Keys Matching Percentage  Overall Matching Percentage  \
0                             100.0                        83.33   

          Recommendation  
0  Reports can be merged  
2 files can be consolidated to 1
Reports that can be fully merged and removed:
set()

Reports that can be partially merged:
Report: ADS Channel Summary Weekly
  Total files: 1
  Matched files: 1
  Re

In [20]:
!pip install xlsxwriter

Collecting xlsxwriter
  Downloading XlsxWriter-3.2.0-py3-none-any.whl.metadata (2.6 kB)
Downloading XlsxWriter-3.2.0-py3-none-any.whl (159 kB)
Installing collected packages: xlsxwriter
Successfully installed xlsxwriter-3.2.0


In [7]:
!pip install sql_metadata

Collecting sql_metadata
  Downloading sql_metadata-2.15.0-py3-none-any.whl.metadata (9.8 kB)
Collecting sqlparse<0.6.0,>=0.4.1 (from sql_metadata)
  Downloading sqlparse-0.5.3-py3-none-any.whl.metadata (3.9 kB)
Downloading sql_metadata-2.15.0-py3-none-any.whl (22 kB)
Downloading sqlparse-0.5.3-py3-none-any.whl (44 kB)
Installing collected packages: sqlparse, sql_metadata
Successfully installed sql_metadata-2.15.0 sqlparse-0.5.3


In [5]:
!pip install mo_sql_parsing

Collecting mo_sql_parsing
  Downloading mo_sql_parsing-11.658.24326-py3-none-any.whl.metadata (10 kB)
Collecting mo-dots==10.647.24166 (from mo_sql_parsing)
  Downloading mo_dots-10.647.24166-py3-none-any.whl.metadata (4.9 kB)
Collecting mo-future==7.584.24095 (from mo_sql_parsing)
  Downloading mo_future-7.584.24095-py3-none-any.whl.metadata (2.0 kB)
Collecting mo-imports==7.584.24095 (from mo_sql_parsing)
  Downloading mo_imports-7.584.24095-py3-none-any.whl.metadata (6.1 kB)
Collecting mo-parsing==8.654.24251 (from mo_sql_parsing)
  Downloading mo_parsing-8.654.24251-py3-none-any.whl.metadata (7.8 kB)
Downloading mo_sql_parsing-11.658.24326-py3-none-any.whl (44 kB)
Downloading mo_dots-10.647.24166-py3-none-any.whl (27 kB)
Downloading mo_future-7.584.24095-py3-none-any.whl (10 kB)
Downloading mo_imports-7.584.24095-py3-none-any.whl (11 kB)
Downloading mo_parsing-8.654.24251-py3-none-any.whl (62 kB)
Installing collected packages: mo-future, mo-imports, mo-dots, mo-parsing, mo_sql_pars