# Query Transformation Tests

In [1]:
import collections
import itertools
import pathlib
import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import mo_sql_parsing as mosp
import natsort
import psycopg2

import numpy as np
import pandas as pd

In [2]:
workloads = ["explicit", "implicit"]
queries_src_dir = pathlib.Path("../simplicity-done-right/JOB-Queries/")

In [3]:
df_data = collections.defaultdict(list)
for workload in workloads:
    workload_path = queries_src_dir / workload
    query_files = list(workload_path.glob("*.sql"))
    df_data["label"].extend(query_file.stem for query_file in query_files)
    df_data["workload"].extend(itertools.repeat(workload, len(query_files)))
    df_data["query"].extend(query_file.read_text().replace("\n", " ").lower() for query_file in query_files)
df_queries = pd.DataFrame(df_data)
df_queries = df_queries.sort_values(by="label", key=lambda _: np.argsort(natsort.index_natsorted(df_queries["label"]))).reset_index(drop=True)
df_queries["mosp_query"] = df_queries["query"].apply(mosp.parse)

In [4]:
df_queries

Unnamed: 0,label,workload,query,mosp_query
0,1a,explicit,select count(*) from movie_companies as mc j...,"{'select': {'value': {'count': '*'}}, 'from': ..."
1,1a,implicit,"select count(*) from company_type as ct, ...","{'select': {'value': {'count': '*'}}, 'from': ..."
2,1b,explicit,select count(*) from movie_info_idx as mi_idx...,"{'select': {'value': {'count': '*'}}, 'from': ..."
3,1b,implicit,"select count(*) from company_type as ct, ...","{'select': {'value': {'count': '*'}}, 'from': ..."
4,1c,explicit,select count(*) from movie_companies as mc j...,"{'select': {'value': {'count': '*'}}, 'from': ..."
...,...,...,...,...
221,33a,implicit,"select count(*) from company_name as cn1, ...","{'select': {'value': {'count': '*'}}, 'from': ..."
222,33b,explicit,select count(*) from movie_link as ml join l...,"{'select': {'value': {'count': '*'}}, 'from': ..."
223,33b,implicit,"select count(*) from company_name as cn1, ...","{'select': {'value': {'count': '*'}}, 'from': ..."
224,33c,explicit,select count(*) from movie_link as ml join l...,"{'select': {'value': {'count': '*'}}, 'from': ..."


In [5]:
df_expl = df_queries[df_queries.workload == "explicit"]
df_subqueries = df_expl[df_expl["query"].str.rfind("select") > df_expl["query"].str.find("join")]

In [6]:
class QueryUpdate:
    def __init__(self, base_table="", ):
        self.base_table = base_table
        self.table_renamings = collections.defaultdict(int)
        self.table_references = list()
        self.predicates = list()

    def include_subquery(self, subquery):

        # first up, build the rename map
        tables_in_sq = self._collect_tables(subquery)
        for table in tables_in_sq:
            self.table_renamings[table] += 1

    def _collect_tables(self, subquery):
        tables = []
        for clause in subquery["join"]["value"]["from"]:
            if "join" in clause:
                tables.append(clause["join"]["value"])
            else:
                tables.append(clause["value"])
        return tables

    def __str__(self) -> str:
        return f"Tables: {self.table_references}, Predicates: {self.predicates}"

In [8]:
def extract_subqueries(plan):
    from_clause = plan["from"]
    query_update = None
    sq_found = False

    # for each reference in the 'from' clause, check if it constitutes a subquery reference
    for table in from_clause:

        # extract the base table name and proceed with the joined tables
        if not isinstance(table, dict) or "join" not in table:
            query_update = QueryUpdate(table)
            continue

        join_target = table["join"]["value"]
        if isinstance(join_target, dict) and "select" in join_target:
            #print("Found subquery:", join_target)
            sq_found = True
            break
            query_update.include_subquery(table)

    return(sq_found)

