Skip to content

Commit

Permalink
fix: Correct minor test issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Apr 29, 2022
1 parent eec0b1c commit 6f3e009
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 20 deletions.
7 changes: 5 additions & 2 deletions airflow_dbt_python/hooks/backends/s3.py
@@ -1,6 +1,7 @@
"""An implementation for an S3 backend for dbt."""
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from zipfile import ZipFile
Expand Down Expand Up @@ -139,7 +140,9 @@ def push_many(
if _file.is_dir():
continue

s3_key = f"s3://{bucket_name}/{key}{ _file.relative_to(source)}"
s3_key = os.path.join(
f"s3://{bucket_name}/{key}", str(_file.relative_to(source))
)

self.load_file_handle_replace_error(
_file,
Expand Down Expand Up @@ -239,7 +242,7 @@ def load_file_handle_replace_error(
self.log.info("Loading file %s to S3: %s", file_path, key)
try:
self.hook.load_file(
file_path,
str(file_path),
key,
bucket_name=bucket_name,
replace=replace,
Expand Down
21 changes: 15 additions & 6 deletions airflow_dbt_python/hooks/dbt.py
Expand Up @@ -43,6 +43,7 @@
from dbt.tracking import initialize_from_flags

from airflow.exceptions import AirflowException
from airflow.version import version as airflow_version

try:
from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -389,10 +390,6 @@ def create_dbt_profile(
if self.profiles_dir is not None:
raw_profiles = read_profile(self.profiles_dir)
else:
profiles_path = Path.home() / ".dbt/profiles.yml"
if not profiles_path.exists():
profiles_path.parent.mkdir(exist_ok=True)
profiles_path.touch()
raw_profiles = {}

if extra_targets:
Expand Down Expand Up @@ -657,6 +654,8 @@ class DbtHook(BaseHook):

def __init__(self, *args, **kwargs):
self.backends: dict[tuple[str, Optional[str]], DbtBackend] = {}
if airflow_version.split()[0] == "1":
kwargs["source"] = None
super().__init__(*args, **kwargs)

def get_backend(self, scheme: str, conn_id: Optional[str]) -> DbtBackend:
Expand Down Expand Up @@ -741,6 +740,7 @@ def run_dbt_task(self, config: BaseConfig) -> tuple[bool, Optional[RunResult]]:

config.dbt_task.pre_init_hook(config)
task, runtime_config = config.create_dbt_task(extra_target)
self.ensure_profiles(config.profiles_dir)

# When creating tasks via from_args, dbt switches to the project directory.
# We have to do that here as we are not using from_args.
Expand All @@ -749,11 +749,9 @@ def run_dbt_task(self, config: BaseConfig) -> tuple[bool, Optional[RunResult]]:
if not isinstance(runtime_config, (UnsetProfileConfig, type(None))):
# The deps command installs the dependencies, which means they may not exist
# before deps runs and the following would raise a CompilationError.
print(runtime_config.args)
runtime_config.load_dependencies()

results = None

with adapter_management():
if not isinstance(runtime_config, (UnsetProfileConfig, type(None))):
register_adapter(runtime_config)
Expand All @@ -764,6 +762,17 @@ def run_dbt_task(self, config: BaseConfig) -> tuple[bool, Optional[RunResult]]:

return success, results

def ensure_profiles(self, profiles_dir: Optional[str]):
"""Ensure a profiles file exists."""
if profiles_dir is not None:
# We expect one to exist given that we have passsed a profiles_dir.
return

profiles_path = Path.home() / ".dbt/profiles.yml"
if not profiles_path.exists():
profiles_path.parent.mkdir(exist_ok=True)
profiles_path.touch()

def get_target_from_connection(
self, target: Optional[str]
) -> Optional[dict[str, Any]]:
Expand Down
10 changes: 5 additions & 5 deletions tests/hooks/dbt/backends/test_dbt_s3_backend.py
Expand Up @@ -266,12 +266,9 @@ def test_push_dbt_project_to_zip_file(s3_bucket, s3_hook, tmpdir, test_files):
# Ensure zip file is not already present.
s3_hook.delete_objects(
s3_bucket,
[zip_s3_key],
)
key = s3_hook.check_for_key(
zip_s3_key,
s3_bucket,
"project/project.zip",
)
key = s3_hook.check_for_key(zip_s3_key)
assert key is False

backend = DbtS3Backend()
Expand All @@ -289,6 +286,9 @@ def test_push_dbt_project_to_zip_file(s3_bucket, s3_hook, tmpdir, test_files):

def test_push_dbt_project_to_files(s3_bucket, s3_hook, tmpdir, test_files):
"""Test pushing a dbt project to a S3 path."""
keys = s3_hook.list_keys(bucket_name=s3_bucket)
assert len(keys) == 0

prefix = f"s3://{s3_bucket}/project/"

backend = DbtS3Backend()
Expand Down
2 changes: 1 addition & 1 deletion tests/hooks/dbt/test_dbt_hook_base.py
Expand Up @@ -109,7 +109,7 @@ def test_dbt_hook_get_target_from_connection(airflow_conns, database):
assert extra_target[conn_id]["dbname"] == database.dbname


@pytest.mark.parametrize("conn_id", [("non_existent",), (None,)])
@pytest.mark.parametrize("conn_id", ["non_existent", None])
def test_dbt_hook_get_target_from_connection_non_existent(conn_id):
"""Test None is returned when Airflow connections do not exist."""
hook = DbtHook()
Expand Down
12 changes: 6 additions & 6 deletions tests/operators/test_dbt_deps.py
Expand Up @@ -108,7 +108,7 @@ def test_dbt_deps_push_to_s3(
# Ensure we are working with an empty dbt_packages dir in S3.
keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
"project/dbt_packages/",
)
if keys is not None and len(keys) > 0:
s3_hook.delete_objects(
Expand All @@ -117,7 +117,7 @@ def test_dbt_deps_push_to_s3(
)
keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
"project/dbt_packages/",
)
assert keys is None or len(keys) == 0

Expand All @@ -139,7 +139,7 @@ def test_dbt_deps_push_to_s3(

keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
f"project/dbt_packages/",
)
assert len(keys) >= 0
# dbt_utils files may be anything, let's just check that at least
Expand Down Expand Up @@ -229,7 +229,7 @@ def test_dbt_deps_push_to_s3_with_no_replace(
# Ensure we are working with an empty dbt_packages dir in S3.
keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
f"project/dbt_packages/",
)
if keys is not None and len(keys) > 0:
s3_hook.delete_objects(
Expand All @@ -238,7 +238,7 @@ def test_dbt_deps_push_to_s3_with_no_replace(
)
keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
f"project/dbt_packages/",
)
assert keys is None or len(keys) == 0

Expand All @@ -256,7 +256,7 @@ def test_dbt_deps_push_to_s3_with_no_replace(

keys = s3_hook.list_keys(
s3_bucket,
f"s3://{s3_bucket}/project/dbt_packages/",
f"project/dbt_packages/",
)
assert len(keys) >= 0
# dbt_utils files may be anything, let's just check that at least
Expand Down

0 comments on commit 6f3e009

Please sign in to comment.