diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala index b3cc5eed..79c4ba87 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -6,6 +6,7 @@ package com.vesoft.nebula.connector import com.vesoft.nebula.connector.ssl.SSLSignType +import com.vesoft.nebula.connector.utils.SparkValidate import org.apache.spark.rdd.RDD import org.apache.spark.sql.{ DataFrame, @@ -30,6 +31,7 @@ package object connector { def nebula(connectionConfig: NebulaConnectionConfig, readConfig: ReadNebulaConfig): NebulaDataFrameReader = { + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") this.connectionConfig = connectionConfig this.readConfig = readConfig this @@ -179,6 +181,7 @@ package object connector { */ def nebula(connectionConfig: NebulaConnectionConfig, writeNebulaConfig: WriteNebulaConfig): NebulaDataFrameWriter = { + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") this.connectionConfig = connectionConfig this.writeNebulaConfig = writeNebulaConfig this diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala new file mode 100644 index 00000000..81a8a39c --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala @@ -0,0 +1,21 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector + +import com.vesoft.nebula.connector.utils.SparkValidate +import org.apache.spark.sql.SparkSession +import org.scalatest.funsuite.AnyFunSuite + +class SparkVersionValidateSuite extends AnyFunSuite { + test("spark version validate") { + try { + val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") + } catch { + case e: Exception => assert(false) + } + } +}