Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions src/main/java/com/snowflake/snowpark_java/CaseExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,56 @@ public class CaseExpr extends Column {
}

/**
* Appends one more WHEN condition to the CASE expression.
* Appends one more WHEN condition to the CASE expression. This method handles any literal value
* and converts it into a `Column` if applies.
*
* <p><b>Example:</b>
*
* <pre>{@code
* Column result = when(col("age").lt(lit(18)), "Minor")
* .when(col("age").lt(lit(65)), "Adult")
* .otherwise("Senior");
* }</pre>
*
* @since 0.12.0
* @param condition The case condition
* @param value The result value in the given condition
* @return The result case expression
* @since 0.12.0
*/
public CaseExpr when(Column condition, Column value) {
return new CaseExpr(caseExpr.when(condition.toScalaColumn(), value.toScalaColumn()));
public CaseExpr when(Column condition, Object value) {
return new CaseExpr(caseExpr.when(condition.toScalaColumn(), toExpr(value).toScalaColumn()));
}

/**
* Sets the default result for this CASE expression.
* Sets the default result for this CASE expression. This method handles any literal value and
* converts it into a `Column` if applies.
*
* <p><b>Example:</b>
*
* <pre>{@code
* Column result = when(col("state").equal(lit("CA")), lit(1000))
* .when(col("state").equal(lit("NY")), lit(2000))
* .otherwise(1000);
* }</pre>
*
* @param value The default value, which can be any literal (e.g., String, int, boolean) or a
* `Column`.
* @return The result column.
* @since 0.12.0
* @param value The default value
* @return The result column
*/
public Column otherwise(Column value) {
return new Column(caseExpr.otherwise(value.toScalaColumn()));
public Column otherwise(Object value) {
return new Column(caseExpr.otherwise(toExpr(value).toScalaColumn()));
}

/**
* Converts any value to an Expression. If the value is already a Column, uses its expression
* directly. Otherwise, wraps it with lit() to create a Column expression.
*/
private Column toExpr(Object exp) {
if (exp instanceof Column) {
return ((Column) exp);
}

return Functions.lit(exp);
}
}
23 changes: 18 additions & 5 deletions src/main/java/com/snowflake/snowpark_java/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -4325,19 +4325,20 @@ public static Column get_path(Column col, Column path) {
* <pre>{@code
* import com.snowflake.snowpark_java.Functions;
* df.select(Functions
* .when(df.col("col").is_null, Functions.lit(1))
* .when(df.col("col").equal_to(Functions.lit(1)), Functions.lit(6))
* .when(df.col("col").is_null, 1)
* .when(df.col("col").equal_to(Functions.lit(1)), 6)
* .otherwise(Functions.lit(7)));
* }</pre>
*
* @since 0.12.0
* @param condition The condition
* @param value The result value
* @return The result column
* @since 0.12.0
*/
public static CaseExpr when(Column condition, Column value) {
public static CaseExpr when(Column condition, Object value) {
return new CaseExpr(
com.snowflake.snowpark.functions.when(condition.toScalaColumn(), value.toScalaColumn()));
com.snowflake.snowpark.functions.when(
condition.toScalaColumn(), toExpr(value).toScalaColumn()));
}

/**
Expand Down Expand Up @@ -6038,4 +6039,16 @@ private static UserDefinedFunction userDefinedFunction(
String funcName, Supplier<UserDefinedFunction> func) {
return javaUDF("Functions", funcName, "", "", func);
}

/**
* Converts any value to an Expression. If the value is already a Column, uses its expression
* directly. Otherwise, wraps it with lit() to create a Column expression.
*/
private static Column toExpr(Object exp) {
if (exp instanceof Column) {
return ((Column) exp);
}

return Functions.lit(exp);
}
}
55 changes: 50 additions & 5 deletions src/main/scala/com/snowflake/snowpark/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -764,24 +764,69 @@ class CaseExpr private[snowpark] (branches: Seq[(Expression, Expression)])
/**
* Appends one more WHEN condition to the CASE expression.
*
* This method handles any literal value and converts it into a `Column`.
*
* ===Example===
* {{{
* val df = session.sql("SELECT * FROM values (10), (25), (65), (70) as T(age)")
* val result = df.select(
* when(col("age") < lit(18), "Minor")
* .when(col("age") < lit(65), lit("Adult"))
* .otherwise("Senior")
* )
* // The second when condition will be "Adult" for rows where age >= 18 and age < 65
* }}}
*
* @param condition
* The case condition.
* @param value
* The result value, which can be any literal (e.g., String, Int, Boolean) or a `Column`.
* @return
* The result case expression.
* @since 0.2.0
*/
def when(condition: Column, value: Column): CaseExpr =
new CaseExpr(branches :+ ((condition.expr, value.expr)))
def when(condition: Column, value: Any): CaseExpr =
new CaseExpr(branches :+ (condition.expr, toExpr(value)))

/**
* Sets the default result for this CASE expression.
*
* This method handles any literal value and converts it into a `Column` using `lit()`.
*
* ===Example===
* {{{
* val df = session.sql("SELECT * FROM values (10), (25), (65), (70) as T(age)")
* val result = df.select(
* when(col("age") < lit(18), "Minor")
* .when(col("age") < lit(65), lit("Adult"))
* .otherwise("Senior")
* )
* // The age_category column will be "Senior" for rows where age >= 65
* }}}
*
* @param value
* The default value, which can be any literal (e.g., String, Int, Boolean) or a `Column`.
* @return
* The result column.
* @since 0.2.0
*/
def otherwise(value: Column): Column = withExpr {
CaseWhen(branches, Option(value.expr))
def otherwise(value: Any): Column = withExpr {
CaseWhen(branches, Option(toExpr(value)))
}

/**
* Sets the default result for this CASE expression. Alias for [[otherwise]].
*
* @since 0.2.0
*/
def `else`(value: Column): Column = otherwise(value)
def `else`(value: Any): Column = otherwise(value)

/**
* Converts any value to an Expression. If the value is already a Column, uses its expression
* directly. Otherwise, wraps it with lit() to create a Column expression.
*/
private def toExpr(exp: Any) = exp match {
case c: Column => c.expr
case _ => lit(exp).expr
}
}
36 changes: 27 additions & 9 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3594,21 +3594,31 @@ object functions {
* Works like a cascading if-then-else statement. A series of conditions are evaluated in
* sequence. When a condition evaluates to TRUE, the evaluation stops and the associated result
* (after THEN) is returned. If none of the conditions evaluate to TRUE, then the result after the
* optional OTHERWISE is returned, if present; otherwise NULL is returned. For Example:
* optional OTHERWISE is returned, if present; otherwise NULL is returned.
*
* ===Example===
* {{{
* import functions._
* df.select(
* when(col("col").is_null, lit(1))
* .when(col("col") === 1, lit(2))
* .otherwise(lit(3))
* )
* import functions._
* val df = session.sql("SELECT * FROM values (null, 5), (1, 10), (2, 15) as T(col, numeric_col)")
* val result = df.select(
* when(col("col").is_null, lit(1))
* .when(col("col") === 1, lit(2))
* .when(col("col") === 1, col("numeric_col") * 0.10)
* .otherwise(lit(3))
* )
* }}}
*
* @param condition
* The case condition.
* @param value
* The result value, which can be any literal (e.g., String, Int, Boolean) or a `Column`.
* @return
* The result case expression.
* @group con_func
* @since 0.2.0
*/
def when(condition: Column, value: Column): CaseExpr =
new CaseExpr(Seq((condition.expr, value.expr)))
def when(condition: Column, value: Any): CaseExpr =
new CaseExpr(Seq((condition.expr, toExpr(value))))

/**
* Returns one of two specified expressions, depending on a condition.
Expand Down Expand Up @@ -5672,4 +5682,12 @@ object functions {
"")(func)
}

/**
* Converts any value to an Expression. If the value is already a Column, uses its expression
* directly. Otherwise, wraps it with lit() to create a Column expression.
*/
private def toExpr(exp: Any) = exp match {
case c: Column => c.expr
case _ => lit(exp).expr
}
}
24 changes: 24 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaColumnSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,30 @@ public void caseWhen() {
Row.create((Object) null),
Row.create(5)
});

