In [45]:
#hide
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False  # workaround for buggy jedi

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
# default_exp format_file

In [47]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [48]:
#export
import re
import os
import tempfile
import argparse
from glob import glob
import ptypysql
from ptypysql.core import *
from ptypysql.utils import *
from ptypysql.validation import *

# format_file

> Functions to format a SQL file with multiple queries and SQL statements

In [49]:
#hide
from nbdev.showdoc import *

## Use-Case

Assume you have a file called sql_file.py containing SQL statements and queries.

After reading it in python we could have something like this:

In [50]:
sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;

create or replace view first_view as -- my first view
select a.car_id,
       b.car_name, sum(a.price) over (partition by b.car_name order by a.car_id) as sum_price, a.price,
from sales as a left join (select car_id, car_name, from cars) as b 
on a.car_id = b.car_id
where car_id>1 and car_id<=100 order by b.car_name;

-- Table no. 1 --
create or replace table first_table as -- my first table
select car_id,
       avg(price) as avg_price,
from first_view
group by car_id order by car_id;

--- End of file ---
""".strip()

Then we would like to format the SQL-queries in this file, while letting every other non-query-SQL statement untouched. For the example above we would like to have something like this:

In [51]:
expected_sql_file = """
--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
    b.car_name,
    sum(a.price) OVER (
        PARTITION BY b.car_name
        ORDER BY a.car_id
    ) AS sum_price,
    a.price
FROM sales AS a
    LEFT JOIN (
        SELECT car_id,
            car_name
        FROM cars
    ) AS b
        ON a.car_id = b.car_id
WHERE car_id > 1
    AND car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS
SELECT car_id,
    avg(price) AS avg_price -- my first table