In [9]:
df_nosq = df_expl[~df_expl.mosp_query.apply(extract_subqueries)]

NameError: name 'QueryUpdate' is not defined

In [7]:
conn = psycopg2.connect("dbname=imdb user=rico host=localhost")
cur = conn.cursor()

In [8]:
q_raw = df_subqueries.iloc[0]["query"]
q_raw

"select count(*) from  movie_companies as mc  join company_type as ct on (ct.kind = 'production companies' and ct.id = mc.company_type_id and mc.note not like '%(as metro-goldwyn-mayer pictures)%' and (mc.note like '%(co-production)%'  or mc.note like '%(presents)%')) join title as t on (t.id = mc.movie_id) join  (select movie_id from movie_info_idx as mi_idx  join info_type as it on (it.info = 'top 250 rank' and it.id = mi_idx.info_type_id)) as t_mi_idx  on(t_mi_idx.movie_id = mc.movie_id);"

In [9]:
@dataclass
class TableRef:
    full_name: str
    alias: str
    
    def __repr__(self):
        return str(self)
    
    def __str__(self):
        return f"{self.full_name} AS {self.alias}"


@dataclass
class AttributeRef:
    src_table: TableRef
    attribute: str
        
    def __repr__(self):
        return str(self)
    
    def __str__(self):
        return f"{self.src_table.alias}.{self.attribute}"


@dataclass
class JoinStatement:
    target: TableRef
    on: str
    
    def expand(self, on):
        pass
    
    def __repr__(self):
        return str(self)
    
    def __str__(self):
        return f"JOIN {self.target} ON {self.on}"

In [10]:
class DBSchema:
    def __init__(self, cursor: "psycopg2.cursor"):
        self.cursor = cursor
        
    def lookup_attribute(self, attribute_name: str, candidate_tables: List[TableRef]) -> AttributeRef:
        for table in candidate_tables:
            columns = self._fetch_columns(table.full_name)
            if attribute_name in columns:
                return table
        raise KeyError(f"Attribute not found: {attribute_name} in candidates {candidate_tables}")

    def _fetch_columns(self, table_name):
        base_query = "SELECT column_name FROM information_schema.columns WHERE table_name = %s"
        cur.execute(base_query, (table_name,))
        result_set = cur.fetchall()
        return [col[0] for col in result_set]

In [11]:
def collect_attributes(projection):
    if not isinstance(projection, list):
        return [projection["value"]]
    return [col["value"] for col in projection]

In [12]:
def extract_table_references(from_clause):
    references = []
    
    for table_ref in from_clause:
        table_name, table_alias = "", ""
        
        # base table reference
        if "value" in table_ref:
            table_name = table_ref["value"]
            table_alias = table_ref.get("name", table_name)
        # joined table
        elif "join" in table_ref:
            join_target = table_ref["join"]
            table_name = join_target["value"]
            table_alias = join_target.get("name", table_name)
            
        references.append(TableRef(table_name, table_alias))
    
    return references

In [13]:
def bind_attributes(raw_attributes: List[str], candidate_tables: List[TableRef], *, dbschema) -> List[AttributeRef]:
    bindings = []
    for attribute in raw_attributes:
        src_table = dbschema.lookup_attribute(attribute, candidate_tables)
        bindings.append(AttributeRef(src_table, attribute))
    return bindings

In [14]:
q = df_subqueries.iloc[0].mosp_query
q