// Handling no column type values
checkAnswer(
df.select(
Functions.when(df.col("a").is_null(), 5)
.when(df.col("a").equal_to(Functions.lit(1)), 6)
.otherwise(7)
.as("a")),
new Row[] {Row.create(5), Row.create(7), Row.create(6), Row.create(7), Row.create(5)});

// Handling null values
checkAnswer(
df.select(
Functions.when(df.col("a").is_null(), null)
.when(df.col("a").equal_to(Functions.lit(1)), null)
.otherwise(null)
.as("a")),
new Row[] {
Row.create((Object) null),
Row.create((Object) null),
Row.create((Object) null),
Row.create((Object) null),
Row.create((Object) null)
});
}

@Test
Expand Down
45 changes: 45 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,51 @@ class ColumnSuite extends TestData {
.as("a")),
Seq(Row(5), Row(7), Row(6), Row(7), Row(5)))

// no column typed value
checkAnswer(
nullData1.select(
functions
.when(col("a").is_null, lit(1))
.when(col("a") === 1, col("a") / 2)
.when(col("a") === 2, col("a") * 2)
.when(col("a") === 3, pow(col("a"), 2))
.as("a")),
Seq(Row(0.5), Row(1.0), Row(1.0), Row(4.0), Row(9.0)))

checkAnswer(
nullData1.select(
functions
.when(col("a").is_null, "null_value")
.when(col("a") <= 2, "lower or equal than two")
.when(col("a") >= 3, "greater than two")
.as("a")),
Seq(
Row("greater than two"),
Row("lower or equal than two"),
Row("lower or equal than two"),
Row("null_value"),
Row("null_value")))

// No column otherwise
checkAnswer(
nullData1.select(
functions
.when(col("a").is_null, lit(5))
.when(col("a") === 1, lit(6))
.otherwise(7)
.as("a")),
Seq(Row(5), Row(7), Row(6), Row(7), Row(5)))

// Handling nulls
checkAnswer(
nullData1.select(
functions
.when(col("a").is_null, null)
.when(col("a") === 1, null)
.otherwise(null)
.as("a")),
Seq(Row(null), Row(null), Row(null), Row(null), Row(null)))

// empty otherwise
checkAnswer(
nullData1.select(
Expand Down
Loading