Skip to content

Commit

Permalink
Revert SNOW-719900 Remove random alias for subqeury in joins (#669) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-kdama committed Mar 2, 2023
1 parent 9c64a3c commit 04ce69d
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 117 deletions.
1 change: 0 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,6 @@ def do_resolve_with_resolved_children(
logical_plan.join_type,
self.analyze(logical_plan.condition) if logical_plan.condition else "",
logical_plan,
self.session.use_constant_subquery_alias,
)

if isinstance(logical_plan, Sort):
Expand Down
52 changes: 9 additions & 43 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,22 +477,10 @@ def set_operator_statement(left: str, right: str, operator: str) -> str:


def left_semi_or_anti_join_statement(
left: str,
right: str,
join_type: JoinType,
condition: str,
use_constant_subquery_alias: bool,
left: str, right: str, join_type: JoinType, condition: str
) -> str:
left_alias = (
"SNOWPARK_LEFT_"
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)
right_alias = (
"SNOWPARK_RIGHT_"
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)
left_alias = random_name_for_temp_object(TempObjectType.TABLE)
right_alias = random_name_for_temp_object(TempObjectType.TABLE)

if isinstance(join_type, LeftSemi):
where_condition = WHERE + EXISTS
Expand Down Expand Up @@ -527,22 +515,10 @@ def left_semi_or_anti_join_statement(


def snowflake_supported_join_statement(
left: str,
right: str,
join_type: JoinType,
condition: str,
use_constant_subquery_alias: bool,
left: str, right: str, join_type: JoinType, condition: str
) -> str:
left_alias = (
"SNOWPARK_LEFT_"
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)
right_alias = (
"SNOWPARK_RIGHT_"
if use_constant_subquery_alias
else random_name_for_temp_object(TempObjectType.TABLE)
)
left_alias = random_name_for_temp_object(TempObjectType.TABLE)
right_alias = random_name_for_temp_object(TempObjectType.TABLE)

if isinstance(join_type, UsingJoin):
join_sql = join_type.tpe.sql
Expand Down Expand Up @@ -591,24 +567,14 @@ def snowflake_supported_join_statement(
return project_statement([], source)


def join_statement(
left: str,
right: str,
join_type: JoinType,
condition: str,
use_constant_subquery_alias: bool,
) -> str:
def join_statement(left: str, right: str, join_type: JoinType, condition: str) -> str:
if isinstance(join_type, (LeftSemi, LeftAnti)):
return left_semi_or_anti_join_statement(
left, right, join_type, condition, use_constant_subquery_alias
)
return left_semi_or_anti_join_statement(left, right, join_type, condition)
if isinstance(join_type, UsingJoin) and isinstance(
join_type.tpe, (LeftSemi, LeftAnti)
):
raise ValueError(f"Unexpected using clause in {join_type.tpe} join")
return snowflake_supported_join_statement(
left, right, join_type, condition, use_constant_subquery_alias
)
return snowflake_supported_join_statement(left, right, join_type, condition)


def create_table_statement(
Expand Down
5 changes: 1 addition & 4 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,9 @@ def join(
join_type: JoinType,
condition: str,
source_plan: Optional[LogicalPlan],
use_constant_subquery_alias: bool,
):
return self.build_binary(
lambda x, y: join_statement(
x, y, join_type, condition, use_constant_subquery_alias
),
lambda x, y: join_statement(x, y, join_type, condition),
left,
right,
source_plan,
Expand Down
19 changes: 2 additions & 17 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,7 @@ def _create_internal(
if "paramstyle" not in self._options:
self._options["paramstyle"] = "qmark"
new_session = Session(
ServerConnection({}, conn) if conn else ServerConnection(self._options),
use_constant_subquery_alias=self._options.get(
"use_constant_subquery_alias", True
),
ServerConnection({}, conn) if conn else ServerConnection(self._options)
)
if "password" in self._options:
self._options["password"] = None
Expand All @@ -255,9 +252,7 @@ def __get__(self, obj, objtype=None):
#: and create a :class:`Session` object.
builder: SessionBuilder = SessionBuilder()

def __init__(
self, conn: ServerConnection, use_constant_subquery_alias: bool = True
) -> None:
def __init__(self, conn: ServerConnection) -> None:
if len(_active_sessions) >= 1 and is_in_stored_procedure():
raise SnowparkClientExceptionMessages.DONT_CREATE_SESSION_IN_SP()
self._conn = conn
Expand Down Expand Up @@ -293,8 +288,6 @@ def __init__(
self._sql_simplifier_enabled: bool = self._get_client_side_session_parameter(
_PYTHON_SNOWPARK_USE_SQL_SIMPLIFIER_STRING, True
)
self._use_constant_subquery_alias: bool = use_constant_subquery_alias

_logger.info("Snowpark Session information: %s", self._session_info)

def __enter__(self):
Expand Down Expand Up @@ -336,14 +329,6 @@ def close(self) -> None:
finally:
_remove_session(self)

@property
def use_constant_subquery_alias(self) -> bool:
return self._use_constant_subquery_alias

@use_constant_subquery_alias.setter
def use_constant_subquery_alias(self, value: bool) -> None:
self._use_constant_subquery_alias = value

@property
def sql_simplifier_enabled(self) -> bool:
return self._sql_simplifier_enabled
Expand Down
31 changes: 0 additions & 31 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2591,34 +2591,3 @@ def test_create_or_replace_view_with_multiple_queries(session):
match="Your dataframe may include DDL or DML operations",
):
df.create_or_replace_view("temp")


def test_nested_joins(session):
df1 = session.create_dataframe([[1, 2], [4, 5]], schema=["a", "b"])
df2 = session.create_dataframe([[1, 3], [4, 6]], schema=["c", "d"])
df3 = session.create_dataframe([[1, 4], [4, 7]], schema=["e", "f"])
res1 = sorted(
df1.join(df2)
.join(df3)
.sort("a", "b", "c", "d", "e", "f")
.select("a", "b", "c", "d", "e", "f")
.collect(),
key=lambda r: r[0],
)
res2 = sorted(
df2.join(df3)
.join(df1)
.sort("a", "b", "c", "d", "e", "f")
.select("a", "b", "c", "d", "e", "f")
.collect(),
key=lambda r: r[0],
)
res3 = sorted(
df3.join(df1)
.join(df2)
.sort("a", "b", "c", "d", "e", "f")
.select("a", "b", "c", "d", "e", "f")
.collect(),
key=lambda r: r[0],
)
assert res1 == res2 == res3
4 changes: 2 additions & 2 deletions tests/unit/test_analyzer_util_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def test_join_statement_negative():
with pytest.raises(
ValueError, match=f"Unexpected using clause in {join_type.tpe} join"
):
join_statement("", "", join_type, "", False)
join_statement("", "", join_type, "")

join_type = UsingJoin(Inner(), ["cond1"])
with pytest.raises(
ValueError, match="A join should either have using clause or a join condition"
):
join_statement("", "", join_type, "cond2", False)
join_statement("", "", join_type, "cond2")
19 changes: 0 additions & 19 deletions tests/unit/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,3 @@ def test_create_or_replace_temp_view_bad_input():
"The input of create_or_replace_temp_view() can only a str or list of strs."
in str(exc_info)
)


@pytest.mark.parametrize(
"join_type",
["inner", "leftouter", "rightouter", "fullouter", "leftsemi", "leftanti", "cross"],
)
def test_same_joins_should_generate_same_queries(join_type):
mock_connection = mock.create_autospec(ServerConnection)
mock_connection._conn = mock.MagicMock()
session = snowflake.snowpark.session.Session(mock_connection)
session._conn._telemetry_client = mock.MagicMock()
df1 = session.create_dataframe([[1, 1, "1"], [2, 2, "3"]]).to_df(
["a1", "b1", "str1"]
)
df2 = session.create_dataframe([[2, 2, "2"], [3, 3, "4"]]).to_df(
["a2", "b2", "str2"]
)

assert df1.join(df2, how=join_type).queries == df1.join(df2, how=join_type).queries

0 comments on commit 04ce69d

Please sign in to comment.