{'select': {'value': {'count': '*'}},
 'from': [{'value': 'movie_companies', 'name': 'mc'},
  {'join': {'value': 'company_type', 'name': 'ct'},
   'on': {'and': [{'eq': ['ct.kind', {'literal': 'production companies'}]},
     {'eq': ['ct.id', 'mc.company_type_id']},
     {'not_like': ['mc.note',
       {'literal': '%(as metro-goldwyn-mayer pictures)%'}]},
     {'or': [{'like': ['mc.note', {'literal': '%(co-production)%'}]},
       {'like': ['mc.note', {'literal': '%(presents)%'}]}]}]}},
  {'join': {'value': 'title', 'name': 't'},
   'on': {'eq': ['t.id', 'mc.movie_id']}},
  {'join': {'value': {'select': {'value': 'movie_id'},
     'from': [{'value': 'movie_info_idx', 'name': 'mi_idx'},
      {'join': {'value': 'info_type', 'name': 'it'},
       'on': {'and': [{'eq': ['it.info', {'literal': 'top 250 rank'}]},
         {'eq': ['it.id', 'mi_idx.info_type_id']}]}}]},
    'name': 't_mi_idx'},
   'on': {'eq': ['t_mi_idx.movie_id', 'mc.movie_id']}}]}

In [39]:
def head(lst):
    if not len(lst):
        raise ValueError("List is empty")
    return lst[0]

def unwrap_singular_dict(d, *, target="value"):
    if target == "value":
        return head(list(d.values()))
    elif target == "key":
        return head(list(d.keys()))
    else:
        raise KeyError("Unknown target: " + target)

In [48]:
def extract_filter_predicates(predicate_tree, base_table):
    # check if we moved down the conjunction/disjunction tree as far as possible
    if isinstance(predicate_tree, list):
        # perform the actual filtering
        matching_predicates = []
        for pred in predicate_tree:
            # each predicate is a dict with just a single key: the comparison operation
            # this slightly weird structure needs to be unwrapped
            # e.g. {'eq': ['it.info', {'literal': 'top 250 rank'}]}
            pred_source = unwrap_singular_dict(pred)[0]
            if not pred_source.startswith(base_table.alias + "."):
                continue
            pred_target = unwrap_singular_dict(pred)[1]
            
            # predicates such as customer.last_login_date > customer.last_purchase_date
            if isinstance(pred_target, str) and pred_target.startswith(base_table.alias + "."):
                matching_predicates.append(pred)
            # predicates such as customer.last_login_date > 2022-01-01
            elif isinstance(pred_target, dict) and "literal" in pred_target:
                matching_predicates.append(pred)
            # join predicates such as customer.id = buyer.id
            else:
                continue

        return matching_predicates
    elif isinstance(predicate_tree, dict):
        operation = unwrap_singular_dict(predicate_tree, target="key")
        return {operation: extract_filter_predicates(predicate_tree[operation], base_table)}

In [16]:
class FlattenedQuery:
    def __init__(self, base_table: TableRef, dbschema: DBSchema):
        self.base_table = base_table
        self.dbschema = dbschema
        self.joins = list()
        self.dangling_joins = list()
        
    def absorb_join(self, join):
        """Includes a join statement in the query. Could be either a subquery, or a plain join."""
        join_data = join["join"]
        join_predicate = join["on"]
        
        # sanity check
        if isinstance(join_data["value"], dict) and not "select" in join_data["value"]:
            warnings.warn("Unknown query structure:", join)
            return
        
        # just a plain join?
        if not isinstance(join_data["value"], dict):
            joined_table = TableRef(join_data["value"], join_data["name"])
            join_stmt = JoinStatement(joined_table, join_predicate)
            self.joins.append(join_stmt)
            return
        
        # at this point, we know we found a subquery
        # we are going to break up its structure first, and gather the relevant data secondly
        subquery = join_data["value"]
        sq_projection = subquery["select"]
        sq_tables = subquery["from"]
        sq_target = join_data["name"] # we actually don't need this
        
        table_refs = extract_table_references(sq_tables)
        attribute_refs = bind_attributes(collect_attributes(sq_projection), table_refs, dbschema=self.dbschema)
        print(attribute_refs)
    
    def __str__(self):
        return f"{self.base_table} {self.joins}"
    
