Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fips-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -144,6 +145,11 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
7 changes: 7 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
<scalaPluginVersion>4.3.0</scalaPluginVersion>
<jackson.version>2.13.2</jackson.version>
<jackson.databind.version>2.13.4.2</jackson.databind.version>
<jackson.module.scala.version>2.13.5</jackson.module.scala.version>

</properties>
<dependencyManagement>
<dependencies>
Expand Down Expand Up @@ -131,6 +133,11 @@
<artifactId>jackson-annotations</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-scala_2.12</artifactId>
<version>${jackson.module.scala.version}</version>
</dependency>

<!-- Test -->
<dependency>
Expand Down
44 changes: 44 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Session.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.snowflake.snowpark_java;

import com.snowflake.snowpark.PublicPreview;
import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.internal.JavaUtils;
import com.snowflake.snowpark_java.types.InternalUtils;
import com.snowflake.snowpark_java.types.StructType;
Expand Down Expand Up @@ -297,6 +298,49 @@ public void unsetQueryTag() {
session.unsetQueryTag();
}

/**
* Updates the query tag that is a JSON encoded string for the current session.
*
* <p>Keep in mind that assigning a value via {@link Session#setQueryTag(String)} will remove any
* current query tag state.
*
* <p>Example 1:
*
* <pre>{@code
* session.setQueryTag("{\"key1\":\"value1\"}");
* session.updateQueryTag("{\"key2\":\"value2\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1","key2":"value2"}
* }</pre>
*
* <p>Example 2:
*
* <pre>{@code
* session.sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();
* session.updateQueryTag("{\"key2\":\"value2\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1","key2":"value2"}
* }</pre>
*
* <p>Example 3:
*
* <pre>{@code
* session.setQueryTag("");
* session.updateQueryTag("{\"key1\":\"value1\"}");
* System.out.println(session.getQueryTag().get());
* {"key1":"value1"}
* }</pre>
*
* @param queryTag A JSON encoded string that provides updates to the current query tag.
* @throws SnowparkClientException If the provided query tag or the query tag of the current
* session are not valid JSON strings; or if it could not serialize the query tag into a JSON
* string.
* @since 1.13.0
*/
public void updateQueryTag(String queryTag) throws SnowparkClientException {
session.updateQueryTag(queryTag);
}

/**
* Creates a new DataFrame via Generator function.
*
Expand Down
90 changes: 90 additions & 0 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.snowflake.snowpark

import com.fasterxml.jackson.databind.json.JsonMapper
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}

import java.io.{File, FileInputStream, FileNotFoundException}
import java.net.URI
import java.sql.{Connection, Date, Time, Timestamp}
Expand All @@ -26,6 +29,7 @@ import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, Snowfl
import scala.concurrent.{ExecutionContext, Future}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

/**
*
Expand Down Expand Up @@ -61,6 +65,11 @@ import scala.reflect.runtime.universe.TypeTag
* @since 0.1.0
*/
class Session private (private[snowpark] val conn: ServerConnection) extends Logging {
private val jsonMapper = JsonMapper
.builder()
.addModule(DefaultScalaModule)
.build() :: ClassTagExtensions

private val STAGE_PREFIX = "@"
// URI and file name with md5
private val classpathURIs = new ConcurrentHashMap[URI, Option[String]]().asScala
Expand Down Expand Up @@ -321,6 +330,87 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
*/
def getQueryTag(): Option[String] = this.conn.getQueryTag()

/**
* Updates the query tag that is a JSON encoded string for the current session.
*
* Keep in mind that assigning a value via [[setQueryTag]] will remove any current query tag
* state.
*
* Example 1:
* {{{
* session.setQueryTag("""{"key1":"value1"}""")
* session.updateQueryTag("""{"key2":"value2"}""")
* print(session.getQueryTag().get)
* {"key1":"value1","key2":"value2"}
* }}}
*
* Example 2:
* {{{
* session.sql("""ALTER SESSION SET QUERY_TAG = '{"key1":"value1"}'""").collect()
* session.updateQueryTag("""{"key2":"value2"}""")
* print(session.getQueryTag().get)
* {"key1":"value1","key2":"value2"}
* }}}
*
* Example 3:
* {{{
* session.setQueryTag("")
* session.updateQueryTag("""{"key1":"value1"}""")
* print(session.getQueryTag().get)
* {"key1":"value1"}
* }}}
*
* @param queryTag A JSON encoded string that provides updates to the current query tag.
* @throws SnowparkClientException If the provided query tag or the query tag of the current
* session are not valid JSON strings; or if it could not
* serialize the query tag into a JSON string.
* @since 1.13.0
*/
def updateQueryTag(queryTag: String): Unit = synchronized {
val newQueryTagMap = parseJsonString(queryTag)
if (newQueryTagMap.isEmpty) {
throw ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
}

var currentQueryTag = this.conn.getParameterValue("query_tag")
currentQueryTag = if (currentQueryTag.isEmpty) "{}" else currentQueryTag

val currentQueryTagMap = parseJsonString(currentQueryTag)
if (currentQueryTagMap.isEmpty) {
throw ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag)
}

val updatedQueryTagMap = currentQueryTagMap.get ++ newQueryTagMap.get
val updatedQueryTagStr = toJsonString(updatedQueryTagMap)
if (updatedQueryTagStr.isEmpty) {
throw ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
}

setQueryTag(updatedQueryTagStr.get)
}

