In [None]:
%load_ext autoreload
%autoreload 2
import sys, pathlib
sys.path.append(str(pathlib.Path.cwd().parent.parent))

In [None]:
%reload_ext dotenv

%dotenv ../../env/.env

import warnings
from pandas import Timedelta
# from optiml.utils import sf
import time
from optiml.utils.sf import logger, sql_to_df, run_sql, conn, session
import pandas as pd
warnings.filterwarnings('ignore')

try:
    %load_ext autotime
except:
    !pip install ipython-autotime
    %load_ext autotime

import plotly
plotly.offline.init_notebook_mode()

In [None]:
from optiml.utils.sf import snowsession
session = snowsession()

In [None]:
# tables = [
#     'query_pattern',
#     'pruning_candidate_patterns'
# ]

# drops = [f"drop table if exists {t}" for t in tables]
# for d in drops:
#     run_sql(d)

In [None]:
prep = [
    r"""
create or replace function js_regexp_replace(subject text, pattern text, replacement text)
returns string
language javascript
as
$$
    const p = SUBJECT;
    let regex = new RegExp(PATTERN, 'g')
    return p.replace(regex, REPLACEMENT);
$$;
""",
r"""
set db_unqualify_regex = '\\b(\\w+\\.\\w+\\.)';
""",
r"""
set schema_unqualify_regex = '\\b(\\w+\\.)';
""",
r"""
set redaction_regex = '(\\b\\d+\\b)|(''[^'']*'')';
""",
r"""
create table if not exists query_pattern as

select

    query_id,
    query_text,
    start_time,
     js_regexp_replace(query_text, $$(/\*(.|\n|\r)*?\*/)|(--.*$)|(--.*(\n|\r))$$, '') query_no_comment,
rtrim(regexp_replace(trim(lower(query_no_comment)), '\\s+',' '), ';') as query_text_lowered_no_whitespace,
regexp_replace(query_text_lowered_no_whitespace, $redaction_regex, '\'[literal]\'') as query_text_literals_redacted,
regexp_replace(
    regexp_replace(query_text_literals_redacted, $db_unqualify_regex, ''),
    $schema_unqualify_regex,
    ''
) as query_text_sanitized,
hash(query_text) as query_text_hash,
hash(query_text_sanitized) as query_pattern_hash
from stg_query_history
where start_time > dateadd(day, -14, (select max(start_time) from stg_query_history));;
""",
r"""
create table if not exists pruning_candidate_patterns as
with pruning_candidates as (
    select
    query_id
        -- query_text_sanitized_no_qualification
    from query_history_narrowed
    where filtered_select and no_pruning 
)
select  
    query_pattern_hash,
    query_text_sanitized,
    any_value(query_text) as example
from query_pattern qp
where qp.query_id in (select query_id from pruning_candidates)
group by 1,2;
    """,
r"""
select
count(*), 
count(distinct query_text_hash),
count(distinct query_pattern_hash)
from 
query_pattern;
"""
]

In [None]:
for s in prep:
    sql = s # s.replace("''", "'\\'").replace("\\", "\\\\")
    print(f"running sql: {sql}")
    display(sql_to_df(sql))

In [None]:
df = sql_to_df("select * from pruning_candidate_patterns;")
print(len(df))

In [None]:
%pip install timeout-decorator

In [None]:
import sqlglot
from sqlglot import parse_one, exp
from sqlglot.lineage import lineage
from sqlglot.schema import MappingSchema
from sqlglot.dialects import Snowflake
from sqlglot.optimizer import optimize
from sqlglot.optimizer.scope import traverse_scope
import timeout_decorator



@timeout_decorator.timeout(2)
def optimize_within_time(*args, **kwargs):
    return optimize(*args, **kwargs)
def filter_predicates(sql, debug=False):
    try:
        optimized = optimize_within_time(parse_one(sql, dialect=Snowflake), dialect=Snowflake)
    except Exception as e:
        return e
    
    if debug:
        scopes = traverse_scope(optimized)
        print(f"number of scopes: {len(scopes)}")
    preds = []
    for scope in traverse_scope(optimized):
        if debug:
            print(f"scope: {scope}")
        aliases = {}
        for alias in scope.find_all(exp.TableAlias):
            aliases[alias.name] = alias.parent.name
            if debug: print(f"alias: {alias.name}, parent: {alias.parent.name}")
        for clause in scope.find_all(exp.Where):
            if debug: print(f"where: {clause}")
            filter_columns = []
            for col in clause.find_all(exp.Column):
                if debug: print(f"col: {col.name}, table_alias: {col.table}, table: {aliases[col.table]}")
                filter_columns.append({
                    "column": col.name,
                    "table_alias": col.table,
                    "table": aliases.get(col.table)
                })
            preds.append(
                {
                    "where_clause": clause.sql(dialect=Snowflake),
                    "filter_columns": filter_columns
                }
            )
    return preds



In [None]:
sql = """
SELECT 
  *
from a
where col = (select 1)
"""
filter_predicates(sql, debug=True)

In [None]:
import time
@timeout_decorator.timeout(0.1)
def dosomething():
    time.sleep(.11)

dosomething()

In [None]:
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True, nb_workers=8)


In [None]:

# df2 = df.head(10000)

df['filter_predicates_dict'] = df.example.parallel_map(filter_predicates)
df

