Skip to content

Commit

Permalink
Merge pull request #40 from runprism/utils
Browse files Browse the repository at this point in the history
`requires_dependency` util function
  • Loading branch information
prism-admin committed Jan 23, 2024
2 parents d358b09 + fe9994b commit ee25f84
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
5 changes: 5 additions & 0 deletions prism/profiles/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -75,6 +76,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
["google.cloud", "google.oauth2"], # noqa
"bigquery"
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down
26 changes: 18 additions & 8 deletions prism/profiles/dbt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""
DBT adapter class definition. This definition uses source code from:
https://github.com/fal-ai/fal
Modifications are made to ensure compatibility with the rest of prism's architecture
DBT adapter class definition.
Table of Contents
- Imports
Expand Down Expand Up @@ -35,8 +31,12 @@
import dbt.tracking
from dbt.task.compile import CompileTask
from dbt.parser.manifest import ManifestLoader
from dbt.contracts.graph.manifest import Manifest, MaybeNonSource, Disabled
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.graph.manifest import (
Manifest,
MaybeNonSource,
Disabled,
)
from dbt.contracts.graph.nodes import ResultNode, GraphMemberNode
from dbt.adapters.sql.impl import SQLAdapter
import dbt.adapters.factory as adapters_factory
from dbt.contracts.sql import ResultTable, RemoteRunResult
Expand All @@ -45,6 +45,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


##########################
Expand Down Expand Up @@ -75,6 +76,7 @@ class InitializeDbtCompileTaskArgs:
exclude: Tuple[str, str]
state: Optional[Path]
single_threaded: Optional[bool]
defer_state: bool = False


