Skip to content

Commit

Permalink
Make Snowpark dependency truly optional for SnowflakeConnection (stre…
Browse files Browse the repository at this point in the history
  • Loading branch information
vdonato authored and zyxue committed Apr 16, 2024
1 parent 331463f commit 6299108
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
5 changes: 4 additions & 1 deletion lib/streamlit/connections/snowflake_connection.py
Expand Up @@ -52,11 +52,14 @@ class SnowflakeConnection(BaseConnection["InternalSnowflakeConnection"]):
def _connect(self, **kwargs) -> "InternalSnowflakeConnection":
import snowflake.connector # type:ignore[import]
from snowflake.connector import Error as SnowflakeError # type:ignore[import]
from snowflake.snowpark.context import get_active_session # type:ignore[import]

# If we're running in SiS, just call get_active_session() and retrieve the
# lower-level connection from it.
if running_in_sis():
from snowflake.snowpark.context import ( # type:ignore[import] # isort: skip
get_active_session,
)

session = get_active_session()

if hasattr(session, "connection"):
Expand Down
13 changes: 8 additions & 5 deletions lib/streamlit/connections/util.py
Expand Up @@ -81,8 +81,11 @@ def load_from_snowsql_config_file(connection_name: str) -> Dict[str, Any]:

def running_in_sis() -> bool:
"""Return whether this app is running in SiS."""
from snowflake.snowpark._internal.utils import ( # type: ignore[import] # isort: skip
is_in_stored_procedure,
)

return cast(bool, is_in_stored_procedure())
try:
from snowflake.snowpark._internal.utils import ( # type: ignore[import] # isort: skip
is_in_stored_procedure,
)

return cast(bool, is_in_stored_procedure())
except ModuleNotFoundError:
return False
9 changes: 8 additions & 1 deletion lib/tests/streamlit/connections/util_test.py
Expand Up @@ -36,7 +36,6 @@ def test_extract_from_dict(self):
assert extracted == {"k1": "v1", "k2": "v2"}
assert d == {"k3": "v3", "k4": "v4"}

@pytest.mark.require_snowflake
def test_not_running_in_sis(self):
assert not running_in_sis()

Expand All @@ -48,6 +47,14 @@ def test_not_running_in_sis(self):
def test_running_in_sis(self):
assert running_in_sis()

@pytest.mark.require_snowflake
@patch(
"snowflake.snowpark._internal.utils.is_in_stored_procedure",
MagicMock(side_effect=ModuleNotFoundError("oh no")),
)
def test_running_in_sis_module_not_found_error(self):
assert not running_in_sis()

def test_load_from_snowsql_config_file_no_file(self):
assert load_from_snowsql_config_file("my_snowpark_connection") == {}

Expand Down

0 comments on commit 6299108

Please sign in to comment.