Skip to content

Commit

Permalink
[SPARK-42653][CONNECT] Artifact transfer from Scala/JVM client to Server
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR introduces a mechanism to transfer artifacts (currently, local `.jar` + `.class` files) from a Spark Connect JVM/Scala client over to the server side of Spark Connect. The mechanism follows the protocol as defined in apache#40147 and supports batching (for multiple "small" artifacts) and chunking (for large artifacts).

Note: Server-side artifact handling is not covered in this PR.

### Why are the changes needed?

In the decoupled client-server architecture of Spark Connect, a remote client may use a local JAR or a new class in their UDF that may not be present on the server. To handle these cases of missing "artifacts", we implement a mechanism to transfer artifacts from the client side over to the server side as per the protocol defined in apache#40147.

### Does this PR introduce _any_ user-facing change?

Yes, users would be able to use the `addArtifact` and `addArtifacts` methods (via a `SparkSession` instance) to transfer local files (`.jar` and `.class` extensions).

### How was this patch tested?

Unit tests - located in `ArtifactSuite`.

Closes apache#40256 from vicennial/SPARK-42653.

Authored-by: vicennial <venkata.gudesa@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
(cherry picked from commit 8a0d626)
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
vicennial authored and dongjoon-hyun committed Mar 3, 2023
1 parent 3d09b40 commit c445b85
Show file tree
Hide file tree
Showing 17 changed files with 710 additions and 23 deletions.
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql

import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong

Expand Down Expand Up @@ -417,6 +418,37 @@ class SparkSession private[sql] (
execute(command)
}

/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(path: String): Unit = client.addArtifact(path)

/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(uri: URI): Unit = client.addArtifact(uri)

/**
* Add one or more artifacts to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)

/**
* This resets the plan id generator so we can produce plans that are comparable.
*
Expand Down
@@ -0,0 +1,305 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client

import java.io.InputStream
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.zip.{CheckedInputStream, CRC32}

import scala.collection.mutable
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal

import Artifact._
import com.google.protobuf.ByteString
import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* The Artifact Manager is responsible for handling and transferring artifacts from the local
* client to the server (local/remote).
* @param userContext
* @param channel
*/
class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
// Using the midpoint recommendation of 32KiB for chunk size as specified in
// https://github.com/grpc/grpc.github.io/issues/371.
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)

/**
* Add a single artifact to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifact(path: String): Unit = {
addArtifact(Utils.resolveURI(path))
}

private def parseArtifacts(uri: URI): Seq[Artifact] = {
// Currently only local files with extensions .jar and .class are supported.
uri.getScheme match {
case "file" =>
val path = Paths.get(uri)
val artifact = path.getFileName.toString match {
case jar if jar.endsWith(".jar") =>
newJarArtifact(path.getFileName, new LocalFile(path))
case cf if cf.endsWith(".class") =>
newClassArtifact(path.getFileName, new LocalFile(path))
case other =>
throw new UnsupportedOperationException(s"Unsuppoted file format: $other")
}
Seq[Artifact](artifact)

case other =>
throw new UnsupportedOperationException(s"Unsupported scheme: $other")
}
}

/**
* Add a single artifact to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifact(uri: URI): Unit = addArtifacts(parseArtifacts(uri))

/**
* Add multiple artifacts to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))

/**
* Add a number of artifacts to the session.
*/
private def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
val promise = Promise[Seq[ArtifactSummary]]
val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
private val summaries = mutable.Buffer.empty[ArtifactSummary]
override def onNext(v: AddArtifactsResponse): Unit = {
v.getArtifactsList.forEach { summary =>
summaries += summary
}
}
override def onError(throwable: Throwable): Unit = {
promise.failure(throwable)
}
override def onCompleted(): Unit = {
promise.success(summaries.toSeq)
}
}
val stream = stub.addArtifacts(responseHandler)
val currentBatch = mutable.Buffer.empty[Artifact]
var currentBatchSize = 0L

def addToBatch(dep: Artifact, size: Long): Unit = {
currentBatch += dep
currentBatchSize += size
}

def writeBatch(): Unit = {
addBatchedArtifacts(currentBatch.toSeq, stream)
currentBatch.clear()
currentBatchSize = 0
}

artifacts.iterator.foreach { artifact =>
val data = artifact.storage
val size = data.size
if (size > CHUNK_SIZE) {
// Payload can either be a batch OR a single chunked artifact. Write batch if non-empty
// before chunking current artifact.
if (currentBatch.nonEmpty) {
writeBatch()
}
addChunkedArtifact(artifact, stream)
} else {
if (currentBatchSize + size > CHUNK_SIZE) {
writeBatch()
}
addToBatch(artifact, size)
}
}
if (currentBatch.nonEmpty) {
writeBatch()
}
stream.onCompleted()
ThreadUtils.awaitResult(promise.future, Duration.Inf)
// TODO(SPARK-42658): Handle responses containing CRC failures.
}

