diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index d463af6883280..a8a88d63b1a63 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql import java.io.Closeable +import java.net.URI import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.AtomicLong @@ -417,6 +418,37 @@ class SparkSession private[sql] ( execute(command) } + /** + * Add a single artifact to the client session. + * + * Currently only local files with extensions .jar and .class are supported. + * + * @since 3.4.0 + */ + @Experimental + def addArtifact(path: String): Unit = client.addArtifact(path) + + /** + * Add a single artifact to the client session. + * + * Currently only local files with extensions .jar and .class are supported. + * + * @since 3.4.0 + */ + @Experimental + def addArtifact(uri: URI): Unit = client.addArtifact(uri) + + /** + * Add one or more artifacts to the session. + * + * Currently only local files with extensions .jar and .class are supported. + * + * @since 3.4.0 + */ + @Experimental + @scala.annotation.varargs + def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) + /** * This resets the plan id generator so we can produce plans that are comparable. * diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala new file mode 100644 index 0000000000000..ead500a53e639 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.io.InputStream +import java.net.URI +import java.nio.file.{Files, Path, Paths} +import java.util.zip.{CheckedInputStream, CRC32} + +import scala.collection.mutable +import scala.concurrent.Promise +import scala.concurrent.duration.Duration +import scala.util.control.NonFatal + +import Artifact._ +import com.google.protobuf.ByteString +import io.grpc.ManagedChannel +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.AddArtifactsResponse +import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * The Artifact Manager is responsible for handling and transferring artifacts from the local + * client to the server (local/remote). + * @param userContext + * @param channel + */ +class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) { + // Using the midpoint recommendation of 32KiB for chunk size as specified in + // https://github.com/grpc/grpc.github.io/issues/371. + private val CHUNK_SIZE: Int = 32 * 1024 + + private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel) + + /** + * Add a single artifact to the session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifact(path: String): Unit = { + addArtifact(Utils.resolveURI(path)) + } + + private def parseArtifacts(uri: URI): Seq[Artifact] = { + // Currently only local files with extensions .jar and .class are supported. + uri.getScheme match { + case "file" => + val path = Paths.get(uri) + val artifact = path.getFileName.toString match { + case jar if jar.endsWith(".jar") => + newJarArtifact(path.getFileName, new LocalFile(path)) + case cf if cf.endsWith(".class") => + newClassArtifact(path.getFileName, new LocalFile(path)) + case other => + throw new UnsupportedOperationException(s"Unsuppoted file format: $other") + } + Seq[Artifact](artifact) + + case other => + throw new UnsupportedOperationException(s"Unsupported scheme: $other") + } + } + + /** + * Add a single artifact to the session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifact(uri: URI): Unit = addArtifacts(parseArtifacts(uri)) + + /** + * Add multiple artifacts to the session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts)) + + /** + * Add a number of artifacts to the session. + */ + private def addArtifacts(artifacts: Iterable[Artifact]): Unit = { + val promise = Promise[Seq[ArtifactSummary]] + val responseHandler = new StreamObserver[proto.AddArtifactsResponse] { + private val summaries = mutable.Buffer.empty[ArtifactSummary] + override def onNext(v: AddArtifactsResponse): Unit = { + v.getArtifactsList.forEach { summary => + summaries += summary + } + } + override def onError(throwable: Throwable): Unit = { + promise.failure(throwable) + } + override def onCompleted(): Unit = { + promise.success(summaries.toSeq) + } + } + val stream = stub.addArtifacts(responseHandler) + val currentBatch = mutable.Buffer.empty[Artifact] + var currentBatchSize = 0L + + def addToBatch(dep: Artifact, size: Long): Unit = { + currentBatch += dep + currentBatchSize += size + } + + def writeBatch(): Unit = { + addBatchedArtifacts(currentBatch.toSeq, stream) + currentBatch.clear() + currentBatchSize = 0 + } + + artifacts.iterator.foreach { artifact => + val data = artifact.storage + val size = data.size + if (size > CHUNK_SIZE) { + // Payload can either be a batch OR a single chunked artifact. Write batch if non-empty + // before chunking current artifact. + if (currentBatch.nonEmpty) { + writeBatch() + } + addChunkedArtifact(artifact, stream) + } else { + if (currentBatchSize + size > CHUNK_SIZE) { + writeBatch() + } + addToBatch(artifact, size) + } + } + if (currentBatch.nonEmpty) { + writeBatch() + } + stream.onCompleted() + ThreadUtils.awaitResult(promise.future, Duration.Inf) + // TODO(SPARK-42658): Handle responses containing CRC failures. + } + + /** + * Add a batch of artifacts to the stream. All the artifacts in this call are packaged into a + * single [[proto.AddArtifactsRequest]]. + */ + private def addBatchedArtifacts( + artifacts: Seq[Artifact], + stream: StreamObserver[proto.AddArtifactsRequest]): Unit = { + val builder = proto.AddArtifactsRequest + .newBuilder() + .setUserContext(userContext) + artifacts.foreach { artifact => + val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32) + try { + val data = proto.AddArtifactsRequest.ArtifactChunk + .newBuilder() + .setData(ByteString.readFrom(in)) + .setCrc(in.getChecksum.getValue) + + builder.getBatchBuilder + .addArtifactsBuilder() + .setName(artifact.path.toString) + .setData(data) + .build() + } catch { + case NonFatal(e) => + stream.onError(e) + throw e + } finally { + in.close() + } + } + stream.onNext(builder.build()) + } + + /** + * Read data from an [[InputStream]] in pieces of `chunkSize` bytes and convert to + * protobuf-compatible [[ByteString]]. + * @param in + * @return + */ + private def readNextChunk(in: InputStream): ByteString = { + val buf = new Array[Byte](CHUNK_SIZE) + var bytesRead = 0 + var count = 0 + while (count != -1 && bytesRead < CHUNK_SIZE) { + count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead) + if (count != -1) { + bytesRead += count + } + } + if (bytesRead == 0) ByteString.empty() + else ByteString.copyFrom(buf, 0, bytesRead) + } + + /** + * Add a artifact in chunks to the stream. The artifact's data is spread out over multiple + * [[proto.AddArtifactsRequest requests]]. + */ + private def addChunkedArtifact( + artifact: Artifact, + stream: StreamObserver[proto.AddArtifactsRequest]): Unit = { + val builder = proto.AddArtifactsRequest + .newBuilder() + .setUserContext(userContext) + + val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32) + try { + // First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`). + // Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`). + val artifactChunkBuilder = proto.AddArtifactsRequest.ArtifactChunk.newBuilder() + var dataChunk = readNextChunk(in) + // Integer division that rounds up to the nearest whole number. + def getNumChunks(size: Long): Long = (size + (CHUNK_SIZE - 1)) / CHUNK_SIZE + + builder.getBeginChunkBuilder + .setName(artifact.path.toString) + .setTotalBytes(artifact.size) + .setNumChunks(getNumChunks(artifact.size)) + .setInitialChunk( + artifactChunkBuilder + .setData(dataChunk) + .setCrc(in.getChecksum.getValue)) + stream.onNext(builder.build()) + in.getChecksum.reset() + builder.clearBeginChunk() + + dataChunk = readNextChunk(in) + // Consume stream in chunks until there is no data left to read. + while (!dataChunk.isEmpty) { + artifactChunkBuilder.setData(dataChunk).setCrc(in.getChecksum.getValue) + builder.setChunk(artifactChunkBuilder.build()) + stream.onNext(builder.build()) + in.getChecksum.reset() + builder.clearChunk() + dataChunk = readNextChunk(in) + } + } catch { + case NonFatal(e) => + stream.onError(e) + throw e + } finally { + in.close() + } + } +} + +class Artifact private (val path: Path, val storage: LocalData) { + require(!path.isAbsolute, s"Bad path: $path") + + lazy val size: Long = storage match { + case localData: LocalData => localData.size + } +} + +object Artifact { + val CLASS_PREFIX: Path = Paths.get("classes") + val JAR_PREFIX: Path = Paths.get("jars") + + def newJarArtifact(fileName: Path, storage: LocalData): Artifact = { + newArtifact(JAR_PREFIX, ".jar", fileName, storage) + } + + def newClassArtifact(fileName: Path, storage: LocalData): Artifact = { + newArtifact(CLASS_PREFIX, ".class", fileName, storage) + } + + private def newArtifact( + prefix: Path, + requiredSuffix: String, + fileName: Path, + storage: LocalData): Artifact = { + require(!fileName.isAbsolute) + require(fileName.toString.endsWith(requiredSuffix)) + new Artifact(prefix.resolve(fileName), storage) + } + + /** + * Payload stored on this machine. + */ + sealed trait LocalData { + def stream: InputStream + def size: Long + } + + /** + * Payload stored in a local file. + */ + class LocalFile(val path: Path) extends LocalData { + override def size: Long = Files.size(path) + override def stream: InputStream = Files.newInputStream(path) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index cdc0b381a4474..599aab441deb5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -36,6 +36,8 @@ private[sql] class SparkConnectClient( private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel) + /** * Placeholder method. * @return @@ -147,6 +149,27 @@ private[sql] class SparkConnectClient( analyze(request) } + /** + * Add a single artifact to the client session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifact(path: String): Unit = artifactManager.addArtifact(path) + + /** + * Add a single artifact to the client session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifact(uri: URI): Unit = artifactManager.addArtifact(uri) + + /** + * Add multiple artifacts to the session. + * + * Currently only local files with extensions .jar and .class are supported. + */ + def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri) + /** * Shutdown the client's connection to the server. */ diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md new file mode 100644 index 0000000000000..df9af41064444 --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/README.md @@ -0,0 +1,5 @@ +The CRCs for a specific file are stored in a text file with the same name (excluding the original extension). + +The CRCs are calculated for data chunks of `32768 bytes` (individual CRCs) and are newline delimited. + +The CRCs were calculated using https://simplycalc.com/crc32-file.php \ No newline at end of file diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt new file mode 100644 index 0000000000000..3e89631dea57c --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/junitLargeJar.txt @@ -0,0 +1,12 @@ +902183889 +2415704507 +1084811487 +1951510 +1158852476 +2003120166 +3026803842 +3850244775 +3409267044 +652109216 +104029242 +3019434266 \ No newline at end of file diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt new file mode 100644 index 0000000000000..531f98ce9a225 --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFile.txt @@ -0,0 +1 @@ +1935693963 \ No newline at end of file diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt new file mode 100644 index 0000000000000..531f98ce9a225 --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallClassFileDup.txt @@ -0,0 +1 @@ +1935693963 \ No newline at end of file diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt new file mode 100644 index 0000000000000..df32adcce7ab5 --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/artifact-tests/crc/smallJar.txt @@ -0,0 +1 @@ +1631702900 \ No newline at end of file diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar new file mode 100755 index 0000000000000..6da55d8b8520d Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/junitLargeJar.jar differ diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class new file mode 100755 index 0000000000000..e796030e471b0 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFile.class differ diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class new file mode 100755 index 0000000000000..e796030e471b0 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallClassFileDup.class differ diff --git a/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar new file mode 100755 index 0000000000000..3c4930e8e9549 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/artifact-tests/smallJar.jar differ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala old mode 100644 new mode 100755 index 67dc92a747233..6e9583ae725eb --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -68,27 +68,7 @@ class PlanGenerationTestSuite // Borrowed from SparkFunSuite private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" - // Borrowed from SparkFunSuite - private def getWorkspaceFilePath(first: String, more: String*): Path = { - if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) { - fail("spark.test.home or SPARK_HOME is not set.") - } - val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) - java.nio.file.Paths.get(sparkHome, first +: more: _*) - } - - protected val baseResourcePath: Path = { - getWorkspaceFilePath( - "connector", - "connect", - "common", - "src", - "test", - "resources", - "query-tests").toAbsolutePath - } - - protected val queryFilePath: Path = baseResourcePath.resolve("queries") + protected val queryFilePath: Path = commonResourcePath.resolve("queries") // A relative path to /connector/connect/server, used by `ProtoToParsedPlanTestSuite` to run // with the datasource. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala new file mode 100644 index 0000000000000..adb2b3f190811 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.io.InputStream +import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.TimeUnit + +import collection.JavaConverters._ +import com.google.protobuf.ByteString +import io.grpc.{ManagedChannel, Server} +import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.AddArtifactsRequest +import org.apache.spark.sql.connect.client.util.ConnectFunSuite + +class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach { + + private var client: SparkConnectClient = _ + private var service: DummySparkConnectService = _ + private var server: Server = _ + private var artifactManager: ArtifactManager = _ + private var channel: ManagedChannel = _ + + private def startDummyServer(): Unit = { + service = new DummySparkConnectService() + server = InProcessServerBuilder + .forName(getClass.getName) + .addService(service) + .build() + server.start() + } + + private def createArtifactManager(): Unit = { + channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build() + artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel) + } + + override def beforeEach(): Unit = { + super.beforeEach() + startDummyServer() + createArtifactManager() + client = null + } + + override def afterEach(): Unit = { + if (server != null) { + server.shutdownNow() + assert(server.awaitTermination(5, TimeUnit.SECONDS), "server failed to shutdown") + } + + if (channel != null) { + channel.shutdownNow() + } + + if (client != null) { + client.shutdown() + } + } + + private val CHUNK_SIZE: Int = 32 * 1024 + protected def artifactFilePath: Path = baseResourcePath.resolve("artifact-tests") + protected def artifactCrcPath: Path = artifactFilePath.resolve("crc") + + private def getCrcValues(filePath: Path): Seq[Long] = { + val fileName = filePath.getFileName.toString + val crcFileName = fileName.split('.').head + ".txt" + Files + .readAllLines(artifactCrcPath.resolve(crcFileName)) + .asScala + .map(_.toLong) + } + + /** + * Check if the data sent to the server (stored in `artifactChunk`) is equivalent to the local + * data at `localPath`. + * @param artifactChunk + * @param localPath + */ + private def assertFileDataEquality( + artifactChunk: AddArtifactsRequest.ArtifactChunk, + localPath: Path): Unit = { + val localData = ByteString.readFrom(Files.newInputStream(localPath)) + val expectedCrc = getCrcValues(localPath).head + assert(artifactChunk.getData == localData) + assert(artifactChunk.getCrc == expectedCrc) + } + + private def singleChunkArtifactTest(path: String): Unit = { + test(s"Single Chunk Artifact - $path") { + val artifactPath = artifactFilePath.resolve(path) + artifactManager.addArtifact(artifactPath.toString) + + val receivedRequests = service.getAndClearLatestAddArtifactRequests() + // Single `AddArtifactRequest` + assert(receivedRequests.size == 1) + + val request = receivedRequests.head + assert(request.hasBatch) + + val batch = request.getBatch + // Single artifact in batch + assert(batch.getArtifactsList.size() == 1) + + val singleChunkArtifact = batch.getArtifacts(0) + val namePrefix = artifactPath.getFileName.toString match { + case jar if jar.endsWith(".jar") => "jars" + case cf if cf.endsWith(".class") => "classes" + } + assert(singleChunkArtifact.getName.equals(namePrefix + "/" + path)) + assertFileDataEquality(singleChunkArtifact.getData, artifactPath) + } + } + + singleChunkArtifactTest("smallClassFile.class") + + singleChunkArtifactTest("smallJar.jar") + + private def readNextChunk(in: InputStream): ByteString = { + val buf = new Array[Byte](CHUNK_SIZE) + var bytesRead = 0 + var count = 0 + while (count != -1 && bytesRead < CHUNK_SIZE) { + count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead) + if (count != -1) { + bytesRead += count + } + } + if (bytesRead == 0) ByteString.empty() + else ByteString.copyFrom(buf, 0, bytesRead) + } + + /** + * Reads data in a chunk of `CHUNK_SIZE` bytes from `in` and verify equality with server-side + * data stored in `chunk`. + * @param in + * @param chunk + * @return + */ + private def checkChunksDataAndCrc( + filePath: Path, + chunks: Seq[AddArtifactsRequest.ArtifactChunk]): Unit = { + val in = Files.newInputStream(filePath) + val crcs = getCrcValues(filePath) + chunks.zip(crcs).foreach { case (chunk, expectedCrc) => + val expectedData = readNextChunk(in) + chunk.getData == expectedData && chunk.getCrc == expectedCrc + } + } + + test("Chunked Artifact - junitLargeJar.jar") { + val artifactPath = artifactFilePath.resolve("junitLargeJar.jar") + artifactManager.addArtifact(artifactPath.toString) + // Expected chunks = roundUp( file_size / chunk_size) = 12 + // File size of `junitLargeJar.jar` is 384581 bytes. + val expectedChunks = (384581 + (CHUNK_SIZE - 1)) / CHUNK_SIZE + val receivedRequests = service.getAndClearLatestAddArtifactRequests() + assert(384581 == Files.size(artifactPath)) + assert(receivedRequests.size == expectedChunks) + assert(receivedRequests.head.hasBeginChunk) + val beginChunkRequest = receivedRequests.head.getBeginChunk + assert(beginChunkRequest.getName == "jars/junitLargeJar.jar") + assert(beginChunkRequest.getTotalBytes == 384581) + assert(beginChunkRequest.getNumChunks == expectedChunks) + val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++ + receivedRequests.drop(1).map(_.getChunk) + checkChunksDataAndCrc(artifactPath, dataChunks) + } + + test("Batched SingleChunkArtifacts") { + val file1 = artifactFilePath.resolve("smallClassFile.class").toUri + val file2 = artifactFilePath.resolve("smallJar.jar").toUri + artifactManager.addArtifacts(Seq(file1, file2)) + val receivedRequests = service.getAndClearLatestAddArtifactRequests() + // Single request containing 2 artifacts. + assert(receivedRequests.size == 1) + + val request = receivedRequests.head + assert(request.hasBatch) + + val batch = request.getBatch + assert(batch.getArtifactsList.size() == 2) + + val artifacts = batch.getArtifactsList + assert(artifacts.get(0).getName == "classes/smallClassFile.class") + assert(artifacts.get(1).getName == "jars/smallJar.jar") + + assertFileDataEquality(artifacts.get(0).getData, Paths.get(file1)) + assertFileDataEquality(artifacts.get(1).getData, Paths.get(file2)) + } + + test("Mix of SingleChunkArtifact and chunked artifact") { + val file1 = artifactFilePath.resolve("smallClassFile.class").toUri + val file2 = artifactFilePath.resolve("junitLargeJar.jar").toUri + val file3 = artifactFilePath.resolve("smallClassFileDup.class").toUri + val file4 = artifactFilePath.resolve("smallJar.jar").toUri + artifactManager.addArtifacts(Seq(file1, file2, file3, file4)) + val receivedRequests = service.getAndClearLatestAddArtifactRequests() + // There are a total of 14 requests. + // The 1st request contains a single artifact - smallClassFile.class (There are no + // other artifacts batched with it since the next one is large multi-chunk artifact) + // Requests 2-13 (1-indexed) belong to the transfer of junitLargeJar.jar. This includes + // the first "beginning chunk" and the subsequent data chunks. + // The last request (14) contains both smallClassFileDup.class and smallJar.jar batched + // together. + assert(receivedRequests.size == 1 + 12 + 1) + + val firstReqBatch = receivedRequests.head.getBatch.getArtifactsList + assert(firstReqBatch.size() == 1) + assert(firstReqBatch.get(0).getName == "classes/smallClassFile.class") + assertFileDataEquality(firstReqBatch.get(0).getData, Paths.get(file1)) + + val secondReq = receivedRequests(1) + assert(secondReq.hasBeginChunk) + val beginChunkRequest = secondReq.getBeginChunk + assert(beginChunkRequest.getName == "jars/junitLargeJar.jar") + assert(beginChunkRequest.getTotalBytes == 384581) + assert(beginChunkRequest.getNumChunks == 12) + // Large artifact data chunks are requests number 3 to 13. + val dataChunks = Seq(beginChunkRequest.getInitialChunk) ++ + receivedRequests.drop(2).dropRight(1).map(_.getChunk) + checkChunksDataAndCrc(Paths.get(file2), dataChunks) + + val lastBatch = receivedRequests.last.getBatch + assert(lastBatch.getArtifactsCount == 2) + val remainingArtifacts = lastBatch.getArtifactsList + assert(remainingArtifacts.get(0).getName == "classes/smallClassFileDup.class") + assert(remainingArtifacts.get(1).getName == "jars/smallJar.jar") + + assertFileDataEquality(remainingArtifacts.get(0).getData, Paths.get(file3)) + assertFileDataEquality(remainingArtifacts.get(1).getData, Paths.get(file4)) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala old mode 100644 new mode 100755 index 8cead49de0c1c..dcb135892064a --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -22,9 +22,10 @@ import io.grpc.{Server, StatusRuntimeException} import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach +import scala.collection.mutable import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} +import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.config.ConnectCommon @@ -181,6 +182,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { private var inputPlan: proto.Plan = _ + private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] = + mutable.ListBuffer.empty private[sql] def getAndClearLatestInputPlan(): proto.Plan = { val plan = inputPlan @@ -188,6 +191,12 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer plan } + private[sql] def getAndClearLatestAddArtifactRequests(): Seq[AddArtifactsRequest] = { + val requests = inputArtifactRequests.toSeq + inputArtifactRequests.clear() + requests + } + override def executePlan( request: ExecutePlanRequest, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { @@ -229,4 +238,16 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer responseObserver.onNext(response) responseObserver.onCompleted() } + + override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) + : StreamObserver[AddArtifactsRequest] = new StreamObserver[AddArtifactsRequest] { + override def onNext(v: AddArtifactsRequest): Unit = inputArtifactRequests.append(v) + + override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable) + + override def onCompleted(): Unit = { + responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build()) + responseObserver.onCompleted() + } + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala old mode 100644 new mode 100755 index 5100fa7d229f1..1ece0838b1bf4 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala @@ -16,9 +16,43 @@ */ package org.apache.spark.sql.connect.client.util +import java.nio.file.Path + import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite /** * The basic testsuite the client tests should extend from. */ -trait ConnectFunSuite extends AnyFunSuite {} // scalastyle:ignore funsuite +trait ConnectFunSuite extends AnyFunSuite { // scalastyle:ignore funsuite + + // Borrowed from SparkFunSuite + protected def getWorkspaceFilePath(first: String, more: String*): Path = { + if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) { + fail("spark.test.home or SPARK_HOME is not set.") + } + val sparkHome = sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + java.nio.file.Paths.get(sparkHome, first +: more: _*) + } + + protected val baseResourcePath: Path = { + getWorkspaceFilePath( + "connector", + "connect", + "client", + "jvm", + "src", + "test", + "resources").toAbsolutePath + } + + protected val commonResourcePath: Path = { + getWorkspaceFilePath( + "connector", + "connect", + "common", + "src", + "test", + "resources", + "query-tests").toAbsolutePath + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala old mode 100644 new mode 100755 index d6446eae4b781..cd353b6ff6097 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT @@ -179,6 +180,28 @@ class SparkConnectService(debug: Boolean) new SparkConnectConfigHandler(responseObserver).handle(request) } catch handleError("config", observer = responseObserver) } + + /** + * This is the main entry method for all calls to add/transfer artifacts. + * + * @param responseObserver + * @return + */ + override def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) + : StreamObserver[AddArtifactsRequest] = { + // TODO: Handle artifact files + // No-Op StreamObserver + new StreamObserver[AddArtifactsRequest] { + override def onNext(v: AddArtifactsRequest): Unit = {} + + override def onError(throwable: Throwable): Unit = responseObserver.onError(throwable) + + override def onCompleted(): Unit = { + responseObserver.onNext(proto.AddArtifactsResponse.newBuilder().build()) + responseObserver.onCompleted() + } + } + } } /**