In [4]:
import os
import sqlglot
from sqlglot.expressions import Table
from sql_metadata import Parser

def extract_tables_from_sql(query: str):
    try:
        parser = Parser(query)
        return parser.tables
    except Exception as e:
        print(f"❌ Error parsing query: {e}")
        return []

sql_folder = "/Users/paolocadei/Documents/Masters/Thesis/Spider2/spider2-snow/evaluation_suite/gold/sql"

def load_sql_files_from_folder(folder_path):
    sql_files = {}
    for file_name in os.listdir(folder_path):
        if file_name.endswith(".sql"):
            full_path = os.path.join(folder_path, file_name)
            with open(full_path, "r", encoding="utf-8") as f:
                sql_files[file_name] = f.read()
    return sql_files

# Load SQL files
sql_queries = load_sql_files_from_folder(sql_folder)

# Extract tables
table_map = {
    file_name: extract_tables_from_sql(query)
    for file_name, query in sql_queries.items()
}

# Print the result
for file, tables in table_map.items():
    print(f"{file}: {tables}")


sf_local285.sql: ['BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_TXN_DF', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_WHSLE_DF', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_CAT', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_LOSS_RATE_DF']
sf_bq158.sql: ['PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.CLINICAL_PANCAN_PATIENT_WITH_FOLLOWUP_FILTERED', 'PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.MC3_MAF_V5_ONE_PER_TUMOR_SAMPLE']
sf_bq429.sql: ['CENSUS_BUREAU_ACS_2.CENSUS_BUREAU_ACS.ZIP_CODES_2018_5YR', 'CENSUS_BUREAU_ACS_2.CENSUS_BUREAU_ACS.ZIP_CODES_2015_5YR', 'CENSUS_BUREAU_ACS_2.CENSUS_BUREAU_ACS.ZIP_CODES_2017_5YR', 'CENSUS_BUREAU_ACS_2.GEO_US_BOUNDARIES.ZIP_CODES', 'base_census']
sf_bq159.sql: ['PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.CLINICAL_PANCAN_PATIENT_WITH_FOLLOWUP_FILTERED', 'PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.MC3_MAF_V5_ONE_PER_TUMOR_SAMPLE']
sf_bq213.sql: ['PATENTS.PATENTS.PUBLICATIONS', 'LATERAL', 'input']
sf_bq167.sql: ['META_KAGGLE.META_KAGGLE.FORUMMESSAGEVOTES', 'META_KAG

In [7]:
import sqlglot
from sqlglot.expressions import Table, Column

def extract_fully_qualified_tables_and_columns(query: str):
    try:
        expression = sqlglot.parse_one(query, read="snowflake")  # or "bigquery", "postgres", etc.

        # Extract fully qualified table names
        tables = set()
        for table in expression.find_all(Table):
            db = table.catalog or ''
            schema = table.db or ''
            name = table.name
            parts = [p for p in [db, schema, name] if p]
            tables.add(".".join(parts))

        # Extract column names
        columns = list({col.name for col in expression.find_all(Column)})

        return {
            "tables": sorted(tables),
            "columns": sorted(columns)
        }

    except Exception as e:
        print(f"❌ Error parsing query: {e}")
        return {"tables": [], "columns": []}


In [8]:
sql_queries = load_sql_files_from_folder(sql_folder)

results = {
    file_name: extract_fully_qualified_tables_and_columns(query)
    for file_name, query in sql_queries.items()
}

# Print sample
for file, info in results.items():
    print(f"{file}:")
    print("  Tables: ", info["tables"])
    print("  Columns:", info["columns"])


sf_local285.sql:
  Tables:  ['BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_CAT', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_LOSS_RATE_DF', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_TXN_DF', 'BANK_SALES_TRADING.BANK_SALES_TRADING.VEG_WHSLE_DF', 'final_item', 'item_2020', 'item_2021', 'item_2022', 'item_2023']
  Columns: ['avg_loss_rate_pct', 'category_code', 'category_name', 'item_code', 'loss_rate_%', 'qty_sold(kg)', 'selling_price', 'txn_date', 'unit_selling_px_rmb/kg', 'whole_sale_price', 'whsle_date', 'whsle_px_rmb-kg']
sf_bq158.sql:
  Tables:  ['PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.CLINICAL_PANCAN_PATIENT_WITH_FOLLOWUP_FILTERED', 'PANCANCER_ATLAS_1.PANCANCER_ATLAS_FILTERED.MC3_MAF_V5_ONE_PER_TUMOR_SAMPLE', 'percentages', 'summ_table', 'table1', 'table2']
  Columns: ['FILTER', 'Hugo_Symbol', 'Nij', 'ParticipantBarcode', 'Study', 'acronym', 'bcr_patient_barcode', 'data1', 'data2', 'histological_type', 'mutation_percentage', 'symbol']
sf_bq429.sql:
  Tables:  ['CENSUS_BUREAU_ACS_2

In [9]:
import json
with open("gold_col_retrieval.json", "w") as f:
    json.dump(results, f, indent=2)
