forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SparkSession.scala
533 lines (475 loc) · 16.7 KB
/
SparkSession.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.util.{Cleaner, ConvertToArrow}
import org.apache.spark.sql.types.StructType
/**
* The entry point to programming Spark with the Dataset and DataFrame API.
*
* In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to
* get an existing session:
*
* {{{
* SparkSession.builder().getOrCreate()
* }}}
*
* The builder can also be used to create a new session:
*
* {{{
* SparkSession.builder
* .master("local")
* .appName("Word Count")
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
*/
class SparkSession private[sql] (
private val client: SparkConnectClient,
private val cleaner: Cleaner,
private val planIdGenerator: AtomicLong)
extends Serializable
with Closeable
with Logging {
private[this] val allocator = new RootAllocator()
lazy val version: String = {
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
}
/**
* Runtime configuration interface for Spark.
*
* This is the interface through which the user can get and set all Spark configurations that
* are relevant to Spark SQL. When getting the value of a config, his defaults to the value set
* in server, if any.
*
* @since 3.4.0
*/
val conf: RuntimeConfig = new RuntimeConfig(client)
/**
* Executes some code block and prints to stdout the time taken to execute the block. This is
* available in Scala only and is used primarily for interactive testing and debugging.
*
* @since 3.4.0
*/
def time[T](f: => T): T = {
val start = System.nanoTime()
val ret = f
val end = System.nanoTime()
// scalastyle:off println
println(s"Time taken: ${NANOSECONDS.toMillis(end - start)} ms")
// scalastyle:on println
ret
}
/**
* Returns a `DataFrame` with no rows or columns.
*
* @since 3.4.0
*/
@transient
val emptyDataFrame: DataFrame = emptyDataset(UnboundRowEncoder)
/**
* Creates a new [[Dataset]] of type T containing zero elements.
*
* @since 3.4.0
*/
def emptyDataset[T: Encoder]: Dataset[T] = createDataset[T](Nil)
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
val localRelationBuilder = builder.getLocalRelationBuilder
.setSchema(encoder.schema.catalogString)
if (data.nonEmpty) {
val timeZoneId = conf.get("spark.sql.session.timeZone")
val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator)
localRelationBuilder.setData(arrowData)
}
}
}
/**
* Creates a `DataFrame` from a local Seq of Product.
*
* @since 3.4.0
*/
def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = {
createDataset(ScalaReflection.encoderFor[A], data.iterator).toDF()
}
/**
* :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing [[Row]]s using
* the given schema. It is important to make sure that the structure of every [[Row]] of the
* provided List matches the provided schema. Otherwise, there will be runtime exception.
*
* @since 3.4.0
*/
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
createDataset(RowEncoder.encoderFor(schema), rows.iterator().asScala).toDF()
}
/**
* Applies a schema to a List of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries
* will return the columns in an undefined order.
* @since 3.4.0
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
createDataset(encoder, data.iterator().asScala).toDF()
}
/**
* Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
* representation) that is generally created automatically through implicits from a
* `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]].
*
* ==Example==
*
* {{{
*
* import spark.implicits._
* case class Person(name: String, age: Long)
* val data = Seq(Person("Michael", 29), Person("Andy", 30), Person("Justin", 19))
* val ds = spark.createDataset(data)
*
* ds.show()
* // +-------+---+
* // | name|age|
* // +-------+---+
* // |Michael| 29|
* // | Andy| 30|
* // | Justin| 19|
* // +-------+---+
* }}}
*
* @since 3.4.0
*/
def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = {
createDataset(encoderFor[T], data.iterator)
}
/**
* Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
* representation) that is generally created automatically through implicits from a
* `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]].
*
* ==Java Example==
*
* {{{
* List<String> data = Arrays.asList("hello", "world");
* Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
* }}}
*
* @since 3.4.0
*/
def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = {
createDataset(data.asScala.toSeq)
}
/**
* Executes a SQL query substituting named parameters by the given arguments, returning the
* result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
*
* @param sqlText
* A SQL statement with named parameters to execute.
* @param args
* A map of parameter names to literal values.
*
* @since 3.4.0
*/
@Experimental
def sql(sqlText: String, args: Map[String, String]): DataFrame = {
sql(sqlText, args.asJava)
}
/**
* Executes a SQL query substituting named parameters by the given arguments, returning the
* result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
*
* @param sqlText
* A SQL statement with named parameters to execute.
* @param args
* A map of parameter names to literal values.
*
* @since 3.4.0
*/
@Experimental
def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame {
builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(proto.SqlCommand.newBuilder().setSql(sqlText).putAllArgs(args)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())
val response = responseIter.asScala
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
}
/**
* Executes a SQL query using Spark, returning the result as a `DataFrame`. This API eagerly
* runs DDL/DML commands, but not for SELECT queries.
*
* @since 3.4.0
*/
def sql(query: String): DataFrame = {
sql(query, Map.empty[String, String])
}
/**
* Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
* `DataFrame`.
* {{{
* sparkSession.read.parquet("/path/to/file.parquet")
* sparkSession.read.schema(schema).json("/path/to/file.json")
* }}}
*
* @since 3.4.0
*/
def read: DataFrameReader = new DataFrameReader(this)
/**
* Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
* reading and the returned DataFrame is the batch scan query plan of this table. If it's a
* view, the returned DataFrame is simply the query plan of the view, which can either be a
* batch or streaming query plan.
*
* @param tableName
* is either a qualified or unqualified name that designates a table or view. If a database is
* specified, it identifies the table/view from the database. Otherwise, it first attempts to
* find a temporary view with the given name and then match the table/view from the current
* database. Note that, the global temporary view database is also valid here.
* @since 3.4.0
*/
def table(tableName: String): DataFrame = {
read.table(tableName)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from 0 to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, None)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value, with partition number specified.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
range(start, end, step, Option(numPartitions))
}
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* (Scala-specific) Implicit methods available in Scala for converting common names and
* [[Symbol]]s into [[Column]]s, and for converting common Scala objects into `DataFrame`s.
*
* {{{
* val sparkSession = SparkSession.builder.getOrCreate()
* import sparkSession.implicits._
* }}}
*
* @since 3.4.0
*/
object implicits extends SQLImplicits(this)
// scalastyle:on
def newSession(): SparkSession = {
throw new UnsupportedOperationException("newSession is not supported")
}
private def range(
start: Long,
end: Long,
step: Long,
numPartitions: Option[Int]): Dataset[java.lang.Long] = {
newDataset(BoxedLongEncoder) { builder =>
val rangeBuilder = builder.getRangeBuilder
.setStart(start)
.setEnd(end)
.setStep(step)
numPartitions.foreach(rangeBuilder.setNumPartitions)
}
}
private[sql] def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder)(f)
}
private[sql] def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset[T](this, plan, encoder)
}
@DeveloperApi
def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
newDataset(extension, UnboundRowEncoder)
}
@DeveloperApi
def newDataset[T](
extension: com.google.protobuf.Any,
encoder: AgnosticEncoder[T]): Dataset[T] = {
newDataset(encoder)(_.setExtension(extension))
}
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
f(builder)
builder.build()
}
private[sql] def analyze(
plan: proto.Plan,
method: proto.AnalyzePlanRequest.AnalyzeCase,
explainMode: Option[proto.AnalyzePlanRequest.Explain.ExplainMode] = None)
: proto.AnalyzePlanResponse = {
client.analyze(method, Some(plan), explainMode)
}
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator, encoder)
cleaner.register(result)
result
}
private[sql] def execute(command: proto.Command): Unit = {
val plan = proto.Plan.newBuilder().setCommand(command).build()
client.execute(plan).asScala.foreach(_ => ())
}
@DeveloperApi
def execute(extension: com.google.protobuf.Any): Unit = {
val command = proto.Command.newBuilder().setExtension(extension).build()
execute(command)
}
/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(path: String): Unit = client.addArtifact(path)
/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(uri: URI): Unit = client.addArtifact(uri)
/**
* Add one or more artifacts to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
/**
* This resets the plan id generator so we can produce plans that are comparable.
*
* For testing only!
*/
private[sql] def resetPlanIdGenerator(): Unit = {
planIdGenerator.set(0)
}
/**
* Synonym for `close()`.
*
* @since 3.4.0
*/
def stop(): Unit = close()
/**
* Close the [[SparkSession]]. This closes the connection, and the allocator. The latter will
* throw an exception if there are still open [[SparkResult]]s.
*
* @since 3.4.0
*/
override def close(): Unit = {
client.shutdown()
allocator.close()
}
}
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
private val planIdGenerator = new AtomicLong
def builder(): Builder = new Builder()
private[sql] lazy val cleaner = {
val cleaner = new Cleaner
cleaner.start()
cleaner
}
class Builder() extends Logging {
private var _client: SparkConnectClient = _
def remote(connectionString: String): Builder = {
client(SparkConnectClient.builder().connectionString(connectionString).build())
this
}
private[sql] def client(client: SparkConnectClient): Builder = {
_client = client
this
}
def build(): SparkSession = {
if (_client == null) {
_client = SparkConnectClient.builder().build()
}
new SparkSession(_client, cleaner, planIdGenerator)
}
}
def getActiveSession: Option[SparkSession] = {
throw new UnsupportedOperationException("getActiveSession is not supported")
}
def getDefaultSession: Option[SparkSession] = {
throw new UnsupportedOperationException("getDefaultSession is not supported")
}
def setActiveSession(session: SparkSession): Unit = {
throw new UnsupportedOperationException("setActiveSession is not supported")
}
def clearActiveSession(): Unit = {
throw new UnsupportedOperationException("clearActiveSession is not supported")
}
def active: SparkSession = {
throw new UnsupportedOperationException("active is not supported")
}
}