Keep in mind that assigning a value via {@link Session#setQueryTag(String)} will remove any + * current query tag state. + * + *
Example 1: + * + *
{@code
+ * session.setQueryTag("{\"key1\":\"value1\"}");
+ * session.updateQueryTag("{\"key2\":\"value2\"}");
+ * System.out.println(session.getQueryTag().get());
+ * {"key1":"value1","key2":"value2"}
+ * }
+ *
+ * Example 2: + * + *
{@code
+ * session.sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();
+ * session.updateQueryTag("{\"key2\":\"value2\"}");
+ * System.out.println(session.getQueryTag().get());
+ * {"key1":"value1","key2":"value2"}
+ * }
+ *
+ * Example 3: + * + *
{@code
+ * session.setQueryTag("");
+ * session.updateQueryTag("{\"key1\":\"value1\"}");
+ * System.out.println(session.getQueryTag().get());
+ * {"key1":"value1"}
+ * }
+ *
+ * @param queryTag A JSON encoded string that provides updates to the current query tag.
+ * @throws SnowparkClientException If the provided query tag or the query tag of the current
+ * session are not valid JSON strings; or if it could not serialize the query tag into a JSON
+ * string.
+ * @since 1.13.0
+ */
+ public void updateQueryTag(String queryTag) throws SnowparkClientException {
+ session.updateQueryTag(queryTag);
+ }
+
/**
* Creates a new DataFrame via Generator function.
*
diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala
index 633c8e42..2b6c580d 100644
--- a/src/main/scala/com/snowflake/snowpark/Session.scala
+++ b/src/main/scala/com/snowflake/snowpark/Session.scala
@@ -1,5 +1,8 @@
package com.snowflake.snowpark
+import com.fasterxml.jackson.databind.json.JsonMapper
+import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
+
import java.io.{File, FileInputStream, FileNotFoundException}
import java.net.URI
import java.sql.{Connection, Date, Time, Timestamp}
@@ -26,6 +29,7 @@ import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, Snowfl
import scala.concurrent.{ExecutionContext, Future}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
+import scala.util.Try
/**
*
@@ -61,6 +65,11 @@ import scala.reflect.runtime.universe.TypeTag
* @since 0.1.0
*/
class Session private (private[snowpark] val conn: ServerConnection) extends Logging {
+ private val jsonMapper = JsonMapper
+ .builder()
+ .addModule(DefaultScalaModule)
+ .build() :: ClassTagExtensions
+
private val STAGE_PREFIX = "@"
// URI and file name with md5
private val classpathURIs = new ConcurrentHashMap[URI, Option[String]]().asScala
@@ -321,6 +330,87 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
*/
def getQueryTag(): Option[String] = this.conn.getQueryTag()
+ /**
+ * Updates the query tag that is a JSON encoded string for the current session.
+ *
+ * Keep in mind that assigning a value via [[setQueryTag]] will remove any current query tag
+ * state.
+ *
+ * Example 1:
+ * {{{
+ * session.setQueryTag("""{"key1":"value1"}""")
+ * session.updateQueryTag("""{"key2":"value2"}""")
+ * print(session.getQueryTag().get)
+ * {"key1":"value1","key2":"value2"}
+ * }}}
+ *
+ * Example 2:
+ * {{{
+ * session.sql("""ALTER SESSION SET QUERY_TAG = '{"key1":"value1"}'""").collect()
+ * session.updateQueryTag("""{"key2":"value2"}""")
+ * print(session.getQueryTag().get)
+ * {"key1":"value1","key2":"value2"}
+ * }}}
+ *
+ * Example 3:
+ * {{{
+ * session.setQueryTag("")
+ * session.updateQueryTag("""{"key1":"value1"}""")
+ * print(session.getQueryTag().get)
+ * {"key1":"value1"}
+ * }}}
+ *
+ * @param queryTag A JSON encoded string that provides updates to the current query tag.
+ * @throws SnowparkClientException If the provided query tag or the query tag of the current
+ * session are not valid JSON strings; or if it could not
+ * serialize the query tag into a JSON string.
+ * @since 1.13.0
+ */
+ def updateQueryTag(queryTag: String): Unit = synchronized {
+ val newQueryTagMap = parseJsonString(queryTag)
+ if (newQueryTagMap.isEmpty) {
+ throw ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
+ }
+
+ var currentQueryTag = this.conn.getParameterValue("query_tag")
+ currentQueryTag = if (currentQueryTag.isEmpty) "{}" else currentQueryTag
+
+ val currentQueryTagMap = parseJsonString(currentQueryTag)
+ if (currentQueryTagMap.isEmpty) {
+ throw ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag)
+ }
+
+ val updatedQueryTagMap = currentQueryTagMap.get ++ newQueryTagMap.get
+ val updatedQueryTagStr = toJsonString(updatedQueryTagMap)
+ if (updatedQueryTagStr.isEmpty) {
+ throw ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
+ }
+
+ setQueryTag(updatedQueryTagStr.get)
+ }
+
+ /**
+ * Attempts to parse a JSON-encoded string into a [[scala.collection.immutable.Map]].
+ *
+ * @param jsonString The JSON-encoded string to parse.
+ * @return An `Option` containing the `Map` if the parsing of the JSON string was
+ * successful, or `None` otherwise.
+ */
+ private def parseJsonString(jsonString: String): Option[Map[String, Any]] = {
+ Try(jsonMapper.readValue[Map[String, Any]](jsonString)).toOption
+ }
+
+ /**
+ * Attempts to convert a [[scala.collection.immutable.Map]] into a JSON-encoded string.
+ *
+ * @param map The `Map` to convert.
+ * @return An `Option` containing the JSON-encoded string if the conversion was successful,
+ * or `None` otherwise.
+ */
+ private def toJsonString(map: Map[String, Any]): Option[String] = {
+ Try(jsonMapper.writeValueAsString(map)).toOption
+ }
+
/*
* Checks that the latest version of all jar dependencies is
* uploaded to a stage and returns the staged URLs
diff --git a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala
index e0df3d6b..ea14da1e 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/ErrorMessage.scala
@@ -159,7 +159,10 @@ private[snowpark] object ErrorMessage {
"""Invalid input argument type, the input argument type of Explode function should be either Map or Array types.
|The input argument type: %s
|""".stripMargin,
- "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.")
+ "0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.",
+ "0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.",
+ "0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s",
+ "0428" -> "Failed to serialize the query tag into a JSON string.")
// scalastyle:on
/*
@@ -409,6 +412,15 @@ private[snowpark] object ErrorMessage {
def MISC_UNSUPPORTED_GEOMETRY_FORMAT(typeName: String): SnowparkClientException =
createException("0425", typeName)
+ def MISC_INVALID_INPUT_QUERY_TAG(): SnowparkClientException =
+ createException("0426")
+
+ def MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag: String): SnowparkClientException =
+ createException("0427", currentQueryTag)
+
+ def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException =
+ createException("0428")
+
/**
* Create Snowpark client Exception.
*
diff --git a/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala b/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala
index aedbc8a6..31c8200f 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/UDFClassPath.scala
@@ -3,6 +3,7 @@ package com.snowflake.snowpark.internal
import com.fasterxml.jackson.annotation.JsonView
import com.fasterxml.jackson.core.TreeNode
import com.fasterxml.jackson.databind.JsonNode
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
import java.io.File
import java.net.{URI, URLClassLoader}
@@ -24,6 +25,8 @@ object UDFClassPath extends Logging {
val jacksonDatabindClass: Class[JsonNode] = classOf[com.fasterxml.jackson.databind.JsonNode]
val jacksonCoreClass: Class[TreeNode] = classOf[com.fasterxml.jackson.core.TreeNode]
val jacksonAnnotationClass: Class[JsonView] = classOf[com.fasterxml.jackson.annotation.JsonView]
+ val jacksonModuleScalaClass: Class[DefaultScalaModule] =
+ classOf[com.fasterxml.jackson.module.scala.DefaultScalaModule]
val jacksonJarSeq = Seq(
RequiredLibrary(
getPathForClass(jacksonDatabindClass),
@@ -33,7 +36,11 @@ object UDFClassPath extends Logging {
RequiredLibrary(
getPathForClass(jacksonAnnotationClass),
"jackson-annotation",
- jacksonAnnotationClass))
+ jacksonAnnotationClass),
+ RequiredLibrary(
+ getPathForClass(jacksonModuleScalaClass),
+ "jackson-module-scala",
+ jacksonModuleScalaClass))
/*
* Libraries required to compile java code generated by snowpark for user's lambda.
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java
index e4062ebc..e651c948 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaSessionNonStoredProcSuite.java
@@ -2,6 +2,8 @@
// to make sure all API can be accessed from public
package com.snowflake.snowpark_test;
+import static org.junit.Assert.assertThrows;
+
import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.TestUtils;
import com.snowflake.snowpark_java.*;
@@ -56,6 +58,84 @@ public void tags() {
assert !getSession().getQueryTag().isPresent();
}
+ @Test
+ public void updateQueryTagAddNewKeyValuePairs() {
+ String queryTag1 = "{\"key1\":\"value1\"}";
+ getSession().setQueryTag(queryTag1);
+
+ String queryTag2 = "{\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
+ getSession().updateQueryTag(queryTag2);
+
+ String expected =
+ "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
+ assert getSession().getQueryTag().isPresent();
+ assert getSession().getQueryTag().get().equals(expected);
+ }
+
+ @Test
+ public void updateQueryTagUpdateKeyValuePairs() {
+ String queryTag1 = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}";
+ getSession().setQueryTag(queryTag1);
+
+ String queryTag2 = "{\"key2\":\"newValue2\"}";
+ getSession().updateQueryTag(queryTag2);
+
+ String expected = "{\"key1\":\"value1\",\"key2\":\"newValue2\",\"key3\":\"value3\"}";
+ assert getSession().getQueryTag().isPresent();
+ assert getSession().getQueryTag().get().equals(expected);
+ }
+
+ @Test
+ public void updateQueryTagEmptySessionQueryTag() {
+ getSession().setQueryTag("");
+
+ String queryTag = "{\"key1\":\"value1\"}";
+ getSession().updateQueryTag(queryTag);
+
+ assert getSession().getQueryTag().isPresent();
+ assert getSession().getQueryTag().get().equals(queryTag);
+ }
+
+ @Test
+ public void updateQueryTagInvalidInputQueryTag() {
+ String queryTag = "tag1";
+
+ SnowparkClientException exception =
+ assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag));
+ assert exception
+ .getMessage()
+ .equals(
+ "Error Code: 0426, Error message: "
+ + "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.");
+ }
+
+ @Test
+ public void updateQueryTagInvalidSessionQueryTag() {
+ String queryTag1 = "tag1";
+ getSession().setQueryTag(queryTag1);
+
+ String queryTag2 = "{\"key1\":\"value1\"}";
+ SnowparkClientException exception =
+ assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag2));
+ assert exception
+ .getMessage()
+ .equals(
+ "Error Code: 0427, Error message: "
+ + "The query tag of the current session must be a valid JSON string. Current query tag: tag1");
+ }
+
+ @Test
+ public void updateQueryTagFromAlterSession() {
+ getSession().sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();
+
+ String queryTag2 = "{\"key2\":\"value2\"}";
+ getSession().updateQueryTag(queryTag2);
+
+ String expected = "{\"key1\":\"value1\",\"key2\":\"value2\"}";
+ assert getSession().getQueryTag().isPresent();
+ assert getSession().getQueryTag().get().equals(expected);
+ }
+
@Test
public void dbAndSchema() {
assert getSession()
diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
index 937b93e6..0ad6d802 100644
--- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
@@ -859,4 +859,31 @@ class ErrorMessageSuite extends FunSuite {
"Unsupported Geometry output format: KWT." +
" Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON."))
}
+
+ test("MISC_INVALID_INPUT_QUERY_TAG") {
+ val ex = ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
+ assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0426")))
+ assert(
+ ex.message.startsWith(
+ "Error Code: 0426, Error message: " +
+ "The given query tag must be a valid JSON string. " +
+ "Ensure it's correctly formatted as JSON."))
+ }
+
+ test("MISC_INVALID_CURRENT_QUERY_TAG") {
+ val ex = ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG("myTag")
+ assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0427")))
+ assert(
+ ex.message.startsWith(
+ "Error Code: 0427, Error message: The query tag of the current session " +
+ "must be a valid JSON string. Current query tag: myTag"))
+ }
+
+ test("MISC_FAILED_TO_SERIALIZE_QUERY_TAG") {
+ val ex = ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
+ assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0428")))
+ assert(
+ ex.message.startsWith(
+ "Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string."))
+ }
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
index 9b82ba7b..3a4a2dbc 100644
--- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
@@ -239,6 +239,75 @@ class SessionSuite extends SNTestBase {
assert(getParameterValue("query_tag", session) == queryTag2)
}
+ test("updateQueryTag when adding new key-value pairs") {
+ val queryTag1 = """{"key1":"value1"}"""
+ session.setQueryTag(queryTag1)
+
+ val queryTag2 = """{"key2":"value2","key3":{"key4":0},"key5":{"key6":"value6"}}"""
+ session.updateQueryTag(queryTag2)
+
+ val expected = {
+ """{"key1":"value1","key2":"value2","key3":{"key4":0},"key5":{"key6":"value6"}}"""
+ }
+ val actual = getParameterValue("query_tag", session)
+ assert(actual == expected)
+ }
+
+ test("updateQueryTag when updating an existing key-value pair") {
+ val queryTag1 = """{"key1":"value1","key2":"value2","key3":"value3"}"""
+ session.setQueryTag(queryTag1)
+
+ val queryTag2 = """{"key2":"newValue2"}"""
+ session.updateQueryTag(queryTag2)
+
+ val expected = """{"key1":"value1","key2":"newValue2","key3":"value3"}"""
+ val actual = getParameterValue("query_tag", session)
+ assert(actual == expected)
+ }
+
+ test("updateQueryTag when the query tag of the current session is empty") {
+ session.setQueryTag("")
+
+ val queryTag = """{"key1":"value1"}"""
+ session.updateQueryTag(queryTag)
+
+ val actual = getParameterValue("query_tag", session)
+ assert(actual == queryTag)
+ }
+
+ test("updateQueryTag when the given query tag is not a valid JSON") {
+ val queryTag = "tag1"
+ val exception = intercept[SnowparkClientException](session.updateQueryTag(queryTag))
+ assert(
+ exception.message.startsWith(
+ "Error Code: 0426, Error message: The given query tag must be a valid JSON string. " +
+ "Ensure it's correctly formatted as JSON."))
+ }
+
+ test("updateQueryTag when the query tag of the current session is not a valid JSON") {
+ val queryTag1 = "tag1"
+ session.setQueryTag(queryTag1)
+
+ val queryTag2 = """{"key1":"value1"}"""
+ val exception = intercept[SnowparkClientException](session.updateQueryTag(queryTag2))
+ assert(
+ exception.message.startsWith(
+ "Error Code: 0427, Error message: The query tag of the current session must be a valid " +
+ "JSON string. Current query tag: tag1"))
+ }
+
+ test("updateQueryTag when the query tag of the current session is set with an ALTER SESSION") {
+ val queryTag1 = """{"key1":"value1"}"""
+ session.sql(s"ALTER SESSION SET QUERY_TAG = '$queryTag1'").collect()
+
+ val queryTag2 = """{"key2":"value2"}"""
+ session.updateQueryTag(queryTag2)
+
+ val expected = """{"key1":"value1","key2":"value2"}"""
+ val actual = getParameterValue("query_tag", session)
+ assert(actual == expected)
+ }
+
test("Multiple queries test for query tags") {
val queryTag = randomName()
session.setQueryTag(queryTag)