FROM first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---
""".lstrip()

python file
test_script.py

### Formatting philosophy of SQL files

* Every SQL-query is separated from above by two new lines
* Every SQL-query is formatted via `format_sql`

### Main function to format the SQL commands in a file

This function applies also basic validation and aborts formatting if the statements `CREATE .. TABLE / VIEW` appear at least twice in the same query after splitting by semicolon, warning the user that she / he may have forgotten a semicolon

In [52]:

#export
def format_sql_commands(s_original, max_len=99, semicolon = True):
    "Format SQL commands in `s`. If SELECT line is longer than `max_len` then reformat line"
    s = s_original.strip()  # strip file contents
    split_s = split_by_semicolon(s)  # split by query
    # validate semicolon
    validations_semicolon = [validate_semicolon(sp) for sp in split_s]
    val_summary_semicolon = sum([val["exit_code"] for val in validations_semicolon])
    # validate balanced parenthesis
    validations_balanced = [validate_balanced_parenthesis(sp) for sp in split_s]
    val_summary_balanced = sum([val["exit_code"] for val in validations_balanced])
    # validate balanced case when ... end
    val_case_end_balanced = [validate_case_when(sp) for sp in split_s if sp != ""]
    val_summary_case = sum([val["exit_code"] for val in val_case_end_balanced])
    if sum([val_summary_semicolon, val_summary_balanced, val_summary_case]) == 0:
        split_comment_after_semicolon = re.compile("((?:\n|create|select|with))")
        check_comment_after_semicolon = re.compile(r"(?:;)[\r\t\f\v ]*(?:\/\*|--)")
        check_ending_semicolon = re.compile(r";\s*$")
        split_s_out = []  # initialize container
        last_i = len(split_s) - 1
        formatted = False
        for i, sp in enumerate(split_s):  # split by semicolon
            # take care of comment after semicolon
            # split by first newline and format only the second item
            if check_comment_after_semicolon.match(sp):
                split_s2 = split_comment_after_semicolon.split(sp, maxsplit=1)
            else:
                split_s2 = [sp]
            # check if the queries are formatted or not
            if check_sql_query(sp) and not check_skip_marker(sp):
                formatted = True
            formatted_split_s2 = [
                "\n\n\n" + format_sql(sp, semicolon=semicolon, max_len=max_len).strip()
                if check_sql_query(sp) and not check_skip_marker(sp)
                else sp
                for sp in split_s2
            ]
            formatted_sp = "".join(formatted_split_s2)
            if i != last_i:
                split_c = split_comment(formatted_sp)
                s_code = "".join([d["string"] for d in split_c if not d["comment"]])
                formatted_sp = (
                    formatted_sp 
                    if check_ending_semicolon.search(s_code)
                    or formatted_sp == ""
                    else formatted_sp + ";"
                )
            split_s_out.append("".join(formatted_sp))
        # if not formatted, return original sql
        if not formatted:
            return s_original
        # join by semicolon
        formatted_s = "".join(split_s_out)
        # remove starting and ending newlines
        formatted_s = formatted_s.strip()
        # remove more than 3 newlines
        formatted_s = re.sub(r"\n{4,}", "\n\n\n", formatted_s)
        # add newline at the end of file
        formatted_s = formatted_s + "\n"
        return formatted_s
    else:
        error_dict = {}
        if val_summary_semicolon > 0:
            file_lines = [
                tuple([line + sum([sd["total_lines"] for sd in validations_semicolon[0:i]]) for line in d["val_lines"]])
                for i, d in enumerate(validations_semicolon)
                if d["exit_code"] == 1
            ]
            error_dict["semicolon"] = {
                "error_code": 2,
                "lines": file_lines
            }
        if val_summary_balanced > 0:
            file_lines = [
                [line + sum([sd["total_lines"] for sd in validations_balanced[0:i]]) for line in d["val_lines"]]
                for i, d in enumerate(validations_balanced)
                if d["exit_code"] == 1
            ]            
            error_dict["unbalanced_parenthesis"] = {
                "error_code": 3,
                "lines": file_lines
            }
        if val_summary_case > 0:
            file_lines = [
                [line + sum([sd["total_lines"] for sd in val_case_end_balanced[0:i]]) for line in d["val_lines"]]
                for i, d in enumerate(val_case_end_balanced)
                if d["exit_code"] == 1
            ]            
            error_dict["unbalanced_case"] = {
                "error_code": 4,
                "lines": file_lines
            }            
        return error_dict

Basic file formatting

In [53]:
assert_and_print(
    format_sql_commands(sql_file),
    expected_sql_file
)

--- Views for some nice data mart ---
use database my_database;
use schema my_schema;


CREATE OR REPLACE VIEW first_view AS -- my first view
SELECT a.car_id,
    b.car_name,
    sum(a.price) OVER (
        PARTITION BY b.car_name
        ORDER BY a.car_id
    ) AS sum_price,
    a.price
FROM sales AS a
    LEFT JOIN (
        SELECT car_id,
            car_name
        FROM cars
    ) AS b
        ON a.car_id = b.car_id
WHERE car_id > 1
    AND car_id <= 100
ORDER BY b.car_name;


-- Table no. 1 --
CREATE OR REPLACE TABLE first_table AS
SELECT car_id,
    avg(price) AS avg_price -- my first table
FROM first_view
GROUP BY car_id
ORDER BY car_id;

--- End of file ---



Using the `/*skip-formatter*/` marker to not format some query

In [54]:
assert_and_print(
    format_sql_commands("""
use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;

create or replace table my_table As
Select asdf, qwer
From table2
group by asdf;
"""),
    """use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;


CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
    qwer
FROM table2
GROUP BY asdf;
""")

use database my_database;

/*skip-formatter*/
create Or replace View my_view aS
select asdf, qwer
from table1;


CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
    qwer
FROM table2
GROUP BY asdf;



In [55]:
assert_and_print(
    format_sql_commands("""
create or replace table my_table As
Select asdf, qwer
From table2
group by asdf;
"""),
    """
CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
    qwer
FROM table2
GROUP BY asdf;
""".lstrip())

CREATE OR REPLACE TABLE my_table AS
SELECT asdf,
    qwer
FROM table2
GROUP BY asdf;



In [56]:
# need fix: problem with comment location
# assert_and_print(
#     format_sql_commands("""
# create or replace table my_table As
# Select asdf, qwer
# From table2
# group by asdf -- some comment
# ;
# """),
#     """
# CREATE OR REPLACE TABLE my_table AS 
# SELECT asdf,
#     qwer
# FROM table2
# GROUP BY asdf; -- some comment
# """.lstrip())

In [57]:
# need fix: problem with comment location
# assert_and_print(
#     format_sql_commands(
# """
# create table my_table As
# select asdf, Qwer, /* ; */
# qwer2, -- ;
# replace(';', '', qwer3) as Qwer4
# from table1; /* Some comment */