/**
* Add a batch of artifacts to the stream. All the artifacts in this call are packaged into a
* single [[proto.AddArtifactsRequest]].
*/
private def addBatchedArtifacts(
artifacts: Seq[Artifact],
stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
artifacts.foreach { artifact =>
val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
val data = proto.AddArtifactsRequest.ArtifactChunk
.newBuilder()
.setData(ByteString.readFrom(in))
.setCrc(in.getChecksum.getValue)

builder.getBatchBuilder
.addArtifactsBuilder()
.setName(artifact.path.toString)
.setData(data)
.build()
} catch {
case NonFatal(e) =>
stream.onError(e)
throw e
} finally {
in.close()
}
}
stream.onNext(builder.build())
}

/**
* Read data from an [[InputStream]] in pieces of `chunkSize` bytes and convert to
* protobuf-compatible [[ByteString]].
* @param in
* @return
*/
private def readNextChunk(in: InputStream): ByteString = {
val buf = new Array[Byte](CHUNK_SIZE)
var bytesRead = 0
var count = 0
while (count != -1 && bytesRead < CHUNK_SIZE) {
count = in.read(buf, bytesRead, CHUNK_SIZE - bytesRead)
if (count != -1) {
bytesRead += count
}
}
if (bytesRead == 0) ByteString.empty()
else ByteString.copyFrom(buf, 0, bytesRead)
}

/**
* Add a artifact in chunks to the stream. The artifact's data is spread out over multiple
* [[proto.AddArtifactsRequest requests]].
*/
private def addChunkedArtifact(
artifact: Artifact,
stream: StreamObserver[proto.AddArtifactsRequest]): Unit = {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)

val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
// First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`).
// Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`).
val artifactChunkBuilder = proto.AddArtifactsRequest.ArtifactChunk.newBuilder()
var dataChunk = readNextChunk(in)
// Integer division that rounds up to the nearest whole number.
def getNumChunks(size: Long): Long = (size + (CHUNK_SIZE - 1)) / CHUNK_SIZE

builder.getBeginChunkBuilder
.setName(artifact.path.toString)
.setTotalBytes(artifact.size)
.setNumChunks(getNumChunks(artifact.size))
.setInitialChunk(
artifactChunkBuilder
.setData(dataChunk)
.setCrc(in.getChecksum.getValue))
stream.onNext(builder.build())
in.getChecksum.reset()
builder.clearBeginChunk()

dataChunk = readNextChunk(in)
// Consume stream in chunks until there is no data left to read.
while (!dataChunk.isEmpty) {
artifactChunkBuilder.setData(dataChunk).setCrc(in.getChecksum.getValue)
builder.setChunk(artifactChunkBuilder.build())
stream.onNext(builder.build())
in.getChecksum.reset()
builder.clearChunk()
dataChunk = readNextChunk(in)
}
} catch {
case NonFatal(e) =>
stream.onError(e)
throw e
} finally {
in.close()
}
}
}

class Artifact private (val path: Path, val storage: LocalData) {
require(!path.isAbsolute, s"Bad path: $path")

lazy val size: Long = storage match {
case localData: LocalData => localData.size
}
}

object Artifact {
val CLASS_PREFIX: Path = Paths.get("classes")
val JAR_PREFIX: Path = Paths.get("jars")

def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
newArtifact(JAR_PREFIX, ".jar", fileName, storage)
}

def newClassArtifact(fileName: Path, storage: LocalData): Artifact = {
newArtifact(CLASS_PREFIX, ".class", fileName, storage)
}

private def newArtifact(
prefix: Path,
requiredSuffix: String,
fileName: Path,
storage: LocalData): Artifact = {
require(!fileName.isAbsolute)
require(fileName.toString.endsWith(requiredSuffix))
new Artifact(prefix.resolve(fileName), storage)
}

/**
* Payload stored on this machine.
*/
sealed trait LocalData {
def stream: InputStream
def size: Long
}

/**
* Payload stored in a local file.
*/
class LocalFile(val path: Path) extends LocalData {
override def size: Long = Files.size(path)
override def stream: InputStream = Files.newInputStream(path)
}
}
Expand Up @@ -36,6 +36,8 @@ private[sql] class SparkConnectClient(

private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)

private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)

/**
* Placeholder method.
* @return
Expand Down Expand Up @@ -147,6 +149,27 @@ private[sql] class SparkConnectClient(
analyze(request)
}

/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifact(path: String): Unit = artifactManager.addArtifact(path)

/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifact(uri: URI): Unit = artifactManager.addArtifact(uri)

/**
* Add multiple artifacts to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*/
def addArtifacts(uri: Seq[URI]): Unit = artifactManager.addArtifacts(uri)

/**
* Shutdown the client's connection to the server.
*/
Expand Down
@@ -0,0 +1,5 @@
The CRCs for a specific file are stored in a text file with the same name (excluding the original extension).

The CRCs are calculated for data chunks of `32768 bytes` (individual CRCs) and are newline delimited.

The CRCs were calculated using https://simplycalc.com/crc32-file.php

0 comments on commit c445b85

Please sign in to comment.