Skip to content

Commit

Permalink
Merge pull request #21 from barrywhart/bhart-pass_sql_string_in_reque…
Browse files Browse the repository at this point in the history
…st_body

For /lint endpoint, pass SQL string in request body
  • Loading branch information
z3z1ma committed Oct 13, 2022
2 parents 95b3d00 + 386f02b commit cdcb911
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
26 changes: 22 additions & 4 deletions src/dbt_osmosis/core/server_v2.py
Expand Up @@ -65,6 +65,7 @@ class OsmosisErrorCode(int, Enum):
ProjectParseFailure = 3
ProjectNotRegistered = 4
ProjectHeaderNotSupplied = 5
SqlNotSupplied = 6


class OsmosisError(BaseModel):
Expand Down Expand Up @@ -214,8 +215,8 @@ async def compile_sql(
},
)
async def lint_sql(
request: Request,
response: Response,
sql: Optional[str] = None,
sql_path: Optional[str] = None,
# TODO: Should config_path be part of /register instead?
extra_config_path: Optional[str] = None,
Expand All @@ -238,12 +239,29 @@ async def lint_sql(
)

# Query Linting
if sql_path is not None:
# Lint a file
sql = Path(sql_path)
else:
# Lint a string
sql = (await request.body()).decode("utf-8")
if not sql:
# No SQL provided -- error.
response.status_code = status.HTTP_400_BAD_REQUEST
return OsmosisErrorContainer(
error=OsmosisError(
code=OsmosisErrorCode.SqlNotSupplied,
message="No SQL provided. Either provide a SQL file path or a SQL string to lint.",
data={},
)
)
try:
result = lint_command(
temp_result = lint_command(
Path(project.project_root),
sql=Path(sql_path) if sql_path else sql,
sql=sql,
extra_config_path=Path(extra_config_path) if extra_config_path else None,
)["violations"]
)
result = temp_result["violations"] if temp_result is not None else []
except Exception as lint_err:
logging.exception("Linting failed")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
Expand Down
5 changes: 2 additions & 3 deletions src/dbt_osmosis/sqlfluff_util.py
Expand Up @@ -43,7 +43,7 @@ def lint_command(
sql: Union[Path, str],
extra_config_path: Optional[Path] = None,
ignore_local_config: bool = False,
) -> Dict:
) -> Optional[Dict]:
"""Lint specified file or SQL string.
This is essentially a streamlined version of the SQLFluff command-line lint
Expand Down Expand Up @@ -77,8 +77,7 @@ def lint_command(
ignore_files=False,
)
records = result.as_records()
assert len(records) == 1
return records[0]
return records[0] if records else None


def test_lint_command():
Expand Down
52 changes: 45 additions & 7 deletions tests/sqlfluff_templater/test_server_v2.py
Expand Up @@ -17,17 +17,25 @@
SQL_PATH = Path(DBT_FLUFF_CONFIG["templater"]["dbt"]["project_dir"]) / "models/my_new_project/issue_1608.sql"


@pytest.mark.parametrize("param_name, param_value", [
("sql_path", SQL_PATH),
("sql", SQL_PATH.read_text()),
])
@pytest.mark.parametrize(
"param_name, param_value",
[
("sql_path", SQL_PATH),
(None, SQL_PATH.read_text()),
],
)
def test_lint(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog):
params = {}
kwargs = {}
if param_name:
params[param_name] = param_value
else:
kwargs["data"] = param_value
response = client.post(
"/lint",
headers={"X-dbt-Project": "dbt_project"},
params={
param_name: param_value,
},
params=params,
**kwargs,
)
assert response.status_code == 200
response_json = response.json()
Expand All @@ -48,3 +56,33 @@ def test_lint(param_name, param_value, profiles_dir, project_dir, sqlfluff_confi
},
]
}


def test_lint_error_no_sql_provided(profiles_dir, project_dir, sqlfluff_config_path, caplog):
response = client.post(
"/lint",
headers={"X-dbt-Project": "dbt_project"},
)
assert response.status_code == 400
response_json = response.json()
assert response_json == {
"error": {
"code": 6,
"data": {},
"message": "No SQL provided. Either provide a SQL file path or a SQL string to lint.",
}
}


def test_lint_parse_failure(profiles_dir, project_dir, sqlfluff_config_path, caplog):
response = client.post(
"/lint",
headers={"X-dbt-Project": "dbt_project"},
data="""select
{{ dbt_utils.star(ref("base_cases")) }}
from {{ ref("base_cases") }}
li""",
)
assert response.status_code == 200
response_json = response.json()
assert response_json == {"result": []}

0 comments on commit cdcb911

Please sign in to comment.