/**
* Attempts to parse a JSON-encoded string into a [[scala.collection.immutable.Map]].
*
* @param jsonString The JSON-encoded string to parse.
* @return An `Option` containing the `Map` if the parsing of the JSON string was
* successful, or `None` otherwise.
*/
private def parseJsonString(jsonString: String): Option[Map[String, Any]] = {
Try(jsonMapper.readValue[Map[String, Any]](jsonString)).toOption
}

/**
* Attempts to convert a [[scala.collection.immutable.Map]] into a JSON-encoded string.
*
* @param map The `Map` to convert.
* @return An `Option` containing the JSON-encoded string if the conversion was successful,
* or `None` otherwise.
*/
private def toJsonString(map: Map[String, Any]): Option[String] = {
Try(jsonMapper.writeValueAsString(map)).toOption
}

/*
* Checks that the latest version of all jar dependencies is
* uploaded to a stage and returns the staged URLs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ private[snowpark] object ErrorMessage {
"""Invalid input argument type, the input argument type of Explode function should be either Map or Array types.
|The input argument type: %s
|""".stripMargin,
"0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.")
"0425" -> "Unsupported Geometry output format: %s. Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON.",
"0426" -> "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.",
"0427" -> "The query tag of the current session must be a valid JSON string. Current query tag: %s",
"0428" -> "Failed to serialize the query tag into a JSON string.")
// scalastyle:on

/*
Expand Down Expand Up @@ -409,6 +412,15 @@ private[snowpark] object ErrorMessage {
def MISC_UNSUPPORTED_GEOMETRY_FORMAT(typeName: String): SnowparkClientException =
createException("0425", typeName)

def MISC_INVALID_INPUT_QUERY_TAG(): SnowparkClientException =
createException("0426")

def MISC_INVALID_CURRENT_QUERY_TAG(currentQueryTag: String): SnowparkClientException =
createException("0427", currentQueryTag)

def MISC_FAILED_TO_SERIALIZE_QUERY_TAG(): SnowparkClientException =
createException("0428")

/**
* Create Snowpark client Exception.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.snowflake.snowpark.internal
import com.fasterxml.jackson.annotation.JsonView
import com.fasterxml.jackson.core.TreeNode
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.module.scala.DefaultScalaModule

import java.io.File
import java.net.{URI, URLClassLoader}
Expand All @@ -24,6 +25,8 @@ object UDFClassPath extends Logging {
val jacksonDatabindClass: Class[JsonNode] = classOf[com.fasterxml.jackson.databind.JsonNode]
val jacksonCoreClass: Class[TreeNode] = classOf[com.fasterxml.jackson.core.TreeNode]
val jacksonAnnotationClass: Class[JsonView] = classOf[com.fasterxml.jackson.annotation.JsonView]
val jacksonModuleScalaClass: Class[DefaultScalaModule] =
classOf[com.fasterxml.jackson.module.scala.DefaultScalaModule]
val jacksonJarSeq = Seq(
RequiredLibrary(
getPathForClass(jacksonDatabindClass),
Expand All @@ -33,7 +36,11 @@ object UDFClassPath extends Logging {
RequiredLibrary(
getPathForClass(jacksonAnnotationClass),
"jackson-annotation",
jacksonAnnotationClass))
jacksonAnnotationClass),
RequiredLibrary(
getPathForClass(jacksonModuleScalaClass),
"jackson-module-scala",
jacksonModuleScalaClass))

/*
* Libraries required to compile java code generated by snowpark for user's lambda.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// to make sure all API can be accessed from public
package com.snowflake.snowpark_test;

import static org.junit.Assert.assertThrows;

import com.snowflake.snowpark.SnowparkClientException;
import com.snowflake.snowpark.TestUtils;
import com.snowflake.snowpark_java.*;
Expand Down Expand Up @@ -56,6 +58,84 @@ public void tags() {
assert !getSession().getQueryTag().isPresent();
}

@Test
public void updateQueryTagAddNewKeyValuePairs() {
String queryTag1 = "{\"key1\":\"value1\"}";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
getSession().updateQueryTag(queryTag2);

String expected =
"{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":{\"key4\":0},\"key5\":{\"key6\":\"value6\"}}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void updateQueryTagUpdateKeyValuePairs() {
String queryTag1 = "{\"key1\":\"value1\",\"key2\":\"value2\",\"key3\":\"value3\"}";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key2\":\"newValue2\"}";
getSession().updateQueryTag(queryTag2);

String expected = "{\"key1\":\"value1\",\"key2\":\"newValue2\",\"key3\":\"value3\"}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void updateQueryTagEmptySessionQueryTag() {
getSession().setQueryTag("");

String queryTag = "{\"key1\":\"value1\"}";
getSession().updateQueryTag(queryTag);

assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(queryTag);
}

@Test
public void updateQueryTagInvalidInputQueryTag() {
String queryTag = "tag1";

SnowparkClientException exception =
assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag));
assert exception
.getMessage()
.equals(
"Error Code: 0426, Error message: "
+ "The given query tag must be a valid JSON string. Ensure it's correctly formatted as JSON.");
}

@Test
public void updateQueryTagInvalidSessionQueryTag() {
String queryTag1 = "tag1";
getSession().setQueryTag(queryTag1);

String queryTag2 = "{\"key1\":\"value1\"}";
SnowparkClientException exception =
assertThrows(SnowparkClientException.class, () -> getSession().updateQueryTag(queryTag2));
assert exception
.getMessage()
.equals(
"Error Code: 0427, Error message: "
+ "The query tag of the current session must be a valid JSON string. Current query tag: tag1");
}

@Test
public void updateQueryTagFromAlterSession() {
getSession().sql("ALTER SESSION SET QUERY_TAG = '{\"key1\":\"value1\"}'").collect();

String queryTag2 = "{\"key2\":\"value2\"}";
getSession().updateQueryTag(queryTag2);

String expected = "{\"key1\":\"value1\",\"key2\":\"value2\"}";
assert getSession().getQueryTag().isPresent();
assert getSession().getQueryTag().get().equals(expected);
}

@Test
public void dbAndSchema() {
assert getSession()
Expand Down
27 changes: 27 additions & 0 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -859,4 +859,31 @@ class ErrorMessageSuite extends FunSuite {
"Unsupported Geometry output format: KWT." +
" Please set session parameter GEOMETRY_OUTPUT_FORMAT to GeoJSON."))
}

test("MISC_INVALID_INPUT_QUERY_TAG") {
val ex = ErrorMessage.MISC_INVALID_INPUT_QUERY_TAG()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0426")))
assert(
ex.message.startsWith(
"Error Code: 0426, Error message: " +
"The given query tag must be a valid JSON string. " +
"Ensure it's correctly formatted as JSON."))
}

test("MISC_INVALID_CURRENT_QUERY_TAG") {
val ex = ErrorMessage.MISC_INVALID_CURRENT_QUERY_TAG("myTag")
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0427")))
assert(
ex.message.startsWith(
"Error Code: 0427, Error message: The query tag of the current session " +
"must be a valid JSON string. Current query tag: myTag"))
}

test("MISC_FAILED_TO_SERIALIZE_QUERY_TAG") {
val ex = ErrorMessage.MISC_FAILED_TO_SERIALIZE_QUERY_TAG()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0428")))
assert(
ex.message.startsWith(
"Error Code: 0428, Error message: Failed to serialize the query tag into a JSON string."))
}
}
Loading