In [None]:
import sqlglot
from sqlglot import parse_one, exp
from sqlglot.lineage import lineage
from sqlglot.schema import MappingSchema
from sqlglot.dialects import Snowflake
from sqlglot.optimizer import optimize
from sqlglot.optimizer.scope import traverse_scope
import timeout_decorator



@timeout_decorator.timeout(2)
def optimize_within_time(*args, **kwargs):
    return optimize(*args, **kwargs)
def filter_predicates(sql):
    try:
        optimized = optimize_within_time(parse_one(sql, dialect=Snowflake), dialect=Snowflake)
    except Exception as e:
        return e
    scopes = traverse_scope(optimized)
    print(f"number of scopes: {len(scopes)}")
    preds = []
    for scope in traverse_scope(optimized):
        print(f"scope: {scope}")
        aliases = {}
        for alias in scope.find_all(exp.TableAlias):
            aliases[alias.name] = alias.parent.name
            print(f"alias: {alias.name}, parent: {alias.parent.name}")
        for clause in scope.find_all(exp.Where):
            print(f"where: {clause}")
            filter_columns = []
            for col in clause.find_all(exp.Column):
                print(f"col: {col.name}, table_alias: {col.table}, table: {aliases[col.table]}")
                filter_columns.append({
                    "column": col.name,
                    "table_alias": col.table,
                    "table": aliases.get(col.table)
                })
            preds.append(
                {
                    "where_clause": clause.sql(dialect=Snowflake),
                    "filter_columns": filter_columns
                }
            )
    return preds



In [None]:
# sql = """

# select * from (
#     SELECT 
#       MEMBER_GUID as USER_ID, 
#       CONCAT(
#         '| ', 
#         ARRAY_TO_STRING(GROUP_IDS, ' | '), 
#         ' |'
#       ) AS "Active Member Groups" 
#     FROM 
#       EBATES_PROD.CAMPAIGN_ASSET.MEMBER_DATA_MODEL_GROUP 


# """

sql = """
with a as ( select x,y,z from b) 
select z from a
"""

sql = """
with a as (
    select 
       id,
       original as col1, 
       col2
    from 
       t1
)
select
    col2 as col2_renamed
from
  a
where a.col1 = 3902
and id = 123
"""

filter_predicates(sql)

In [None]:
try:
    optimized = optimize_within_time(parse_one(sql, dialect=Snowflake), dialect=Snowflake)
except Exception as e:
    print(e)
scopes = traverse_scope(optimized)
print(f"number of scopes: {len(scopes)}")
preds = []
for scope in traverse_scope(optimized):
    print(f"scope: {scope}")
    aliases = {}
    for alias in scope.find_all(exp.TableAlias):
        aliases[alias.name] = alias.parent.name
        print(f"alias: {alias.name}, parent: {alias.parent.name}")
    for clause in scope.find_all(exp.Where):
        print(f"where: {clause}")
        filter_columns = []
        for col in clause.find_all(exp.Column):
            print(f"col: {col.name}, table_alias: {col.table}, table: {aliases[col.table]}")
            filter_columns.append({
                "column": col.name,
                "table_alias": col.table,
                "table": aliases.get(col.table)
            })
        preds.append(
            {
                "where_clause": clause.sql(dialect=Snowflake),
                "filter_columns": filter_columns
            }
        )



In [None]:
scopes = traverse_scope(optimized)

In [None]:
len(scopes)