# create view my_view As
# Select asdf
# From my_table; /* Another comment */
# """
#     ),
# """
# CREATE TABLE my_table AS
# SELECT asdf,
#        qwer, /* ; */
#        qwer2, -- ;
#        replace(';', '', qwer3) AS qwer4
# FROM   table1; /* Some comment */


# CREATE VIEW my_view AS
# SELECT asdf
# FROM   my_table; /* Another comment */
# """.lstrip()
# )

In [58]:
# need fix: problem with comment location

# assert_and_print(
#     format_sql_commands(
# """
# create or replace transient table my_table As
# select asdf, Qwer, /* ; */
# qwer2, -- ;
# replace(';', '', qwer3) as Qwer4
# from table1;

# create view my_view As
# Select asdf
# From my_table;
# """.lstrip()
#     ),
# """
# CREATE OR REPLACE TRANSIENT TABLE my_table AS
# SELECT asdf,
#        qwer, /* ; */
#        qwer2, -- ;
#        replace(';', '', qwer3) AS qwer4
# FROM   table1;


# CREATE VIEW my_view AS
# SELECT asdf
# FROM   my_table;
# """.lstrip()
# )

If the validation fails, then the function returns a dictionary instead of the formatted queries with information about the error

Semicolon validation error

In [59]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
qwer2, -- ;
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
Select asdf
From my_table;
""".lstrip()
    ), 
    {"semicolon": {"error_code": 2, "lines": [(1, 7)]}}
)

{'semicolon': {'error_code': 2, 'lines': [(1, 7)]}}


Unbalanced parenthesis error

In [60]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ; ()
( /* ) */
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
(Select asdf
From my_table;
""".lstrip()
    ), 
    {"unbalanced_parenthesis": {"error_code": 3, "lines": [[3, 4], [9]]}}
)

{'unbalanced_parenthesis': {'error_code': 3, 'lines': [[3, 4], [9]]}}


Unbalanced parenthesis + semicolon error