def flatten_query(mosp_query):
    from_clause = mosp_query["from"]
    flattened_query = None
    for table in from_clause:
        # extract the base table name and proceed with the joined tables
        if not isinstance(table, dict) or "join" not in table:
            table_ref = TableRef(table["value"], table["name"])
            flattened_query = FlattenedQuery(table_ref, DBSchema(cur))
            continue
        else:
            flattened_query.absorb_join(table)
            
    return flattened_query

In [17]:
flatten_query(q)

[mi_idx.movie_id]


<__main__.FlattenedQuery at 0x7fb3db9d1e80>

In [13]:
q_raw

"select count(*) from  movie_companies as mc  join company_type as ct on (ct.kind = 'production companies' and ct.id = mc.company_type_id and mc.note not like '%(as metro-goldwyn-mayer pictures)%' and (mc.note like '%(co-production)%'  or mc.note like '%(presents)%')) join title as t on (t.id = mc.movie_id) join  (select movie_id from movie_info_idx as mi_idx  join info_type as it on (it.info = 'top 250 rank' and it.id = mi_idx.info_type_id)) as t_mi_idx  on(t_mi_idx.movie_id = mc.movie_id);"

## Regression analysis

First up, based on actual results:

In [63]:
regressions = []
regression_found = False
for q_idx, query in enumerate(df_queries.itertuples()):
    orig_query = query.query
    cur.execute(orig_query)
    orig_card = cur.fetchone()[0]
    mosp_query = mosp.format(query.mosp_query)
    cur.execute(mosp_query)
    mosp_card = cur.fetchone()[0]
    if q_idx % 50 == 0:
        print("Now at query", q_idx+1)
    if orig_card != mosp_card:
        regression_found = True
        regressions.append(query.label)
        print("= Regression found for query", query.label, "Orig:", orig_card, "MOSP:", mosp_card)
        
if not regression_found:
    print("== All tests succeeded! ==")
else:
    print("== Regressions found: ==")
    print(regressions)

Now at query 1
Now at query 51
Now at query 101
Now at query 151
Now at query 201


Secondly, based on query plans:

In [22]:
def extract_plan_nodes(plan):
    node_type = plan["Node Type"]
    node_filter = ""
    if "Join Filter" in plan:
        node_filter = plan["Join Filter"]
    elif "Hash Cond" in plan:
        node_filter = plan["Hash Cond"]
    elif "Filter" in plan and "Index Cond" in plan:
        node_filter = plan["Index Cond"] + " // " + plan["Filter"]
    elif "Filter" in plan:
        node_filter = plan["Filter"]
    elif "Index Cond" in plan:
        node_filter = plan["Index Cond"]
    
    nodes = [node_type]
    if node_filter:
        nodes = [f"{node_type} :: {node_filter}"]
    
    for subplan in plan.get("Plans", []):
        nodes.extend(extract_plan_nodes(subplan))
    return nodes

In [25]:
regressions = []
regression_found = False
for q_idx, query in enumerate(df_queries.itertuples()):
    orig_query = query.query
    cur.execute("explain (format json) " + orig_query)
    orig_plan = cur.fetchone()[0][0]["Plan"]
    orig_nodes = extract_plan_nodes(orig_plan)
    
    mosp_query = mosp.format(query.mosp_query)
    cur.execute("explain (format json) " + mosp_query)
    mosp_plan = cur.fetchone()[0][0]["Plan"]
    mosp_nodes = extract_plan_nodes(mosp_plan)
    
    if q_idx % 50 == 0:
        print("Now at query", q_idx+1)
    if orig_nodes != mosp_nodes:
        regression_found = True
        regressions.append(query.label)
        print("= Regression found for query", query.label, "Orig:", orig_nodes, "MOSP:", mosp_nodes)
        
if not regression_found:
    print("== All tests succeeded! ==")
else:
    print("== Regressions found: ==")
    print(regressions)

Now at query 1
Now at query 51
Now at query 101
Now at query 151
Now at query 201
== All tests succeeded! ==