In [None]:
# df2[df2.filter_predicates_dict.map(str) == "'Timed Out'"].example.iloc[0]

In [None]:
len(df[df.filter_predicates_dict.map(str) == "'Timed Out'"])

In [None]:
import snowflake

import snowflake.snowpark.functions as F
import snowflake.snowpark.dataframe


In [None]:
import json

def json_string(val):
    try:
        return json.dumps(val)
    except Exception as e:
        return str(val)

df['filter_predicates'] = df.filter_predicates_dict.map(json_string) 


In [None]:
snowdf = session.create_dataframe(df.drop(columns=['filter_predicates_dict']))

def uppercase_all_columns(df: snowflake.snowpark.dataframe) -> snowflake.snowpark.dataframe:
    return df.select([F.col(column).as_(column.upper()) for column in df.columns])


uppercase_all_columns(snowdf).write.mode("overwrite").save_as_table("filter_predicates_limit")


In [None]:
analytics = [
r"""
    select 
    -- try_parse_json(filter_predicates) is not null as sql_parsed,
    case 
        when contains(filter_predicates, 'could not be resolved') then 'column could not be resolved'
        when try_parse_json(filter_predicates) is not null then 'successful sql parse'
        else 'other'
    end as parse_status,
    count(*)
from filter_predicates_limit
group by all;
""",

r"""
create or replace table query_filters_flattened as 
select 
    query_pattern_hash,
    example,
    query_text_sanitized,
    try_parse_json(filter_predicates) as filters_parsed,
    filters_parsed['aliases'] as aliases,
    -- filters_parsed['num_optimized_scopes'] as num_optimized_scopes,
    
    f.value['where_clause']::text as where_clause,
    f2.*,
    f2.value['column']::text as col_name,
    f2.value['table']::text as table_name,
    row_number() over (partition by query_pattern_hash, where_clause, table_name, col_name order by null) as col_reference_number
from filter_predicates_limit q,
   lateral flatten(input => filters_parsed) f
   ,lateral flatten(input => value:filter_columns) f2;
""",
r"""
with query_pattern_cluster_key_candidate as
(
select
    query_pattern_hash,
    table_name,
    col_name,
    any_value(query_text_sanitized),
    count(*) as num_column_references
from query_filters_flattened
where table_name is not null
group by 1,2,3
)
-- select * from query_pattern_cluster_key_candidate
-- where table_name is null limit 10;
select 
    count(*), 
    count(distinct query_pattern_hash, table_name, col_name) 
from query_pattern_cluster_key_candidate
;
""",
r"""
create or replace table query_cluster_key_candidate as
with query_pattern_cluster_key_candidate as
(
select
    query_pattern_hash,
    table_name,
    col_name,
    any_value(query_text_sanitized) query_text_sanitized,
    any_value(where_clause) where_clause_example,
    count(*) as num_column_references
from query_filters_flattened
where table_name is not null
group by all
),
tables_accessed_by_query as (
    select 
        query_id,
        array_agg(lower(table_name)) tables_accessed
    from base_object_access_event
    where objectdomain = 'Table'
    group by 1
)

-- select * from query_filters_flattened limit 10;
select
    qp.query_id,
    qff.query_pattern_hash,
    qff.query_text_sanitized,
    qp.query_text,
    qff.table_name,
    qff.col_name,
    where_clause_example,
    tables_accessed,
    array_contains(lower(qff.table_name)::variant, tables_accessed) as is_source_table
from query_pattern_cluster_key_candidate qff
left join query_pattern qp
on qff.query_pattern_hash = qp.query_pattern_hash
left join tables_accessed_by_query taq
on qp.query_id = taq.query_id;
""",
r"""
select
    is_source_table,
    count(*)
from
    query_cluster_key_candidate
group by all;
""",
r"""
select 
    count(*), 
    count(distinct query_id, table_name, col_name), 
    count(distinct query_id) 
from query_cluster_key_candidate;
""",
r"""
create or replace table cluster_key_report as
select 
    table_name,
    col_name as pruning_key_candidate,
    
    -- metadata/observations
    'fill me in' as is_current_clustering_key,
    count(distinct query_pattern_hash) num_query_patterns,
    count(distinct qckc.query_id) as num_queries,
    count(*) as num_hits, --todo: should be same as num_queries - table_name/col_name/query_id should be a unique key
    sum(execution_time_s)/(3600) total_latency_hours,
    avg(execution_time_s) avg_latency_sec,
    median(execution_time_s) median_latency_sec,
    min(start_time) earliest_hit_time,
    max(start_time) latest_hit_time,

    any_value(where_clause_example) where_clause_example,
    sum(query_cost) total_query_cost,
    avg(query_cost) avg_query_cost,

    -- filter sources
    any_value(qckc.query_text_sanitized) as example_query_pattern,
    array_unique_agg(role_name) as roles_hitting,
    'fill me in' as apps_hitting

    -- helper_sql_manual_cluster
    -- helper_sql_auto_cluster

from query_cluster_key_candidate qckc
left join query_history_enriched qhe
on qckc.query_id = qhe.query_id
where is_source_table
group by 1,2;
"""
]

In [None]:
for sql in analytics:
    print(f"running sql: {sql}")
    display(sql_to_df(sql))