####################
Expand Down Expand Up @@ -284,7 +286,8 @@ def get_parsed_task_node(self,
target_model_package: Optional[str],
target_model_version: Optional[str],
project_dir: str,
manifest: Manifest
manifest: Manifest,
source_node: Optional[GraphMemberNode] = None
) -> ResultNode:
"""
Get the node associated with the inputted target task
Expand All @@ -294,6 +297,8 @@ def get_parsed_task_node(self,
target_package_name: package containing task
project_dir: project directory
manifest: dbt Manifest
source_node: dbt node from which we call `ref`. Since Prism exists outside
the dbt graph, this should almost always be `None`
returns:
node associated with inputted target task
"""
Expand All @@ -304,6 +309,7 @@ def get_parsed_task_node(self,

# TODO: test target task creation where node_package != project_dir
target_model: MaybeNonSource = manifest.resolve_ref(
source_node=source_node,
target_model_name=target_model_name,
target_model_package=target_model_package,
target_model_version=target_model_version,
Expand Down Expand Up @@ -467,6 +473,10 @@ def execute(self, return_type: str = "list"):
else:
return '\n'.join(execute_str)

@requires_dependencies(
"dbt",
"dbt"
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down
9 changes: 9 additions & 0 deletions prism/profiles/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -83,6 +84,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
"psycopg2",
"postgres"
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down Expand Up @@ -120,6 +125,10 @@ def create_engine(self,
conn.set_session(autocommit=True)
return conn

@requires_dependencies(
"psycopg2",
"postgres"
)
def execute_sql(self, query: str, return_type: Optional[str]) -> pd.DataFrame:
"""
Execute the SQL query
Expand Down
9 changes: 9 additions & 0 deletions prism/profiles/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -100,6 +101,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
"prestodb",
"presto",
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down Expand Up @@ -162,6 +167,10 @@ def create_engine(self,

return conn

@requires_dependencies(
"prestodb",
"presto",
)
def execute_sql(self, query: str, return_type: Optional[str]) -> pd.DataFrame:
"""
Execute the SQL query
Expand Down
5 changes: 5 additions & 0 deletions prism/profiles/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -90,6 +91,10 @@ def parse_adapter_dict(self,
else:
return '\n'.join(profile_exec_list)

@requires_dependencies(
"pyspark",
"pyspark",
)
def create_engine(self,
adapter_dict: Dict[str,
Any],
Expand Down
9 changes: 9 additions & 0 deletions prism/profiles/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -83,6 +84,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
"psycopg2",
"postgres"
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down Expand Up @@ -120,6 +125,10 @@ def create_engine(self,
conn.set_session(autocommit=True)
return conn

@requires_dependencies(
"psycopg2",
"postgres"
)
def execute_sql(self, query: str, return_type: Optional[str]) -> pd.DataFrame:
"""
Execute the SQL query
Expand Down
9 changes: 9 additions & 0 deletions prism/profiles/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -78,6 +79,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
["snowflake.connector", "pyarrow"],
"snowflake"
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down Expand Up @@ -112,6 +117,10 @@ def create_engine(self,
)
return ctx

@requires_dependencies(
["snowflake.connector", "pyarrow"],
"snowflake"
)
def execute_sql(self, query: str, return_type: Optional[str]) -> pd.DataFrame:
"""
Execute the SQL query
Expand Down
9 changes: 9 additions & 0 deletions prism/profiles/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Prism-specific imports
from .adapter import Adapter
import prism.exceptions
from prism.utils import requires_dependencies


####################
Expand Down Expand Up @@ -100,6 +101,10 @@ def is_valid_config(self,
# If no exception has been raised, return True
return True

@requires_dependencies(
"trino",
"trino",
)
def create_engine(self,
adapter_dict: Dict[str, Any],
adapter_name: str,
Expand Down Expand Up @@ -162,6 +167,10 @@ def create_engine(self,

return conn

@requires_dependencies(
"trino",
"trino",
)
def execute_sql(self, query: str, return_type: Optional[str]) -> pd.DataFrame:
"""
Execute the SQL query
Expand Down
5 changes: 3 additions & 2 deletions prism/tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ def test_postgres(self):
# Check output
self.assertTrue(expected_output.is_file())
df = pd.read_csv(expected_output)
self.assertEqual(df.shape[0], 1)
self.assertEqual(df.test_col[0], 1)
self.assertEqual(df.shape[0], 10)
self.assertEqual(df["first_name"][0], "Abel")
self.assertEqual(df["last_name"][0], "Maclead")

# Remove the .compiled directory, if it exists
self._remove_compiled_dir(wkdir)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
CUSTOMER_ID,FIRST_NAME,LAST_NAME,FIRST_ORDER,MOST_RECENT_ORDER,NUMBER_OF_ORDERS,CUSTOMER_LIFETIME_VALUE
1.0,Michael,P.,2018-01-01,2018-02-10,2.0,33.0
2.0,Shawn,M.,2018-01-11,2018-01-11,1.0,23.0
3.0,Kathleen,P.,2018-01-02,2018-03-11,3.0,65.0
4.0,Jimmy,C.,,,,
5.0,Katherine,R.,,,,
6.0,Sarah,R.,2018-02-19,2018-02-19,1.0,8.0
7.0,Martin,M.,2018-01-14,2018-01-14,1.0,26.0
8.0,Frank,R.,2018-01-29,2018-03-12,2.0,45.0
9.0,Jennifer,F.,2018-03-17,2018-03-17,1.0,30.0
10.0,Henry,W.,,,,
1,Michael,P.,2018-01-01,2018-02-10,2.0,33.0
2,Shawn,M.,2018-01-11,2018-01-11,1.0,23.0
3,Kathleen,P.,2018-01-02,2018-03-11,3.0,65.0
4,Jimmy,C.,,,,
5,Katherine,R.,,,,
6,Sarah,R.,2018-02-19,2018-02-19,1.0,8.0
7,Martin,M.,2018-01-14,2018-01-14,1.0,26.0
8,Frank,R.,2018-01-29,2018-03-12,2.0,45.0
9,Jennifer,F.,2018-03-17,2018-03-17,1.0,30.0
10,Henry,W.,,,,
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ def run(self, tasks, hooks):
returns:
task output
"""
sql = "SELECT 1 AS test_col"
sql = """
SELECT
first_name
, last_name
FROM us500
ORDER BY
first_name
, last_name
LIMIT 10
"""
df = hooks.sql(
adapter_name="postgres_base",
query=sql,
Expand Down
63 changes: 63 additions & 0 deletions prism/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Util functions
"""

# Imports
from typing import (
Any,
Callable,
List,
Optional,
Union,
)
import importlib
from functools import wraps


# Util functions
def requires_dependencies(
dependencies: Union[str, List[str]],
extras: Optional[str] = None,
):
"""
Wrapper used to prompt the user to `pip install` a package and/or Prism extracts in
order to run a function. Borrowed heavily from the `unstructured` library:
https://github.com/Unstructured-IO/unstructured/blob/main/unstructured/utils.py
args:
dependencies: required dependencies
extracts: list of Prism extras that the user can `pip install`
"""
if isinstance(dependencies, str):
dependencies = [dependencies]

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args, **kwargs):
missing_deps: List[str] = []
for dep in dependencies:
if not dependency_exists(dep):
missing_deps.append(dep)
if len(missing_deps) > 0:
raise ImportError(
f"""Following dependencies are missing: {', '.join(["`" + dep + "`" for dep in missing_deps])}. """ # noqa
+ ( # noqa
f"""Please install them using `pip install "prism-ds[{extras}]"`.""" # noqa
if extras
else f"Please install them using `pip install {' '.join(missing_deps)}`." # noqa
),
)
return func(*args, **kwargs)

return wrapper
return decorator


def dependency_exists(dependency: str):
try:
importlib.import_module(dependency)
except ImportError as e:
# Check to make sure this isn't some unrelated import error.
if dependency in repr(e):
return False
return True

0 comments on commit ee25f84

Please sign in to comment.