In [61]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
(qwer2, -- ; ()
( /* ) */
replace(';', '', qwer3) as Qwer4
from table1

create view my_view As
(Select asdf
From my_table;
""".lstrip()
    ), 
    {
        "semicolon": {"error_code": 2, "lines": [(1, 8)]},        
        "unbalanced_parenthesis": {"error_code": 3, "lines": [[3, 4, 9]]},
    }
)

{'semicolon': {'error_code': 2, 'lines': [(1, 8)]}, 'unbalanced_parenthesis': {'error_code': 3, 'lines': [[3, 4, 9]]}}


Unbalanced case when ... end

In [62]:
assert_and_print(
    format_sql_commands(
"""
create or replace transient table my_table As
select asdf, Qwer, /* ; */
case when asdf = 1 then 1 as qwer,
replace(';', '', qwer3) as Qwer4
from table1;

create view my_view As
""".lstrip()
    ), 
    {"unbalanced_case": {"error_code": 4, "lines": [[3]]}}
)

{'unbalanced_case': {'error_code': 4, 'lines': [[3]]}}


### Function to format 1 SQL file

In [63]:
#export
def format_sql_file(f, max_len=99):
    """Format file `f` with SQL commands and overwrite the file.
    If SELECT line is longer than 99 characters then reformat line

    Return exit_code:
    * 0 = Everything already formatted
    * 1 = Formatting applied
    * 2 = Problem detected, formatting aborted
    """
    exit_code = 0
    
    # open the file
    with open(f, "r") as file:
        py_scripts = file.read()

    # use for python    
    # TODO: support for custom SQL string searching 
    sql_regex = re.compile(r'.+DB\.(?:fetch|execute)(?:\_\w+)?\(\s*\"\"\"\s*(?:--sql)?\s*([\s\S]+?)\"\"\"\,?')
    sql_heading = re.compile(r'(DB\.(?:fetch|execute)(?:\_\w+)?\()(\s*\"\"\"\s*(?:--sql)?\s*)')
    sqls = sql_regex.finditer(py_scripts)
    for sql in sqls:
        sql_function = sql.group()
        sql_commands = sql.group(1)
        # now only support for select clause(not support for create/insert/update etc.)
        if not sql_commands.strip().lower().startswith(("with", "select")):
            continue
        indent_length = len(sql_function) - len(sql_function.lstrip()) + 4
        indent = " " * indent_length
        max_len_with_indent = max_len - indent_length
        # format SQL statements, and only add semicolon if SQL input of the function is finished, i.e., get "," 
        if sql_function.endswith(','):
            formatted_file = format_sql_commands(sql_commands, max_len=max_len_with_indent, semicolon = True)
        else:
            formatted_file = format_sql_commands(sql_commands, max_len=max_len_with_indent, semicolon = False)
        if isinstance(formatted_file, dict):
            print(f"Something went wrong in file: {f}")
            if "semicolon" in formatted_file.keys():
                print(
                    (
                    "[WARNING] Identified CREATE keyword more than twice within the same query " +
                    f"at lines {formatted_file['semicolon']['lines']}\n"
                    "You may have forgotten a semicolon (;) to delimit the queries"
                    )
                )
            if "unbalanced_parenthesis" in formatted_file.keys():
                print(
                    (
                    "[WARNING] Identified unbalanced parenthesis " +
                    f"at lines {formatted_file['unbalanced_parenthesis']['lines']}\n"
                    "You should check your parenthesis"
                    )
                )
            if "unbalanced_case" in formatted_file.keys():
                print(
                    (
                    "[WARNING] Identified unbalanced case when ... end " +
                    f"at lines {formatted_file['unbalanced_case']['lines']}\n"
                    "You should check for missing case or end keywords"
                    )
                )
            print(f"Aborting formatting for file {f}")
            return 2

        else:
            sql_commands_indented = "\n".join(remove_prefix(sql_commands_split, indent) for sql_commands_split in sql_commands.split("\n"))
            if sql_commands_indented == formatted_file or sql_commands == formatted_file:
                exit_code += 0
            else:
                exit_code = 1
                # for safety
                if sql_function in py_scripts and sql_commands in sql_function:
                    formatted_file = "\n".join([indent + s for s in formatted_file.split("\n")])
                    formatted_function = sql_heading.sub(r'\1' + f'\n{indent}"""\n', sql_function)
                    formatted_function = formatted_function.replace(sql_commands, formatted_file)
                    py_scripts = py_scripts.replace(sql_function, formatted_function)
                else:
                    print(f"Something went wrong in file: {f}")
                    print(
                        (
                        f"[WARNING] The original SQL query does not exist in the file : {f}\n" +
                        f"The corresponding SQL query is:\n{sql_commands} \n" +
                        f"The formatted SQL query is:\n{formatted_file} \n" +
                        "You may want to replace the SQL query manually"
                        )
                   )
                    return 2

    # overwrite file
    if exit_code == 1:
        with open(f, "w") as file:
            file.write(py_scripts)
    
    return exit_code

### Function to format many SQL files

In [64]:
#export
def format_sql_files(files, recursive=False, max_len=99):
    "Format SQL `files`"
    exit_codes = []
    # if wildcard "*" is input then use it
    if len(files) == 1 and re.search("\*", files[0]):
        if recursive:  # if recursive search
            files = glob(os.path.join("**", files[0]), recursive=True)
        else:
            files = glob(files[0])
    for file in files:
        exit_codes.append(format_sql_file(file, max_len=max_len))
    if sum(exit_codes) == 0:
        print("Nothing to format, everything is fine!")
    else:
        print("All specified files were formatted!")

In [65]:
#export
def format_sql_files_cli():
    "Format SQL files"
    parser = argparse.ArgumentParser(description="Format SQL files")
    parser.add_argument(
        "files",
        help='Path to SQL files. You can also use wildcard using ".*sql"', 
        type=str,
        nargs="+"
    )
    parser.add_argument(
        "-r",
        "--recursive",
        help="Should files also be searched in subfolders?",
        action="store_true"
    )
    parser.add_argument(
        "-m",
        "--max-line-length",
        help="Maximum line length for trunction of SELECT fields",
        type=int,
        default=99
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version=f"ptypysql version {ptypysql.__version__}"
    )
    args = parser.parse_args()
    format_sql_files(files=args.files, recursive=args.recursive, max_len=args.max_line_length)

In [66]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_format_file.ipynb.
Converted 02_utils.ipynb.
Converted 03_validation.ipynb.
Converted 99_additional_tests.ipynb.
Converted index.ipynb.
