#### SQL Query Parsing

In [41]:
import sqlglot
from sqlglot import expressions as exp
 

In [79]:
query = """
        select l_returnflag,
                100.00 * sum(case
                        when p_type like 'PROMO%'
                                then l_extendedprice * (1 - l_discount)
                        else 0
                end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
        from
                lineitem,
                part
        where
                l_partkey = p_partkey
                and l_shipdate >= CAST('1993-04-01' AS date)
                and l_shipdate < DATEADD(mm, 1, CAST('1993-04-01' AS date))
        order by
                l_returnflag;
        """        
        

In [80]:
# parse query
parsed = sqlglot.parse_one(query)


In [81]:
parsed

Select(
  expressions=[
    Column(
      this=Identifier(this=l_returnflag, quoted=False)),
    Alias(
      this=Div(
        this=Mul(
          this=Literal(this=100.00, is_string=False),
          expression=Sum(
            this=Case(
              ifs=[
                If(
                  this=Like(
                    this=Column(
                      this=Identifier(this=p_type, quoted=False)),
                    expression=Literal(this=PROMO%, is_string=True)),
                  true=Mul(
                    this=Column(
                      this=Identifier(this=l_extendedprice, quoted=False)),
                    expression=Paren(
                      this=Sub(
                        this=Literal(this=1, is_string=False),
                        expression=Column(
                          this=Identifier(this=l_discount, quoted=False))))))],
              default=Literal(this=0, is_string=False)))),
        expression=Sum(
          this=Mul(
            this=Column(

In [82]:
# extract all column names
print("Column names: ", [c.name for c in parsed.find_all(sqlglot.expressions.Column)])

# extract all table names
print("Table names: ", [t.name for t in parsed.find_all(sqlglot.expressions.Table)])

Column names:  ['l_returnflag', 'l_returnflag', 'l_shipdate', 'l_extendedprice', 'l_partkey', 'p_partkey', 'l_shipdate', 'mm', 'l_discount', 'p_type', 'l_extendedprice', 'l_discount']
Table names:  ['lineitem', 'part']


In [84]:
def get_columns(query):
    # Parse the query
    parsed = sqlglot.parse_one(query)

    def extract_columns(node):
        columns = set()
        if isinstance(node, exp.Column):
            columns.add(node.name)
        for child in node.args.values():
            if isinstance(child, exp.Expression):
                columns.update(extract_columns(child))
        return columns

    # Find the WHERE clause
    where_clause = parsed.find(exp.Where)
    # Extract columns from the WHERE clause
    predicate_columns = set()
    if where_clause:
        predicate_columns = extract_columns(where_clause)

    # Find the SELECT clause
    select_clause = parsed.find(exp.Select)
    # Extract columns from the SELECT clause
    payload_columns = set()
    if select_clause:
        for projection in select_clause.expressions:
            payload_columns.update(extract_columns(projection))

    # Find the ORDER BY clause
    order_by_clause = parsed.find(exp.Order)
    # Extract columns from the ORDER BY clause
    order_by_columns = set()
    if order_by_clause:
        for order in order_by_clause.expressions:
            order_by_columns.update(extract_columns(order))

    # Find the GROUP BY clause
    group_by_clause = parsed.find(exp.Group)
    # Extract columns from the GROUP BY clause
    group_by_columns = set()
    if group_by_clause:
        for group in group_by_clause.expressions:
            group_by_columns.update(extract_columns(group))


    return predicate_columns, payload_columns, order_by_columns, group_by_columns



predicate_columns, payload_columns, order_by_columns, group_by_columns = get_columns(query)

# Print the columns involved in predicates
print(f"Predicate columns: {predicate_columns}, Payload columns: {payload_columns}, Order by columns: {order_by_columns}, Group by columns: {group_by_columns}")


Predicate columns: {'l_shipdate', 'l_partkey', 'p_partkey'}, Payload columns: {'l_discount', 'l_returnflag', 'l_extendedprice'}, Order by columns: {'l_returnflag'}, Group by columns: set()
