diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml
index fda804df..daf6e83c 100644
--- a/.github/workflows/pull_request.yml
+++ b/.github/workflows/pull_request.yml
@@ -45,6 +45,9 @@ jobs:
- name: Build with Maven
run: |
- mvn -B package
+ mvn clean package -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2
+ mvn clean package -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4
+ mvn clean package -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0
+
- uses: codecov/codecov-action@v2
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 7132a0a7..f7563646 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -38,10 +38,29 @@ jobs:
popd
popd
- - name: Deploy release to Maven
+ - name: Deploy release for spark2.4 to Maven
uses: samuelmeuli/action-maven-publish@v1
with:
gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
nexus_username: ${{ secrets.OSSRH_USERNAME }}
nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4
+
+ - name: Deploy release for spark2.2 to Maven
+ uses: samuelmeuli/action-maven-publish@v1
+ with:
+ gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
+ gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
+ nexus_username: ${{ secrets.OSSRH_USERNAME }}
+ nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2
+
+ - name: Deploy release for spark3.0 to Maven
+ uses: samuelmeuli/action-maven-publish@v1
+ with:
+ gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
+ gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
+ nexus_username: ${{ secrets.OSSRH_USERNAME }}
+ nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0
diff --git a/.github/workflows/snapshot.yml b/.github/workflows/snapshot.yml
index 2fee6328..0159d49e 100644
--- a/.github/workflows/snapshot.yml
+++ b/.github/workflows/snapshot.yml
@@ -40,10 +40,29 @@ jobs:
popd
popd
- - name: Deploy SNAPSHOT to Sonatype
+ - name: Deploy SNAPSHOT for spark2.4 to Sonatype
uses: samuelmeuli/action-maven-publish@v1
with:
gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
nexus_username: ${{ secrets.OSSRH_USERNAME }}
nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4
+
+ - name: Deploy SNAPSHOT for spark2.2 to Sonatype
+ uses: samuelmeuli/action-maven-publish@v1
+ with:
+ gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
+ gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
+ nexus_username: ${{ secrets.OSSRH_USERNAME }}
+ nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2
+
+ - name: Deploy SNAPSHOT for spark3.0 to Sonatype
+ uses: samuelmeuli/action-maven-publish@v1
+ with:
+ gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }}
+ gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }}
+ nexus_username: ${{ secrets.OSSRH_USERNAME }}
+ nexus_password: ${{ secrets.OSSRH_TOKEN }}
+ maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0
diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala
index 8a4a94b4..3d5a07d2 100644
--- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala
+++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala
@@ -40,11 +40,10 @@ object NebulaSparkReaderExample {
}
def readVertex(spark: SparkSession): Unit = {
- LOG.info("start to read nebula vertices")
val config =
NebulaConnectionConfig
.builder()
- .withMetaAddress("127.0.0.1:9559")
+ .withMetaAddress("192.168.8.171:9559")
.withConenctionRetry(2)
.build()
val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
@@ -52,7 +51,7 @@ object NebulaSparkReaderExample {
.withSpace("test")
.withLabel("person")
.withNoColumn(false)
- .withReturnCols(List("birthday"))
+ .withReturnCols(List())
.withLimit(10)
.withPartitionNum(10)
.build()
@@ -63,12 +62,10 @@ object NebulaSparkReaderExample {
}
def readEdges(spark: SparkSession): Unit = {
- LOG.info("start to read nebula edges")
-
val config =
NebulaConnectionConfig
.builder()
- .withMetaAddress("127.0.0.1:9559")
+ .withMetaAddress("192.168.8.171:9559")
.withTimeout(6000)
.withConenctionRetry(2)
.build()
@@ -77,7 +74,7 @@ object NebulaSparkReaderExample {
.withSpace("test")
.withLabel("knows")
.withNoColumn(false)
- .withReturnCols(List("degree"))
+ .withReturnCols(List())
.withLimit(10)
.withPartitionNum(10)
.build()
diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala
index 279935f9..ccde7fdf 100644
--- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala
+++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala
@@ -94,16 +94,22 @@ object NebulaSparkWriterExample {
* if your withVidAsProp is true, then tag schema also should have property name: id
*/
def writeVertex(spark: SparkSession): Unit = {
- LOG.info("start to write nebula vertices")
- val df = spark.read.json("example/src/main/resources/vertex")
+ val df = spark.read.json("vertex")
df.show()
- val config = getNebulaConnectionConfig()
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("192.168.8.171:9559")
+ .withGraphAddress("192.168.8.171:9669")
+ .withConenctionRetry(2)
+ .build()
val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig
.builder()
.withSpace("test")
.withTag("person")
.withVidField("id")
+ .withWriteMode(WriteMode.DELETE)
.withVidAsProp(false)
.withBatch(1000)
.build()
@@ -117,8 +123,7 @@ object NebulaSparkWriterExample {
* if your withRankAsProperty is true, then edge schema also should have property name: degree
*/
def writeEdge(spark: SparkSession): Unit = {
- LOG.info("start to write nebula edges")
- val df = spark.read.json("example/src/main/resources/edge")
+ val df = spark.read.json("edge")
df.show()
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
diff --git a/nebula-spark-common/pom.xml b/nebula-spark-common/pom.xml
index 9eea5315..d0a761fa 100644
--- a/nebula-spark-common/pom.xml
+++ b/nebula-spark-common/pom.xml
@@ -19,10 +19,6 @@
3.2.3
4.13.1
1.13
-
- 2.11
- 2.4.4
- 2.11.12
@@ -57,7 +53,7 @@
org.scalatest
- scalatest-funsuite_2.11
+ scalatest-funsuite_${scala.binary.version}
${scalatest.version}
test
diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala
index 21195c5c..2bb630cf 100644
--- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala
+++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala
@@ -15,22 +15,18 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import scala.collection.mutable.ListBuffer
-class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
- operaType: OperaType.Value)
- extends Serializable
- with Logging {
+class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String]) extends Serializable {
import NebulaOptions._
def this(parameters: Map[String, String], operaType: OperaType.Value) =
- this(CaseInsensitiveMap(parameters))(operaType)
+ this(CaseInsensitiveMap(parameters))
def this(hostAndPorts: String,
spaceName: String,
dataType: String,
label: String,
- parameters: Map[String, String],
- operaType: OperaType.Value) = {
+ parameters: Map[String, String]) = {
this(
CaseInsensitiveMap(
parameters ++ Map(
@@ -39,8 +35,9 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
NebulaOptions.TYPE -> dataType,
NebulaOptions.LABEL -> label
))
- )(operaType)
+ )
}
+ val operaType = OperaType.withName(parameters(OPERATE_TYPE))
/**
* Return property with all options
@@ -104,21 +101,24 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
val label: String = parameters(LABEL)
/** read parameters */
- var returnCols: String = _
- var partitionNums: String = _
- var noColumn: Boolean = _
- var limit: Int = _
- var ngql: String = _
+ var returnCols: String = _
+ var partitionNums: String = _
+ var noColumn: Boolean = _
+ var limit: Int = _
+ var pushDownFiltersEnabled: Boolean = _
+ var ngql: String = _
if (operaType == OperaType.READ) {
returnCols = parameters(RETURN_COLS)
noColumn = parameters.getOrElse(NO_COLUMN, false).toString.toBoolean
partitionNums = parameters(PARTITION_NUMBER)
limit = parameters.getOrElse(LIMIT, DEFAULT_LIMIT).toString.toInt
- ngql = parameters.getOrElse(NGQL,EMPTY_STRING)
- ngql = parameters.getOrElse(NGQL,EMPTY_STRING)
- if(ngql!=EMPTY_STRING){
+ // TODO explore the pushDownFiltersEnabled parameter to users
+ pushDownFiltersEnabled = parameters.getOrElse(PUSHDOWN_FILTERS_ENABLE, false).toString.toBoolean
+ ngql = parameters.getOrElse(NGQL, EMPTY_STRING)
+ ngql = parameters.getOrElse(NGQL, EMPTY_STRING)
+ if (ngql != EMPTY_STRING) {
require(parameters.isDefinedAt(GRAPH_ADDRESS),
- s"option $GRAPH_ADDRESS is required for ngql and can not be blank")
+ s"option $GRAPH_ADDRESS is required for ngql and can not be blank")
graphAddress = parameters(GRAPH_ADDRESS)
}
}
@@ -187,13 +187,11 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
def getMetaAddress: List[Address] = {
val hostPorts: ListBuffer[Address] = new ListBuffer[Address]
- metaAddress
- .split(",")
- .foreach(hostPort => {
- // check host & port by getting HostAndPort
- val addr = HostAndPort.fromString(hostPort)
- hostPorts.append((addr.getHostText, addr.getPort))
- })
+ for (hostPort <- metaAddress.split(",")) {
+ // check host & port by getting HostAndPort
+ val addr = HostAndPort.fromString(hostPort)
+ hostPorts.append((addr.getHostText, addr.getPort))
+ }
hostPorts.toList
}
@@ -211,9 +209,6 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])(
}
-class NebulaOptionsInWrite(@transient override val parameters: CaseInsensitiveMap[String])
- extends NebulaOptions(parameters)(OperaType.WRITE) {}
-
object NebulaOptions {
/** nebula common config */
@@ -237,14 +232,17 @@ object NebulaOptions {
val CA_SIGN_PARAM: String = "caSignParam"
val SELF_SIGN_PARAM: String = "selfSignParam"
+ val OPERATE_TYPE: String = "operateType"
+
/** read config */
- val RETURN_COLS: String = "returnCols"
- val NO_COLUMN: String = "noColumn"
- val PARTITION_NUMBER: String = "partitionNumber"
- val LIMIT: String = "limit"
+ val RETURN_COLS: String = "returnCols"
+ val NO_COLUMN: String = "noColumn"
+ val PARTITION_NUMBER: String = "partitionNumber"
+ val LIMIT: String = "limit"
+ val PUSHDOWN_FILTERS_ENABLE: String = "pushDownFiltersEnable"
/** read by ngql **/
- val NGQL: String = "ngql"
+ val NGQL: String = "ngql"
/** write config */
val RATE_LIMIT: String = "rateLimit"
diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala
index 47ad2d4c..bf48e2a6 100644
--- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala
+++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala
@@ -7,22 +7,27 @@ package com.vesoft.nebula.connector
import com.vesoft.nebula.PropertyType
import com.vesoft.nebula.client.graph.data.{DateTimeWrapper, DurationWrapper, TimeWrapper}
+import com.vesoft.nebula.connector.nebula.MetaProvider
import com.vesoft.nebula.meta.{ColumnDef, ColumnTypeDef}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{
BooleanType,
DataType,
+ DataTypes,
DoubleType,
FloatType,
IntegerType,
LongType,
StringType,
+ StructField,
StructType,
TimestampType
}
import org.apache.spark.unsafe.types.UTF8String
import org.slf4j.LoggerFactory
+import scala.collection.mutable.ListBuffer
+
object NebulaUtils {
private val LOG = LoggerFactory.getLogger(this.getClass)
@@ -156,4 +161,70 @@ object NebulaUtils {
s
}
+ /**
+ * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula.
+ */
+ def getSchema(nebulaOptions: NebulaOptions): StructType = {
+ val returnCols = nebulaOptions.getReturnCols
+ val noColumn = nebulaOptions.noColumn
+ val fields: ListBuffer[StructField] = new ListBuffer[StructField]
+ val metaProvider = new MetaProvider(
+ nebulaOptions.getMetaAddress,
+ nebulaOptions.timeout,
+ nebulaOptions.connectionRetry,
+ nebulaOptions.executionRetry,
+ nebulaOptions.enableMetaSSL,
+ nebulaOptions.sslSignType,
+ nebulaOptions.caSignParam,
+ nebulaOptions.selfSignParam
+ )
+
+ import scala.collection.JavaConverters._
+ var schemaCols: Seq[ColumnDef] = Seq()
+ val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType)
+
+ // construct vertex or edge default prop
+ if (isVertex) {
+ fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false))
+ } else {
+ fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false))
+ fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false))
+ fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false))
+ }
+
+ var dataSchema: StructType = null
+ // read no column
+ if (noColumn) {
+ dataSchema = new StructType(fields.toArray)
+ return dataSchema
+ }
+ // get tag schema or edge schema
+ val schema = if (isVertex) {
+ metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label)
+ } else {
+ metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label)
+ }
+
+ schemaCols = schema.columns.asScala
+
+ // read all columns
+ if (returnCols.isEmpty) {
+ schemaCols.foreach(columnDef => {
+ LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ")
+ fields.append(
+ DataTypes.createStructField(new String(columnDef.getName),
+ NebulaUtils.convertDataType(columnDef.getType),
+ true))
+ })
+ } else {
+ for (col: String <- returnCols) {
+ fields.append(
+ DataTypes
+ .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true))
+ }
+ }
+ dataSchema = new StructType(fields.toArray)
+ dataSchema
+ }
+
}
diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala
new file mode 100644
index 00000000..054819a9
--- /dev/null
+++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala
@@ -0,0 +1,291 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.client.graph.data.{
+ CASignedSSLParam,
+ HostAddress,
+ SSLParam,
+ SelfSignedSSLParam,
+ ValueWrapper
+}
+import com.vesoft.nebula.client.storage.StorageClient
+import com.vesoft.nebula.client.storage.data.BaseTableRow
+import com.vesoft.nebula.client.storage.scan.{
+ ScanEdgeResult,
+ ScanEdgeResultIterator,
+ ScanVertexResult,
+ ScanVertexResultIterator
+}
+import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter
+import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils}
+import com.vesoft.nebula.connector.exception.GraphConnectException
+import com.vesoft.nebula.connector.nebula.MetaProvider
+import com.vesoft.nebula.connector.ssl.SSLSignType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
+import org.apache.spark.sql.types.StructType
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+trait NebulaReader {
+ private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
+
+ private var metaProvider: MetaProvider = _
+ private var schema: StructType = _
+
+ protected var dataIterator: Iterator[BaseTableRow] = _
+ protected var scanPartIterator: Iterator[Integer] = _
+ protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]()
+ protected var storageClient: StorageClient = _
+ protected var nebulaOptions: NebulaOptions = _
+
+ private var vertexResponseIterator: ScanVertexResultIterator = _
+ private var edgeResponseIterator: ScanEdgeResultIterator = _
+
+ /**
+ * init the reader: init metaProvider, storageClient
+ */
+ def init(index: Int, nebulaOptions: NebulaOptions, schema: StructType): Int = {
+ this.schema = schema
+ this.nebulaOptions = nebulaOptions
+
+ metaProvider = new MetaProvider(
+ nebulaOptions.getMetaAddress,
+ nebulaOptions.timeout,
+ nebulaOptions.connectionRetry,
+ nebulaOptions.executionRetry,
+ nebulaOptions.enableMetaSSL,
+ nebulaOptions.sslSignType,
+ nebulaOptions.caSignParam,
+ nebulaOptions.selfSignParam
+ )
+ val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress]
+
+ for (addr <- nebulaOptions.getMetaAddress) {
+ address.append(new HostAddress(addr._1, addr._2))
+ }
+
+ var sslParam: SSLParam = null
+ if (nebulaOptions.enableStorageSSL) {
+ SSLSignType.withName(nebulaOptions.sslSignType) match {
+ case SSLSignType.CA => {
+ val caSSLSignParams = nebulaOptions.caSignParam
+ sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath,
+ caSSLSignParams.crtFilePath,
+ caSSLSignParams.keyFilePath)
+ }
+ case SSLSignType.SELF => {
+ val selfSSLSignParams = nebulaOptions.selfSignParam
+ sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath,
+ selfSSLSignParams.keyFilePath,
+ selfSSLSignParams.password)
+ }
+ case _ => throw new IllegalArgumentException("ssl sign type is not supported")
+ }
+ this.storageClient = new StorageClient(address.asJava,
+ nebulaOptions.timeout,
+ nebulaOptions.connectionRetry,
+ nebulaOptions.executionRetry,
+ true,
+ sslParam)
+ } else {
+ this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout)
+ }
+
+ if (!storageClient.connect()) {
+ throw new GraphConnectException("storage connect failed.")
+ }
+ // allocate scanPart to this partition
+ val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName)
+ totalPart
+ }
+
+ /**
+ * resolve the vertex/edge data to InternalRow
+ */
+ protected def getRow(): InternalRow = {
+ val resultSet: Array[ValueWrapper] =
+ dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper])
+ val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema)
+ val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
+
+ for (i <- getters.indices) {
+ val value: ValueWrapper = resultSet(i)
+ var resolved = false
+ if (value.isNull) {
+ mutableRow.setNullAt(i)
+ resolved = true
+ }
+ if (value.isString) {
+ getters(i).apply(value.asString(), mutableRow, i)
+ resolved = true
+ }
+ if (value.isDate) {
+ getters(i).apply(value.asDate(), mutableRow, i)
+ resolved = true
+ }
+ if (value.isTime) {
+ getters(i).apply(value.asTime(), mutableRow, i)
+ resolved = true
+ }
+ if (value.isDateTime) {
+ getters(i).apply(value.asDateTime(), mutableRow, i)
+ resolved = true
+ }
+ if (value.isLong) {
+ getters(i).apply(value.asLong(), mutableRow, i)
+ }
+ if (value.isBoolean) {
+ getters(i).apply(value.asBoolean(), mutableRow, i)
+ }
+ if (value.isDouble) {
+ getters(i).apply(value.asDouble(), mutableRow, i)
+ }
+ if (value.isGeography) {
+ getters(i).apply(value.asGeography(), mutableRow, i)
+ }
+ if (value.isDuration) {
+ getters(i).apply(value.asDuration(), mutableRow, i)
+ }
+ }
+ mutableRow
+ }
+
+ /**
+ * if the scan response has next vertex
+ */
+ protected def hasNextVertexRow: Boolean = {
+ {
+ if (dataIterator == null && vertexResponseIterator == null && !scanPartIterator.hasNext)
+ return false
+
+ var continue: Boolean = false
+ var break: Boolean = false
+ while ((dataIterator == null || !dataIterator.hasNext) && !break) {
+ resultValues.clear()
+ continue = false
+ if (vertexResponseIterator == null || !vertexResponseIterator.hasNext) {
+ if (scanPartIterator.hasNext) {
+ try {
+ if (nebulaOptions.noColumn) {
+ vertexResponseIterator = storageClient.scanVertex(nebulaOptions.spaceName,
+ scanPartIterator.next(),
+ nebulaOptions.label,
+ nebulaOptions.limit,
+ 0,
+ Long.MaxValue,
+ true,
+ true)
+ } else {
+ vertexResponseIterator = storageClient.scanVertex(
+ nebulaOptions.spaceName,
+ scanPartIterator.next(),
+ nebulaOptions.label,
+ nebulaOptions.getReturnCols.asJava,
+ nebulaOptions.limit,
+ 0,
+ Long.MaxValue,
+ true,
+ true)
+ }
+ } catch {
+ case e: Exception =>
+ LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
+ storageClient.close()
+ throw new Exception(e.getMessage, e)
+ }
+ // jump to the next loop
+ continue = true
+ }
+ // break while loop
+ break = !continue
+ } else {
+ val next: ScanVertexResult = vertexResponseIterator.next
+ if (!next.isEmpty) {
+ dataIterator = next.getVertexTableRows.iterator().asScala
+ }
+ }
+ }
+
+ if (dataIterator == null) {
+ return false
+ }
+ dataIterator.hasNext
+ }
+ }
+
+ /**
+ * if the scan response has next edge
+ */
+ protected def hasNextEdgeRow: Boolean = {
+ if (dataIterator == null && edgeResponseIterator == null && !scanPartIterator.hasNext)
+ return false
+
+ var continue: Boolean = false
+ var break: Boolean = false
+ while ((dataIterator == null || !dataIterator.hasNext) && !break) {
+ resultValues.clear()
+ continue = false
+ if (edgeResponseIterator == null || !edgeResponseIterator.hasNext) {
+ if (scanPartIterator.hasNext) {
+ try {
+ if (nebulaOptions.noColumn) {
+ edgeResponseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
+ scanPartIterator.next(),
+ nebulaOptions.label,
+ nebulaOptions.limit,
+ 0L,
+ Long.MaxValue,
+ true,
+ true)
+ } else {
+ edgeResponseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
+ scanPartIterator.next(),
+ nebulaOptions.label,
+ nebulaOptions.getReturnCols.asJava,
+ nebulaOptions.limit,
+ 0,
+ Long.MaxValue,
+ true,
+ true)
+ }
+ } catch {
+ case e: Exception =>
+ LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
+ storageClient.close()
+ throw new Exception(e.getMessage, e)
+ }
+ // jump to the next loop
+ continue = true
+ }
+ // break while loop
+ break = !continue
+ } else {
+ val next: ScanEdgeResult = edgeResponseIterator.next
+ if (!next.isEmpty) {
+ dataIterator = next.getEdgeTableRows.iterator().asScala
+ }
+ }
+ }
+
+ if (dataIterator == null) {
+ return false
+ }
+ dataIterator.hasNext
+ }
+
+ /**
+ * close the reader
+ */
+ protected def closeReader(): Unit = {
+ metaProvider.close()
+ storageClient.close()
+ }
+}
diff --git a/nebula-spark-connector/pom.xml b/nebula-spark-connector/pom.xml
index 08a65f42..1174f703 100644
--- a/nebula-spark-connector/pom.xml
+++ b/nebula-spark-connector/pom.xml
@@ -47,7 +47,7 @@
org.scalatest
scalatest-funsuite_2.11
- ${scalatest.version}
+ 3.2.3
test
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
index c298a890..4233ec81 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
@@ -42,7 +42,7 @@ class NebulaDataSource
* Creates a {@link DataSourceReader} to scan the data from Nebula Graph.
*/
override def createReader(options: DataSourceOptions): DataSourceReader = {
- val nebulaOptions = getNebulaOptions(options, OperaType.READ)
+ val nebulaOptions = getNebulaOptions(options)
val dataType = nebulaOptions.dataType
LOG.info("create reader")
@@ -65,7 +65,7 @@ class NebulaDataSource
mode: SaveMode,
options: DataSourceOptions): Optional[DataSourceWriter] = {
- val nebulaOptions = getNebulaOptions(options, OperaType.WRITE)
+ val nebulaOptions = getNebulaOptions(options)
val dataType = nebulaOptions.dataType
if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
LOG.warn(s"Currently do not support mode")
@@ -140,12 +140,12 @@ class NebulaDataSource
/**
* construct nebula options with DataSourceOptions
*/
- def getNebulaOptions(options: DataSourceOptions, operateType: OperaType.Value): NebulaOptions = {
+ def getNebulaOptions(options: DataSourceOptions): NebulaOptions = {
var parameters: Map[String, String] = Map()
for (entry: Entry[String, String] <- options.asMap().entrySet) {
parameters += (entry.getKey -> entry.getValue)
}
- val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))(operateType)
+ val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))
nebulaOptions
}
}
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala
index e9f122c1..0ee3b8bb 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala
@@ -47,6 +47,7 @@ package object connector {
val dfReader = reader
.format(classOf[NebulaDataSource].getName)
.option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
.option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
.option(NebulaOptions.LABEL, readConfig.getLabel)
.option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum)
@@ -84,6 +85,7 @@ package object connector {
val dfReader = reader
.format(classOf[NebulaDataSource].getName)
.option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
.option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
.option(NebulaOptions.LABEL, readConfig.getLabel)
.option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(","))
@@ -235,6 +237,7 @@ package object connector {
.format(classOf[NebulaDataSource].getName)
.mode(SaveMode.Overwrite)
.option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
.option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
.option(NebulaOptions.LABEL, writeConfig.getTagName)
.option(NebulaOptions.USER_NAME, writeConfig.getUser)
@@ -278,6 +281,7 @@ package object connector {
.format(classOf[NebulaDataSource].getName)
.mode(SaveMode.Overwrite)
.option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
.option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
.option(NebulaOptions.USER_NAME, writeConfig.getUser)
.option(NebulaOptions.PASSWD, writeConfig.getPasswd)
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
index b54f5a14..605a36dc 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
@@ -5,72 +5,11 @@
package com.vesoft.nebula.connector.reader
-import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator}
import com.vesoft.nebula.connector.NebulaOptions
import org.apache.spark.sql.types.StructType
-import org.slf4j.{Logger, LoggerFactory}
-import scala.collection.JavaConverters._
class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType)
extends NebulaPartitionReader(index, nebulaOptions, schema) {
- private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
- private var responseIterator: ScanEdgeResultIterator = _
-
- override def next(): Boolean = {
- if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext)
- return false
-
- var continue: Boolean = false
- var break: Boolean = false
- while ((dataIterator == null || !dataIterator.hasNext) && !break) {
- resultValues.clear()
- continue = false
- if (responseIterator == null || !responseIterator.hasNext) {
- if (scanPartIterator.hasNext) {
- try {
- if (nebulaOptions.noColumn) {
- responseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.limit,
- 0L,
- Long.MaxValue,
- true,
- true)
- } else {
- responseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.getReturnCols.asJava,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- }
- } catch {
- case e: Exception =>
- LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
- storageClient.close()
- throw new Exception(e.getMessage, e)
- }
- // jump to the next loop
- continue = true
- }
- // break while loop
- break = !continue
- } else {
- val next: ScanEdgeResult = responseIterator.next
- if (!next.isEmpty) {
- dataIterator = next.getEdgeTableRows.iterator().asScala
- }
- }
- }
-
- if (dataIterator == null) {
- return false
- }
- dataIterator.hasNext
- }
+ override def next(): Boolean = hasNextEdgeRow
}
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala
index e72f3460..06e06a59 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala
@@ -5,44 +5,18 @@
package com.vesoft.nebula.connector.reader
-import com.vesoft.nebula.client.graph.data.{
- CASignedSSLParam,
- HostAddress,
- SSLParam,
- SelfSignedSSLParam,
- ValueWrapper
-}
-import com.vesoft.nebula.client.storage.StorageClient
-import com.vesoft.nebula.client.storage.data.BaseTableRow
-import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter
-import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils}
-import com.vesoft.nebula.connector.exception.GraphConnectException
-import com.vesoft.nebula.connector.nebula.MetaProvider
-import com.vesoft.nebula.connector.ssl.SSLSignType
+import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader
import org.apache.spark.sql.types.StructType
import org.slf4j.{Logger, LoggerFactory}
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import scala.collection.mutable.ListBuffer
-
/**
* Read nebula data for each spark partition
*/
-abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] {
+abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] with NebulaReader {
private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
- private var metaProvider: MetaProvider = _
- private var schema: StructType = _
-
- protected var dataIterator: Iterator[BaseTableRow] = _
- protected var scanPartIterator: Iterator[Integer] = _
- protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]()
- protected var storageClient: StorageClient = _
-
/**
* @param index identifier for spark partition
* @param nebulaOptions nebula Options
@@ -50,113 +24,16 @@ abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] {
*/
def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) {
this()
- this.schema = schema
-
- metaProvider = new MetaProvider(
- nebulaOptions.getMetaAddress,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- nebulaOptions.enableMetaSSL,
- nebulaOptions.sslSignType,
- nebulaOptions.caSignParam,
- nebulaOptions.selfSignParam
- )
- val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress]
-
- for (addr <- nebulaOptions.getMetaAddress) {
- address.append(new HostAddress(addr._1, addr._2))
- }
-
- var sslParam: SSLParam = null
- if (nebulaOptions.enableStorageSSL) {
- SSLSignType.withName(nebulaOptions.sslSignType) match {
- case SSLSignType.CA => {
- val caSSLSignParams = nebulaOptions.caSignParam
- sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath,
- caSSLSignParams.crtFilePath,
- caSSLSignParams.keyFilePath)
- }
- case SSLSignType.SELF => {
- val selfSSLSignParams = nebulaOptions.selfSignParam
- sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath,
- selfSSLSignParams.keyFilePath,
- selfSSLSignParams.password)
- }
- case _ => throw new IllegalArgumentException("ssl sign type is not supported")
- }
- this.storageClient = new StorageClient(address.asJava,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- true,
- sslParam)
- } else {
- this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout)
- }
-
- if (!storageClient.connect()) {
- throw new GraphConnectException("storage connect failed.")
- }
- // allocate scanPart to this partition
- val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName)
-
+ val totalPart = super.init(index, nebulaOptions, schema)
// index starts with 1
val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt)
LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}")
scanPartIterator = scanParts.iterator
}
- override def get(): InternalRow = {
- val resultSet: Array[ValueWrapper] =
- dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper])
- val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema)
- val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
-
- for (i <- getters.indices) {
- val value: ValueWrapper = resultSet(i)
- var resolved = false
- if (value.isNull) {
- mutableRow.setNullAt(i)
- resolved = true
- }
- if (value.isString) {
- getters(i).apply(value.asString(), mutableRow, i)
- resolved = true
- }
- if (value.isDate) {
- getters(i).apply(value.asDate(), mutableRow, i)
- resolved = true
- }
- if (value.isTime) {
- getters(i).apply(value.asTime(), mutableRow, i)
- resolved = true
- }
- if (value.isDateTime) {
- getters(i).apply(value.asDateTime(), mutableRow, i)
- resolved = true
- }
- if (value.isLong) {
- getters(i).apply(value.asLong(), mutableRow, i)
- }
- if (value.isBoolean) {
- getters(i).apply(value.asBoolean(), mutableRow, i)
- }
- if (value.isDouble) {
- getters(i).apply(value.asDouble(), mutableRow, i)
- }
- if (value.isGeography) {
- getters(i).apply(value.asGeography(), mutableRow, i)
- }
- if (value.isDuration) {
- getters(i).apply(value.asDuration(), mutableRow, i)
- }
- }
- mutableRow
- }
+ override def get(): InternalRow = super.getRow()
override def close(): Unit = {
- metaProvider.close()
- storageClient.close()
+ super.closeReader()
}
}
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala
index 0ca8ee54..d859570e 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala
@@ -7,15 +7,12 @@ package com.vesoft.nebula.connector.reader
import java.util
-import com.vesoft.nebula.connector.nebula.MetaProvider
-import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils}
-import com.vesoft.nebula.meta.ColumnDef
+import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition}
-import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+import org.apache.spark.sql.types.{StructType}
import org.slf4j.LoggerFactory
-import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
/**
@@ -27,78 +24,16 @@ abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSour
private var datasetSchema: StructType = _
override def readSchema(): StructType = {
- datasetSchema = getSchema(nebulaOptions)
+ if (datasetSchema == null) {
+ datasetSchema = NebulaUtils.getSchema(nebulaOptions)
+ }
+
LOG.info(s"dataset's schema: $datasetSchema")
datasetSchema
}
- protected def getSchema: StructType = getSchema(nebulaOptions)
-
- /**
- * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula.
- */
- def getSchema(nebulaOptions: NebulaOptions): StructType = {
- val returnCols = nebulaOptions.getReturnCols
- val noColumn = nebulaOptions.noColumn
- val fields: ListBuffer[StructField] = new ListBuffer[StructField]
- val metaProvider = new MetaProvider(
- nebulaOptions.getMetaAddress,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- nebulaOptions.enableMetaSSL,
- nebulaOptions.sslSignType,
- nebulaOptions.caSignParam,
- nebulaOptions.selfSignParam
- )
-
- import scala.collection.JavaConverters._
- var schemaCols: Seq[ColumnDef] = Seq()
- val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType)
-
- // construct vertex or edge default prop
- if (isVertex) {
- fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false))
- } else {
- fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false))
- fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false))
- fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false))
- }
-
- var dataSchema: StructType = null
- // read no column
- if (noColumn) {
- dataSchema = new StructType(fields.toArray)
- return dataSchema
- }
- // get tag schema or edge schema
- val schema = if (isVertex) {
- metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label)
- } else {
- metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label)
- }
-
- schemaCols = schema.columns.asScala
-
- // read all columns
- if (returnCols.isEmpty) {
- schemaCols.foreach(columnDef => {
- LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ")
- fields.append(
- DataTypes.createStructField(new String(columnDef.getName),
- NebulaUtils.convertDataType(columnDef.getType),
- true))
- })
- } else {
- for (col: String <- returnCols) {
- fields.append(
- DataTypes
- .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true))
- }
- }
- dataSchema = new StructType(fields.toArray)
- dataSchema
- }
+ protected def getSchema: StructType =
+ if (datasetSchema == null) NebulaUtils.getSchema(nebulaOptions) else datasetSchema
}
/**
diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
index 3466b1dd..d6d056d2 100644
--- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
+++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
@@ -5,75 +5,11 @@
package com.vesoft.nebula.connector.reader
-import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator}
import com.vesoft.nebula.connector.NebulaOptions
import org.apache.spark.sql.types.StructType
-import org.slf4j.{Logger, LoggerFactory}
-
-import scala.collection.JavaConverters._
class NebulaVertexPartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType)
extends NebulaPartitionReader(index, nebulaOptions, schema) {
- private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
-
- private var responseIterator: ScanVertexResultIterator = _
-
- override def next(): Boolean = {
- if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext)
- return false
-
- var continue: Boolean = false
- var break: Boolean = false
- while ((dataIterator == null || !dataIterator.hasNext) && !break) {
- resultValues.clear()
- continue = false
- if (responseIterator == null || !responseIterator.hasNext) {
- if (scanPartIterator.hasNext) {
- try {
- if (nebulaOptions.noColumn) {
- responseIterator = storageClient.scanVertex(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- } else {
- responseIterator = storageClient.scanVertex(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.getReturnCols.asJava,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- }
- } catch {
- case e: Exception =>
- LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
- storageClient.close()
- throw new Exception(e.getMessage, e)
- }
- // jump to the next loop
- continue = true
- }
- // break while loop
- break = !continue
- } else {
- val next: ScanVertexResult = responseIterator.next
- if (!next.isEmpty) {
- dataIterator = next.getVertexTableRows.iterator().asScala
- }
- }
- }
-
- if (dataIterator == null) {
- return false
- }
- dataIterator.hasNext
- }
-
+ override def next(): Boolean = hasNextVertexRow
}
diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
index b8f0e72e..46c8f502 100644
--- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
+++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
@@ -169,6 +169,7 @@ class NebulaGraphMock {
LOG.error("create string type space failed," + createResp.getErrorMessage)
sys.exit(-1)
}
+ Thread.sleep(10000)
}
def mockIntIdGraphSchema(): Unit = {
@@ -184,6 +185,7 @@ class NebulaGraphMock {
LOG.error("create int type space failed," + createResp.getErrorMessage)
sys.exit(-1)
}
+ Thread.sleep(10000)
}
def close(): Unit = {
diff --git a/nebula-spark-connector_2.2/pom.xml b/nebula-spark-connector_2.2/pom.xml
index 26e46cc3..b68fa0c4 100644
--- a/nebula-spark-connector_2.2/pom.xml
+++ b/nebula-spark-connector_2.2/pom.xml
@@ -46,7 +46,7 @@
org.scalatest
scalatest-funsuite_2.11
- ${scalatest.version}
+ 3.2.3
test
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
index 8d978501..6019dc3c 100644
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
@@ -43,7 +43,7 @@ class NebulaDataSource
*/
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- val nebulaOptions = getNebulaOptions(parameters, OperaType.READ)
+ val nebulaOptions = getNebulaOptions(parameters)
LOG.info("create relation")
LOG.info(s"options ${parameters}")
@@ -59,7 +59,7 @@ class NebulaDataSource
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
- val nebulaOptions = getNebulaOptions(parameters, OperaType.WRITE)
+ val nebulaOptions = getNebulaOptions(parameters)
if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) {
LOG.warn(s"Currently do not support mode")
}
@@ -78,9 +78,8 @@ class NebulaDataSource
/**
* construct nebula options with DataSourceOptions
*/
- def getNebulaOptions(options: Map[String, String],
- operateType: OperaType.Value): NebulaOptions = {
- val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType)
+ def getNebulaOptions(options: Map[String, String]): NebulaOptions = {
+ val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))
nebulaOptions
}
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala
index 21c59f77..8c616727 100644
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala
@@ -110,6 +110,7 @@ package object connector {
val dfReader = reader
.format(classOf[NebulaDataSource].getName)
.option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
.option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
.option(NebulaOptions.LABEL, readConfig.getLabel)
.option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum)
@@ -147,6 +148,7 @@ package object connector {
val dfReader = reader
.format(classOf[NebulaDataSource].getName)
.option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
.option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
.option(NebulaOptions.LABEL, readConfig.getLabel)
.option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(","))
@@ -298,6 +300,7 @@ package object connector {
.format(classOf[NebulaDataSource].getName)
.mode(SaveMode.Overwrite)
.option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
.option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
.option(NebulaOptions.LABEL, writeConfig.getTagName)
.option(NebulaOptions.USER_NAME, writeConfig.getUser)
@@ -340,6 +343,7 @@ package object connector {
.format(classOf[NebulaDataSource].getName)
.mode(SaveMode.Overwrite)
.option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
.option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
.option(NebulaOptions.USER_NAME, writeConfig.getUser)
.option(NebulaOptions.PASSWD, writeConfig.getPasswd)
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
new file mode 100644
index 00000000..7b458006
--- /dev/null
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
@@ -0,0 +1,17 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.Partition
+import org.apache.spark.sql.types.StructType
+import org.slf4j.{Logger, LoggerFactory}
+
+class NebulaEdgePartitionReader(index: Partition, nebulaOptions: NebulaOptions, schema: StructType)
+ extends NebulaIterator(index, nebulaOptions, schema) {
+
+ override def hasNext(): Boolean = hasNextEdgeRow
+}
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala
deleted file mode 100644
index 45fab6c8..00000000
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala
+++ /dev/null
@@ -1,77 +0,0 @@
-/* Copyright (c) 2022 vesoft inc. All rights reserved.
- *
- * This source code is licensed under Apache 2.0 License.
- */
-
-package com.vesoft.nebula.connector.reader
-
-import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator}
-import com.vesoft.nebula.connector.NebulaOptions
-import org.apache.spark.Partition
-import org.apache.spark.sql.types.StructType
-import org.slf4j.LoggerFactory
-import scala.collection.JavaConverters._
-
-class NebulaEdgeReader(split: Partition, nebulaOptions: NebulaOptions, schema: StructType)
- extends NebulaIterator(split, nebulaOptions, schema) {
- private val LOG = LoggerFactory.getLogger(this.getClass)
-
- private var responseIterator: ScanEdgeResultIterator = _
-
- override def hasNext: Boolean = {
- if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext)
- return false
-
- var continue: Boolean = false
- var break: Boolean = false
- while ((dataIterator == null || !dataIterator.hasNext) && !break) {
- resultValues.clear()
- continue = false
- if (responseIterator == null || !responseIterator.hasNext) {
- if (scanPartIterator.hasNext) {
- try {
- if (nebulaOptions.noColumn) {
- responseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.limit,
- 0L,
- Long.MaxValue,
- true,
- true)
- } else {
- responseIterator = storageClient.scanEdge(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.getReturnCols.asJava,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- }
- } catch {
- case e: Exception =>
- LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
- storageClient.close()
- throw new Exception(e.getMessage, e)
- }
- // jump to the next loop
- continue = true
- }
- // break while loop
- break = !continue
- } else {
- val next: ScanEdgeResult = responseIterator.next
- if (!next.isEmpty) {
- dataIterator = next.getEdgeTableRows.iterator().asScala
- }
- }
- }
-
- if (dataIterator == null) {
- return false
- }
- dataIterator.hasNext
- }
-}
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala
index dc7b085e..3c182aa8 100644
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala
@@ -5,100 +5,23 @@
package com.vesoft.nebula.connector.reader
-import com.vesoft.nebula.client.graph.data.{
- CASignedSSLParam,
- HostAddress,
- SSLParam,
- SelfSignedSSLParam,
- ValueWrapper
-}
-import com.vesoft.nebula.client.storage.data.BaseTableRow
-import com.vesoft.nebula.client.storage.StorageClient
import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils}
-import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter
-import com.vesoft.nebula.connector.exception.GraphConnectException
-import com.vesoft.nebula.connector.nebula.MetaProvider
-import com.vesoft.nebula.connector.ssl.SSLSignType
import org.apache.spark.Partition
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.types.StructType
import org.slf4j.{Logger, LoggerFactory}
-import scala.collection.mutable
-import scala.collection.mutable.ListBuffer
-import scala.collection.JavaConverters._
-
/**
- * @todo
* iterator for nebula vertex or edge data
* convert each vertex data or edge data to Spark SQL's Row
*/
-abstract class NebulaIterator extends Iterator[InternalRow] {
-
- private val LOG: Logger = LoggerFactory.getLogger(classOf[NebulaIterator])
-
- private var metaProvider: MetaProvider = _
- private var schema: StructType = _
-
- protected var dataIterator: Iterator[BaseTableRow] = _
- protected var scanPartIterator: Iterator[Integer] = _
- protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]()
- protected var storageClient: StorageClient = _
+abstract class NebulaIterator extends Iterator[InternalRow] with NebulaReader {
+ private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
def this(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) {
this()
- this.schema = schema
-
- metaProvider = new MetaProvider(
- nebulaOptions.getMetaAddress,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- nebulaOptions.enableMetaSSL,
- nebulaOptions.sslSignType,
- nebulaOptions.caSignParam,
- nebulaOptions.selfSignParam
- )
- val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress]
-
- for (addr <- nebulaOptions.getMetaAddress) {
- address.append(new HostAddress(addr._1, addr._2))
- }
-
- var sslParam: SSLParam = null
- if (nebulaOptions.enableStorageSSL) {
- SSLSignType.withName(nebulaOptions.sslSignType) match {
- case SSLSignType.CA => {
- val caSSLSignParams = nebulaOptions.caSignParam
- sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath,
- caSSLSignParams.crtFilePath,
- caSSLSignParams.keyFilePath)
- }
- case SSLSignType.SELF => {
- val selfSSLSignParams = nebulaOptions.selfSignParam
- sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath,
- selfSSLSignParams.keyFilePath,
- selfSSLSignParams.password)
- }
- case _ => throw new IllegalArgumentException("ssl sign type is not supported")
- }
- this.storageClient = new StorageClient(address.asJava,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- true,
- sslParam)
- } else {
- this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout)
- }
-
- if (!storageClient.connect()) {
- throw new GraphConnectException("storage connect failed.")
- }
- // allocate scanPart to this partition
- val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName)
-
+ val totalPart = super.init(index.index, nebulaOptions, schema)
+ // index starts with 0
val nebulaPartition = index.asInstanceOf[NebulaPartition]
val scanParts =
nebulaPartition.getScanParts(totalPart, nebulaOptions.partitionNums.toInt)
@@ -107,61 +30,12 @@ abstract class NebulaIterator extends Iterator[InternalRow] {
}
/**
- * @todo
* whether this iterator can provide another element.
*/
override def hasNext: Boolean
/**
- * @todo
* Produces the next vertex or edge of this iterator.
*/
- override def next(): InternalRow = {
- val resultSet: Array[ValueWrapper] =
- dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper])
- val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema)
- val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
-
- for (i <- getters.indices) {
- val value: ValueWrapper = resultSet(i)
- var resolved = false
- if (value.isNull) {
- mutableRow.setNullAt(i)
- resolved = true
- }
- if (value.isString) {
- getters(i).apply(value.asString(), mutableRow, i)
- resolved = true
- }
- if (value.isDate) {
- getters(i).apply(value.asDate(), mutableRow, i)
- resolved = true
- }
- if (value.isTime) {
- getters(i).apply(value.asTime(), mutableRow, i)
- resolved = true
- }
- if (value.isDateTime) {
- getters(i).apply(value.asDateTime(), mutableRow, i)
- resolved = true
- }
- if (value.isLong) {
- getters(i).apply(value.asLong(), mutableRow, i)
- }
- if (value.isBoolean) {
- getters(i).apply(value.asBoolean(), mutableRow, i)
- }
- if (value.isDouble) {
- getters(i).apply(value.asDouble(), mutableRow, i)
- }
- if (value.isGeography) {
- getters(i).apply(value.asGeography(), mutableRow, i)
- }
- if (value.isDuration) {
- getters(i).apply(value.asDuration(), mutableRow, i)
- }
- }
- mutableRow
- }
-
+ override def next(): InternalRow = super.getRow()
}
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala
index bbe5118c..3e29b822 100644
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala
@@ -27,8 +27,8 @@ class NebulaRDD(val sqlContext: SQLContext, var nebulaOptions: NebulaOptions, sc
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val dataType = nebulaOptions.dataType
if (DataTypeEnum.VERTEX.toString.equalsIgnoreCase(dataType))
- new NebulaVertexReader(split, nebulaOptions, schema)
- else new NebulaEdgeReader(split, nebulaOptions, schema)
+ new NebulaVertexPartitionReader(split, nebulaOptions, schema)
+ else new NebulaEdgePartitionReader(split, nebulaOptions, schema)
}
override def getPartitions = {
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala
index 53733061..129f1b4d 100644
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala
@@ -5,95 +5,35 @@
package com.vesoft.nebula.connector.reader
-import com.vesoft.nebula.connector.nebula.MetaProvider
-import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils}
-import com.vesoft.nebula.meta.ColumnDef
+import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
-import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
+import org.apache.spark.sql.types.{StructType}
import org.slf4j.LoggerFactory
-import scala.collection.mutable.ListBuffer
-
case class NebulaRelation(override val sqlContext: SQLContext, nebulaOptions: NebulaOptions)
extends BaseRelation
with TableScan {
private val LOG = LoggerFactory.getLogger(this.getClass)
- protected lazy val datasetSchema: StructType = getSchema(nebulaOptions)
+ protected var datasetSchema: StructType = _
+ NebulaUtils.getSchema(nebulaOptions)
override val needConversion: Boolean = false
- override def schema: StructType = getSchema(nebulaOptions)
-
- /**
- * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula.
- */
- private def getSchema(nebulaOptions: NebulaOptions): StructType = {
- val returnCols = nebulaOptions.getReturnCols
- val noColumn = nebulaOptions.noColumn
- val fields: ListBuffer[StructField] = new ListBuffer[StructField]
- val metaProvider = new MetaProvider(
- nebulaOptions.getMetaAddress,
- nebulaOptions.timeout,
- nebulaOptions.connectionRetry,
- nebulaOptions.executionRetry,
- nebulaOptions.enableMetaSSL,
- nebulaOptions.sslSignType,
- nebulaOptions.caSignParam,
- nebulaOptions.selfSignParam
- )
-
- import scala.collection.JavaConverters._
- var schemaCols: Seq[ColumnDef] = Seq()
- val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType)
-
- // construct vertex or edge default prop
- if (isVertex) {
- fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false))
- } else {
- fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false))
- fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false))
- fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false))
- }
-
- var dataSchema: StructType = null
- // read no column
- if (noColumn) {
- dataSchema = new StructType(fields.toArray)
- return dataSchema
+ override def schema: StructType = {
+ if (datasetSchema == null) {
+ datasetSchema = NebulaUtils.getSchema(nebulaOptions)
}
- // get tag schema or edge schema
- val schema = if (isVertex) {
- metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label)
- } else {
- metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label)
- }
-
- schemaCols = schema.columns.asScala
-
- // read all columns
- if (returnCols.isEmpty) {
- schemaCols.foreach(columnDef => {
- LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ")
- fields.append(
- DataTypes.createStructField(new String(columnDef.getName),
- NebulaUtils.convertDataType(columnDef.getType),
- true))
- })
- } else {
- for (col: String <- returnCols) {
- fields.append(
- DataTypes
- .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true))
- }
- }
- dataSchema = new StructType(fields.toArray)
- dataSchema
+ datasetSchema
}
override def buildScan(): RDD[Row] = {
+
+ if (datasetSchema == null) {
+ datasetSchema = NebulaUtils.getSchema(nebulaOptions)
+ }
if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) {
new NebulaNgqlRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]]
} else {
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
new file mode 100644
index 00000000..2f4aadb5
--- /dev/null
+++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
@@ -0,0 +1,19 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.Partition
+import org.apache.spark.sql.types.StructType
+
+class NebulaVertexPartitionReader(index: Partition,
+ nebulaOptions: NebulaOptions,
+ schema: StructType)
+ extends NebulaIterator(index, nebulaOptions, schema) {
+
+ override def hasNext: Boolean = hasNextVertexRow
+
+}
diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala
deleted file mode 100644
index c27dc339..00000000
--- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright (c) 2022 vesoft inc. All rights reserved.
- *
- * This source code is licensed under Apache 2.0 License.
- */
-
-package com.vesoft.nebula.connector.reader
-
-import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator}
-import com.vesoft.nebula.connector.NebulaOptions
-import org.apache.spark.Partition
-import org.apache.spark.sql.types.StructType
-import org.slf4j.LoggerFactory
-import scala.collection.JavaConverters._
-
-class NebulaVertexReader(split: Partition, nebulaOptions: NebulaOptions, schema: StructType)
- extends NebulaIterator(split, nebulaOptions, schema) {
-
- private val LOG = LoggerFactory.getLogger(this.getClass)
-
- private var responseIterator: ScanVertexResultIterator = _
-
- override def hasNext: Boolean = {
- if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext)
- return false
-
- var continue: Boolean = false
- var break: Boolean = false
- while ((dataIterator == null || !dataIterator.hasNext) && !break) {
- resultValues.clear()
- continue = false
- if (responseIterator == null || !responseIterator.hasNext) {
- if (scanPartIterator.hasNext) {
- try {
- if (nebulaOptions.noColumn) {
- responseIterator = storageClient.scanVertex(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- } else {
- responseIterator = storageClient.scanVertex(nebulaOptions.spaceName,
- scanPartIterator.next(),
- nebulaOptions.label,
- nebulaOptions.getReturnCols.asJava,
- nebulaOptions.limit,
- 0,
- Long.MaxValue,
- true,
- true)
- }
- } catch {
- case e: Exception =>
- LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e)
- storageClient.close()
- throw new Exception(e.getMessage, e)
- }
- // jump to the next loop
- continue = true
- }
- // break while loop
- break = !continue
- } else {
- val next: ScanVertexResult = responseIterator.next
- if (!next.isEmpty) {
- dataIterator = next.getVertexTableRows.iterator().asScala
- }
- }
- }
-
- if (dataIterator == null) {
- return false
- }
- dataIterator.hasNext
- }
-}
diff --git a/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
index b8f0e72e..46c8f502 100644
--- a/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
+++ b/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
@@ -169,6 +169,7 @@ class NebulaGraphMock {
LOG.error("create string type space failed," + createResp.getErrorMessage)
sys.exit(-1)
}
+ Thread.sleep(10000)
}
def mockIntIdGraphSchema(): Unit = {
@@ -184,6 +185,7 @@ class NebulaGraphMock {
LOG.error("create int type space failed," + createResp.getErrorMessage)
sys.exit(-1)
}
+ Thread.sleep(10000)
}
def close(): Unit = {
diff --git a/nebula-spark-connector_3.0/.gitignore b/nebula-spark-connector_3.0/.gitignore
new file mode 100644
index 00000000..84e7a6bc
--- /dev/null
+++ b/nebula-spark-connector_3.0/.gitignore
@@ -0,0 +1,36 @@
+# Compiled class file
+*.class
+
+# Log file
+*.log
+
+# BlueJ files
+*.ctxt
+
+# Mobile Tools for Java (J2ME)
+.mtj.tmp/
+
+# Package Files #
+*.jar
+*.war
+*.nar
+*.ear
+*.zip
+*.tar.gz
+*.rar
+
+# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
+hs_err_pid*
+
+# build target
+target/
+
+# IDE
+.idea/
+.eclipse/
+*.iml
+
+spark-importer.ipr
+spark-importer.iws
+
+.DS_Store
diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml
new file mode 100644
index 00000000..989a87c8
--- /dev/null
+++ b/nebula-spark-connector_3.0/pom.xml
@@ -0,0 +1,288 @@
+
+
+
+ nebula-spark
+ com.vesoft
+ 3.0-SNAPSHOT
+ ../pom.xml
+
+ 4.0.0
+
+ nebula-spark-connector_3.0
+
+
+ 1.8
+ 1.8
+ 2.12
+ 3.0.0
+ 2.12.10
+ 3.2.3
+
+
+
+
+ org.apache.spark
+ spark-core_2.12
+ ${spark3.0.version}
+ provided
+
+
+ org.apache.spark
+ spark-sql_2.12
+ ${spark3.0.version}
+ provided
+
+
+ org.apache.spark
+ spark-graphx_2.12
+ ${spark3.0.version}
+ provided
+
+
+ com.vesoft
+ nebula-spark-common
+ ${project.version}
+
+
+
+ org.scalatest
+ scalatest-funsuite_2.12
+ 3.2.3
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-deploy-plugin
+ 2.8.2
+
+
+ default-deploy
+ deploy
+
+
+
+
+ org.sonatype.plugins
+ nexus-staging-maven-plugin
+ 1.6.8
+ true
+
+ ossrh
+ https://oss.sonatype.org/
+ true
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 3.2.0
+
+
+
+ test-jar
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.1
+
+ ${compiler.source.version}
+ ${compiler.target.version}
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+ 2.5.3
+
+
+ package
+
+ single
+
+
+
+
+
+ jar-with-dependencies
+
+ ${project.artifactId}-${project.version}-jar-with-dependencies
+ false
+
+
+
+
+ org.scala-tools
+ maven-scala-plugin
+ 2.15.2
+
+ 2.12.10
+
+ -target:jvm-1.8
+
+
+ -Xss4096K
+
+
+
+
+ scala-compile
+
+ compile
+
+
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+ scala-test-compile
+
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ 3.2.0
+
+
+ attach-sources
+
+ jar
+
+
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.2
+
+
+
+ compile
+ testCompile
+
+
+
+ Scaladoc
+
+ doc
+
+ prepare-package
+
+
+ -nobootcp
+ -no-link-warnings
+
+
+
+
+ attach-javadocs
+
+ doc-jar
+
+
+
+ -nobootcp
+ -no-link-warnings
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.12.4
+
+
+ **/*Test.*
+ **/*Suite.*
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.0.0
+
+
+ test
+
+ test
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 3.2.0
+
+
+ attach-javadocs
+ package
+
+ jar
+
+
+ UTF-8
+ UTF-8
+
+ -source 8
+ -Xdoclint:none
+
+
+
+
+
+
+ org.jacoco
+ jacoco-maven-plugin
+ 0.8.7
+
+
+
+ prepare-agent
+
+
+
+ report
+ test
+
+ report
+
+
+
+
+
+
+
+
+ snapshots
+ https://oss.sonatype.org/content/repositories/snapshots/
+
+
+
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
new file mode 100644
index 00000000..a5c629bf
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala
@@ -0,0 +1,86 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector
+
+import java.util
+import java.util.Map.Entry
+
+import com.vesoft.nebula.connector.nebula.MetaProvider
+import com.vesoft.nebula.connector.reader.SimpleScanBuilder
+import com.vesoft.nebula.connector.utils.Validations
+import com.vesoft.nebula.meta.ColumnDef
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.{
+ SupportsRead,
+ SupportsWrite,
+ Table,
+ TableCapability,
+ TableProvider
+}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable.ListBuffer
+import scala.jdk.CollectionConverters.asScalaSetConverter
+
+class NebulaDataSource extends TableProvider with DataSourceRegister {
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ Validations.validateSparkVersion("3.*")
+
+ private var schema: StructType = null
+ private var nebulaOptions: NebulaOptions = _
+
+ /**
+ * The string that represents the format that nebula data source provider uses.
+ */
+ override def shortName(): String = "nebula"
+
+ override def supportsExternalMetadata(): Boolean = true
+
+ override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = {
+ if (schema == null) {
+ nebulaOptions = getNebulaOptions(caseInsensitiveStringMap)
+ if (nebulaOptions.operaType == OperaType.READ) {
+ schema = NebulaUtils.getSchema(nebulaOptions)
+ } else {
+ schema = new StructType()
+ }
+ }
+ schema
+ }
+
+ override def getTable(tableSchema: StructType,
+ transforms: Array[Transform],
+ map: util.Map[String, String]): Table = {
+ if (nebulaOptions == null) {
+ nebulaOptions = getNebulaOptions(new CaseInsensitiveStringMap(map))
+ }
+ new NebulaTable(tableSchema, nebulaOptions)
+ }
+
+ /**
+ * construct nebula options with DataSourceOptions
+ */
+ private def getNebulaOptions(
+ caseInsensitiveStringMap: CaseInsensitiveStringMap): NebulaOptions = {
+ var parameters: Map[String, String] = Map()
+ for (entry: Entry[String, String] <- caseInsensitiveStringMap
+ .asCaseSensitiveMap()
+ .entrySet()
+ .asScala) {
+ parameters += (entry.getKey -> entry.getValue)
+ }
+ val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))
+ nebulaOptions
+ }
+
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala
new file mode 100644
index 00000000..a256fa93
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala
@@ -0,0 +1,70 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector
+
+import java.util
+import java.util.Map.Entry
+
+import com.vesoft.nebula.connector.reader.SimpleScanBuilder
+import com.vesoft.nebula.connector.writer.NebulaWriterBuilder
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.slf4j.LoggerFactory
+
+import scala.collection.JavaConverters._
+
+class NebulaTable(schema: StructType, nebulaOptions: NebulaOptions)
+ extends Table
+ with SupportsRead
+ with SupportsWrite {
+
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ /**
+ * Creates a {@link DataSourceReader} to scan the data from Nebula Graph.
+ */
+ override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): ScanBuilder = {
+ LOG.info("create scan builder")
+ LOG.info(s"options ${caseInsensitiveStringMap.asCaseSensitiveMap()}")
+
+ new SimpleScanBuilder(nebulaOptions, schema)
+ }
+
+ /**
+ * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph.
+ */
+ override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = {
+ LOG.info("create writer")
+ LOG.info(s"options ${logicalWriteInfo.options().asCaseSensitiveMap()}")
+ new NebulaWriterBuilder(logicalWriteInfo.schema(), SaveMode.Append, nebulaOptions)
+ }
+
+ /**
+ * NebulaGraph table name
+ */
+ override def name(): String = {
+ nebulaOptions.label
+ }
+
+ override def schema(): StructType = schema
+
+ override def capabilities(): util.Set[TableCapability] =
+ Set(
+ TableCapability.BATCH_READ,
+ TableCapability.BATCH_WRITE,
+ TableCapability.ACCEPT_ANY_SCHEMA,
+ TableCapability.OVERWRITE_BY_FILTER,
+ TableCapability.OVERWRITE_DYNAMIC,
+ TableCapability.STREAMING_WRITE,
+ TableCapability.MICRO_BATCH_READ
+ ).asJava
+
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala
new file mode 100644
index 00000000..79c4ba87
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala
@@ -0,0 +1,282 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector
+
+import com.vesoft.nebula.connector.ssl.SSLSignType
+import com.vesoft.nebula.connector.utils.SparkValidate
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{
+ DataFrame,
+ DataFrameReader,
+ DataFrameWriter,
+ Encoder,
+ Encoders,
+ Row,
+ SaveMode
+}
+
+import scala.collection.mutable.ListBuffer
+
+package object connector {
+
+ /**
+ * spark reader for nebula graph
+ */
+ implicit class NebulaDataFrameReader(reader: DataFrameReader) {
+ var connectionConfig: NebulaConnectionConfig = _
+ var readConfig: ReadNebulaConfig = _
+
+ def nebula(connectionConfig: NebulaConnectionConfig,
+ readConfig: ReadNebulaConfig): NebulaDataFrameReader = {
+ SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*")
+ this.connectionConfig = connectionConfig
+ this.readConfig = readConfig
+ this
+ }
+
+ /**
+ * Reading com.vesoft.nebula.tools.connector.vertices from Nebula Graph
+ * @return DataFrame
+ */
+ def loadVerticesToDF(): DataFrame = {
+ assert(connectionConfig != null && readConfig != null,
+ "nebula config is not set, please call nebula() before loadVerticesToDF")
+ val dfReader = reader
+ .format(classOf[NebulaDataSource].getName)
+ .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
+ .option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
+ .option(NebulaOptions.LABEL, readConfig.getLabel)
+ .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum)
+ .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(","))
+ .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn)
+ .option(NebulaOptions.LIMIT, readConfig.getLimit)
+ .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress)
+ .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout)
+ .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry)
+ .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry)
+ .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL)
+ .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL)
+
+ if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) {
+ dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType)
+ SSLSignType.withName(connectionConfig.getSignType) match {
+ case SSLSignType.CA =>
+ dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam)
+ case SSLSignType.SELF =>
+ dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam)
+ }
+ }
+
+ dfReader.load()
+ }
+
+ /**
+ * Reading edges from Nebula Graph
+ * @return DataFrame
+ */
+ def loadEdgesToDF(): DataFrame = {
+ assert(connectionConfig != null && readConfig != null,
+ "nebula config is not set, please call nebula() before loadEdgesToDF")
+
+ val dfReader = reader
+ .format(classOf[NebulaDataSource].getName)
+ .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString)
+ .option(NebulaOptions.SPACE_NAME, readConfig.getSpace)
+ .option(NebulaOptions.LABEL, readConfig.getLabel)
+ .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(","))
+ .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn)
+ .option(NebulaOptions.LIMIT, readConfig.getLimit)
+ .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum)
+ .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress)
+ .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout)
+ .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry)
+ .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry)
+ .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL)
+ .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL)
+
+ if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) {
+ dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType)
+ SSLSignType.withName(connectionConfig.getSignType) match {
+ case SSLSignType.CA =>
+ dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam)
+ case SSLSignType.SELF =>
+ dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam)
+ }
+ }
+
+ dfReader.load()
+ }
+
+ /**
+ * read nebula vertex edge to graphx's vertex
+ * use hash() for String type vertex id.
+ */
+ def loadVerticesToGraphx(): RDD[NebulaGraphxVertex] = {
+ val vertexDataset = loadVerticesToDF()
+ implicit val encoder: Encoder[NebulaGraphxVertex] =
+ Encoders.bean[NebulaGraphxVertex](classOf[NebulaGraphxVertex])
+
+ vertexDataset
+ .map(row => {
+ val vertexId = row.get(0)
+ val vid: Long = vertexId.toString.toLong
+ val props: ListBuffer[Any] = ListBuffer()
+ for (i <- row.schema.fields.indices) {
+ if (i != 0) {
+ props.append(row.get(i))
+ }
+ }
+ (vid, props.toList)
+ })(encoder)
+ .rdd
+ }
+
+ /**
+ * read nebula edge edge to graphx's edge
+ * use hash() for String type srcId and dstId.
+ */
+ def loadEdgesToGraphx(): RDD[NebulaGraphxEdge] = {
+ val edgeDataset = loadEdgesToDF()
+ implicit val encoder: Encoder[NebulaGraphxEdge] =
+ Encoders.bean[NebulaGraphxEdge](classOf[NebulaGraphxEdge])
+
+ edgeDataset
+ .map(row => {
+ val props: ListBuffer[Any] = ListBuffer()
+ for (i <- row.schema.fields.indices) {
+ if (i != 0 && i != 1 && i != 2) {
+ props.append(row.get(i))
+ }
+ }
+ val srcId = row.get(0)
+ val dstId = row.get(1)
+ val edgeSrc = srcId.toString.toLong
+ val edgeDst = dstId.toString.toLong
+ val edgeProp = (row.get(2).toString.toLong, props.toList)
+ org.apache.spark.graphx
+ .Edge(edgeSrc, edgeDst, edgeProp)
+ })(encoder)
+ .rdd
+ }
+
+ }
+
+ /**
+ * spark writer for nebula graph
+ */
+ implicit class NebulaDataFrameWriter(writer: DataFrameWriter[Row]) {
+
+ var connectionConfig: NebulaConnectionConfig = _
+ var writeNebulaConfig: WriteNebulaConfig = _
+
+ /**
+ * config nebula connection
+ * @param connectionConfig connection parameters
+ * @param writeNebulaConfig write parameters for vertex or edge
+ */
+ def nebula(connectionConfig: NebulaConnectionConfig,
+ writeNebulaConfig: WriteNebulaConfig): NebulaDataFrameWriter = {
+ SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*")
+ this.connectionConfig = connectionConfig
+ this.writeNebulaConfig = writeNebulaConfig
+ this
+ }
+
+ /**
+ * write dataframe into nebula vertex
+ */
+ def writeVertices(): Unit = {
+ assert(connectionConfig != null && writeNebulaConfig != null,
+ "nebula config is not set, please call nebula() before writeVertices")
+ val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaVertexConfig]
+ val dfWriter = writer
+ .format(classOf[NebulaDataSource].getName)
+ .mode(SaveMode.Overwrite)
+ .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
+ .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
+ .option(NebulaOptions.LABEL, writeConfig.getTagName)
+ .option(NebulaOptions.USER_NAME, writeConfig.getUser)
+ .option(NebulaOptions.PASSWD, writeConfig.getPasswd)
+ .option(NebulaOptions.VERTEX_FIELD, writeConfig.getVidField)
+ .option(NebulaOptions.VID_POLICY, writeConfig.getVidPolicy)
+ .option(NebulaOptions.BATCH, writeConfig.getBatch)
+ .option(NebulaOptions.VID_AS_PROP, writeConfig.getVidAsProp)
+ .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode)
+ .option(NebulaOptions.DELETE_EDGE, writeConfig.getDeleteEdge)
+ .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress)
+ .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress)
+ .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout)
+ .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry)
+ .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry)
+ .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL)
+ .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL)
+
+ if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) {
+ dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType)
+ SSLSignType.withName(connectionConfig.getSignType) match {
+ case SSLSignType.CA =>
+ dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam)
+ case SSLSignType.SELF =>
+ dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam)
+ }
+ }
+
+ dfWriter.save()
+ }
+
+ /**
+ * write dataframe into nebula edge
+ */
+ def writeEdges(): Unit = {
+
+ assert(connectionConfig != null && writeNebulaConfig != null,
+ "nebula config is not set, please call nebula() before writeEdges")
+ val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaEdgeConfig]
+ val dfWriter = writer
+ .format(classOf[NebulaDataSource].getName)
+ .mode(SaveMode.Overwrite)
+ .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString)
+ .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString)
+ .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace)
+ .option(NebulaOptions.USER_NAME, writeConfig.getUser)
+ .option(NebulaOptions.PASSWD, writeConfig.getPasswd)
+ .option(NebulaOptions.LABEL, writeConfig.getEdgeName)
+ .option(NebulaOptions.SRC_VERTEX_FIELD, writeConfig.getSrcFiled)
+ .option(NebulaOptions.DST_VERTEX_FIELD, writeConfig.getDstField)
+ .option(NebulaOptions.SRC_POLICY, writeConfig.getSrcPolicy)
+ .option(NebulaOptions.DST_POLICY, writeConfig.getDstPolicy)
+ .option(NebulaOptions.RANK_FIELD, writeConfig.getRankField)
+ .option(NebulaOptions.BATCH, writeConfig.getBatch)
+ .option(NebulaOptions.SRC_AS_PROP, writeConfig.getSrcAsProp)
+ .option(NebulaOptions.DST_AS_PROP, writeConfig.getDstAsProp)
+ .option(NebulaOptions.RANK_AS_PROP, writeConfig.getRankAsProp)
+ .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode)
+ .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress)
+ .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress)
+ .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout)
+ .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry)
+ .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry)
+ .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL)
+ .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL)
+
+ if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) {
+ dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType)
+ SSLSignType.withName(connectionConfig.getSignType) match {
+ case SSLSignType.CA =>
+ dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam)
+ case SSLSignType.SELF =>
+ dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam)
+ }
+ }
+
+ dfWriter.save()
+ }
+ }
+
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
new file mode 100644
index 00000000..e1062d49
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala
@@ -0,0 +1,15 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.sql.types.StructType
+
+class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType)
+ extends NebulaPartitionReader(index, nebulaOptions, schema) {
+
+ override def next(): Boolean = hasNextEdgeRow
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala
new file mode 100644
index 00000000..44fdf698
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala
@@ -0,0 +1,39 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.PartitionReader
+import org.apache.spark.sql.types.StructType
+import org.slf4j.{Logger, LoggerFactory}
+
+/**
+ * Read nebula data for each spark partition
+ */
+abstract class NebulaPartitionReader extends PartitionReader[InternalRow] with NebulaReader {
+ private val LOG: Logger = LoggerFactory.getLogger(this.getClass)
+
+ /**
+ * @param index identifier for spark partition
+ * @param nebulaOptions nebula Options
+ * @param schema of data need to read
+ */
+ def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) {
+ this()
+ val totalPart = super.init(index, nebulaOptions, schema)
+ // index starts with 1
+ val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt)
+ LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}")
+ scanPartIterator = scanParts.iterator
+ }
+
+ override def get(): InternalRow = super.getRow()
+
+ override def close(): Unit = {
+ super.closeReader()
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala
new file mode 100644
index 00000000..9e5e7e5e
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala
@@ -0,0 +1,25 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
+import org.apache.spark.sql.types.StructType
+
+class NebulaPartitionReaderFactory(private val nebulaOptions: NebulaOptions,
+ private val schema: StructType)
+ extends PartitionReaderFactory {
+ override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = {
+ val partition = inputPartition.asInstanceOf[NebulaPartition].partition
+ if (DataTypeEnum.VERTEX.toString.equals(nebulaOptions.dataType)) {
+
+ new NebulaVertexPartitionReader(partition, nebulaOptions, schema)
+ } else {
+ new NebulaEdgePartitionReader(partition, nebulaOptions, schema)
+ }
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
new file mode 100644
index 00000000..da5b02d2
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala
@@ -0,0 +1,15 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.sql.types.StructType
+
+class NebulaVertexPartitionReader(split: Int, nebulaOptions: NebulaOptions, schema: StructType)
+ extends NebulaPartitionReader(split, nebulaOptions, schema) {
+
+ override def next(): Boolean = hasNextVertexRow
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala
new file mode 100644
index 00000000..f7515917
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala
@@ -0,0 +1,76 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import java.util
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.sql.connector.read.{
+ Batch,
+ InputPartition,
+ PartitionReaderFactory,
+ Scan,
+ ScanBuilder,
+ SupportsPushDownFilters,
+ SupportsPushDownRequiredColumns
+}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+import scala.collection.mutable.ListBuffer
+import scala.jdk.CollectionConverters.asScalaBufferConverter
+
+class SimpleScanBuilder(nebulaOptions: NebulaOptions, schema: StructType)
+ extends ScanBuilder
+ with SupportsPushDownFilters
+ with SupportsPushDownRequiredColumns {
+
+ private var filters: Array[Filter] = Array[Filter]()
+
+ override def build(): Scan = {
+ new SimpleScan(nebulaOptions, nebulaOptions.partitionNums.toInt, schema)
+ }
+
+ override def pushFilters(pushFilters: Array[Filter]): Array[Filter] = {
+ if (nebulaOptions.pushDownFiltersEnabled) {
+ filters = pushFilters
+ }
+ pushFilters
+ }
+
+ override def pushedFilters(): Array[Filter] = filters
+
+ override def pruneColumns(requiredColumns: StructType): Unit = {
+ if (!nebulaOptions.pushDownFiltersEnabled || requiredColumns == schema) {
+ new StructType()
+ }
+ }
+}
+
+class SimpleScan(nebulaOptions: NebulaOptions, nebulaTotalPart: Int, schema: StructType)
+ extends Scan
+ with Batch {
+ override def toBatch: Batch = this
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val partitionSize = nebulaTotalPart
+ val inputPartitions = for (i <- 1 to partitionSize)
+ yield {
+ NebulaPartition(i)
+ }
+ inputPartitions.map(_.asInstanceOf[InputPartition]).toArray
+ }
+
+ override def readSchema(): StructType = schema
+
+ override def createReaderFactory(): PartitionReaderFactory =
+ new NebulaPartitionReaderFactory(nebulaOptions, schema)
+}
+
+/**
+ * An identifier for a partition in an NebulaRDD.
+ */
+case class NebulaPartition(partition: Int) extends InputPartition
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala
new file mode 100644
index 00000000..93b9a87c
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala
@@ -0,0 +1,18 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.utils
+
+import org.apache.spark.sql.SparkSession
+
+object Validations {
+ def validateSparkVersion(supportedVersions: String*): Unit = {
+ val sparkVersion = SparkSession.getActiveSession.map { _.version }.getOrElse("UNKNOWN")
+ if (!(sparkVersion == "UNKNOWN" || supportedVersions.exists(sparkVersion.matches))) {
+ throw new RuntimeException(
+ s"Your current spark version ${sparkVersion} is not supported bt the current connector.")
+ }
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala
new file mode 100644
index 00000000..184f63ba
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala
@@ -0,0 +1,10 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import org.apache.spark.sql.connector.write.WriterCommitMessage
+
+case class NebulaCommitMessage(executeStatements: List[String]) extends WriterCommitMessage
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala
new file mode 100644
index 00000000..9c28df82
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala
@@ -0,0 +1,115 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges}
+import com.vesoft.nebula.connector.{KeyPolicy, NebulaOptions, WriteMode}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable.ListBuffer
+
+class NebulaEdgeWriter(nebulaOptions: NebulaOptions,
+ srcIndex: Int,
+ dstIndex: Int,
+ rankIndex: Option[Int],
+ schema: StructType)
+ extends NebulaWriter(nebulaOptions)
+ with DataWriter[InternalRow] {
+
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ val rankIdx = if (rankIndex.isDefined) rankIndex.get else -1
+ val propNames = NebulaExecutor.assignEdgePropNames(schema,
+ srcIndex,
+ dstIndex,
+ rankIdx,
+ nebulaOptions.srcAsProp,
+ nebulaOptions.dstAsProp,
+ nebulaOptions.rankAsProp)
+ val fieldTypMap: Map[String, Integer] =
+ if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]()
+ else metaProvider.getEdgeSchema(nebulaOptions.spaceName, nebulaOptions.label)
+
+ val srcPolicy =
+ if (nebulaOptions.srcPolicy.isEmpty) Option.empty
+ else Option(KeyPolicy.withName(nebulaOptions.srcPolicy))
+ val dstPolicy = {
+ if (nebulaOptions.dstPolicy.isEmpty) Option.empty
+ else Option(KeyPolicy.withName(nebulaOptions.dstPolicy))
+ }
+
+ /** buffer to save batch edges */
+ var edges: ListBuffer[NebulaEdge] = new ListBuffer()
+
+ prepareSpace()
+
+ /**
+ * write one edge record to buffer
+ */
+ override def write(row: InternalRow): Unit = {
+ val srcId = NebulaExecutor.extraID(schema, row, srcIndex, srcPolicy, isVidStringType)
+ val dstId = NebulaExecutor.extraID(schema, row, dstIndex, dstPolicy, isVidStringType)
+ val rank =
+ if (rankIndex.isEmpty) Option.empty
+ else Option(NebulaExecutor.extraRank(schema, row, rankIndex.get))
+ val values =
+ if (nebulaOptions.writeMode == WriteMode.DELETE) List()
+ else
+ NebulaExecutor.assignEdgeValues(schema,
+ row,
+ srcIndex,
+ dstIndex,
+ rankIdx,
+ nebulaOptions.srcAsProp,
+ nebulaOptions.dstAsProp,
+ nebulaOptions.rankAsProp,
+ fieldTypMap)
+ val nebulaEdge = NebulaEdge(srcId, dstId, rank, values)
+ edges.append(nebulaEdge)
+ if (edges.size >= nebulaOptions.batch) {
+ execute()
+ }
+ }
+
+ /**
+ * submit buffer edges to nebula
+ */
+ def execute(): Unit = {
+ val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy)
+ val exec = nebulaOptions.writeMode match {
+ case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges)
+ case WriteMode.UPDATE =>
+ NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges)
+ case WriteMode.DELETE =>
+ NebulaExecutor.toDeleteExecuteStatement(nebulaOptions.label, nebulaEdges)
+ case _ =>
+ throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.")
+ }
+ edges.clear()
+ submit(exec)
+ }
+
+ override def commit(): WriterCommitMessage = {
+ if (edges.nonEmpty) {
+ execute()
+ }
+ graphProvider.close()
+ metaProvider.close()
+ NebulaCommitMessage.apply(failedExecs.toList)
+ }
+
+ override def abort(): Unit = {
+ LOG.error("insert edge task abort.")
+ graphProvider.close()
+ }
+
+ override def close(): Unit = {
+ graphProvider.close()
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala
new file mode 100644
index 00000000..18a596d6
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala
@@ -0,0 +1,105 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.connector.NebulaOptions
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{
+ BatchWrite,
+ DataWriter,
+ DataWriterFactory,
+ PhysicalWriteInfo,
+ WriterCommitMessage
+}
+import org.apache.spark.sql.types.StructType
+import org.slf4j.LoggerFactory
+
+/**
+ * creating and initializing the actual Nebula vertex writer at executor side
+ */
+class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType)
+ extends DataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ new NebulaVertexWriter(nebulaOptions, vertexIndex, schema)
+ }
+}
+
+/**
+ * creating and initializing the actual Nebula edge writer at executor side
+ */
+class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions,
+ srcIndex: Int,
+ dstIndex: Int,
+ rankIndex: Option[Int],
+ schema: StructType)
+ extends DataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema)
+ }
+}
+
+/**
+ * nebula vertex writer to create factory
+ */
+class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions,
+ vertexIndex: Int,
+ schema: StructType)
+ extends BatchWrite {
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = {
+ new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema)
+ }
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = {
+ LOG.debug(s"${messages.length}")
+ for (msg <- messages) {
+ val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage]
+ if (nebulaMsg.executeStatements.nonEmpty) {
+ LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}")
+ } else {
+ LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed")
+ }
+ }
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ LOG.error("NebulaDataSourceVertexWriter abort")
+ }
+}
+
+/**
+ * nebula edge writer to create factory
+ */
+class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions,
+ srcIndex: Int,
+ dstIndex: Int,
+ rankIndex: Option[Int],
+ schema: StructType)
+ extends BatchWrite {
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory =
+ new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema)
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = {
+ LOG.debug(s"${messages.length}")
+ for (msg <- messages) {
+ val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage]
+ if (nebulaMsg.executeStatements.nonEmpty) {
+ LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}")
+ } else {
+ LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed")
+ }
+ }
+
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ LOG.error("NebulaDataSourceEdgeWriter abort")
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala
new file mode 100644
index 00000000..22a5d311
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala
@@ -0,0 +1,99 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.connector.{
+ KeyPolicy,
+ NebulaOptions,
+ NebulaVertex,
+ NebulaVertices,
+ WriteMode
+}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable.ListBuffer
+
+class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType)
+ extends NebulaWriter(nebulaOptions)
+ with DataWriter[InternalRow] {
+
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ val propNames = NebulaExecutor.assignVertexPropNames(schema, vertexIndex, nebulaOptions.vidAsProp)
+ val fieldTypMap: Map[String, Integer] =
+ if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]()
+ else metaProvider.getTagSchema(nebulaOptions.spaceName, nebulaOptions.label)
+
+ val policy = {
+ if (nebulaOptions.vidPolicy.isEmpty) Option.empty
+ else Option(KeyPolicy.withName(nebulaOptions.vidPolicy))
+ }
+
+ /** buffer to save batch vertices */
+ var vertices: ListBuffer[NebulaVertex] = new ListBuffer()
+
+ prepareSpace()
+
+ /**
+ * write one vertex row to buffer
+ */
+ override def write(row: InternalRow): Unit = {
+ val vertex =
+ NebulaExecutor.extraID(schema, row, vertexIndex, policy, isVidStringType)
+ val values =
+ if (nebulaOptions.writeMode == WriteMode.DELETE) List()
+ else
+ NebulaExecutor.assignVertexPropValues(schema,
+ row,
+ vertexIndex,
+ nebulaOptions.vidAsProp,
+ fieldTypMap)
+ val nebulaVertex = NebulaVertex(vertex, values)
+ vertices.append(nebulaVertex)
+ if (vertices.size >= nebulaOptions.batch) {
+ execute()
+ }
+ }
+
+ /**
+ * submit buffer vertices to nebula
+ */
+ def execute(): Unit = {
+ val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy)
+ val exec = nebulaOptions.writeMode match {
+ case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaVertices)
+ case WriteMode.UPDATE =>
+ NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices)
+ case WriteMode.DELETE =>
+ NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, nebulaOptions.deleteEdge)
+ case _ =>
+ throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.")
+ }
+ vertices.clear()
+ submit(exec)
+ }
+
+ override def commit(): WriterCommitMessage = {
+ if (vertices.nonEmpty) {
+ execute()
+ }
+ graphProvider.close()
+ metaProvider.close()
+ NebulaCommitMessage(failedExecs.toList)
+ }
+
+ override def abort(): Unit = {
+ LOG.error("insert vertex task abort.")
+ graphProvider.close()
+ }
+
+ override def close(): Unit = {
+ graphProvider.close()
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala
new file mode 100644
index 00000000..05718536
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala
@@ -0,0 +1,62 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import java.util.concurrent.TimeUnit
+
+import com.google.common.util.concurrent.RateLimiter
+import com.vesoft.nebula.connector.NebulaOptions
+import com.vesoft.nebula.connector.nebula.{GraphProvider, MetaProvider, VidType}
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable.ListBuffer
+
+class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable {
+ private val LOG = LoggerFactory.getLogger(this.getClass)
+
+ val failedExecs: ListBuffer[String] = new ListBuffer[String]
+
+ val metaProvider = new MetaProvider(
+ nebulaOptions.getMetaAddress,
+ nebulaOptions.timeout,
+ nebulaOptions.connectionRetry,
+ nebulaOptions.executionRetry,
+ nebulaOptions.enableMetaSSL,
+ nebulaOptions.sslSignType,
+ nebulaOptions.caSignParam,
+ nebulaOptions.selfSignParam
+ )
+ val graphProvider = new GraphProvider(
+ nebulaOptions.getGraphAddress,
+ nebulaOptions.timeout,
+ nebulaOptions.enableGraphSSL,
+ nebulaOptions.sslSignType,
+ nebulaOptions.caSignParam,
+ nebulaOptions.selfSignParam
+ )
+ val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING
+
+ def prepareSpace(): Unit = {
+ graphProvider.switchSpace(nebulaOptions.user, nebulaOptions.passwd, nebulaOptions.spaceName)
+ }
+
+ def submit(exec: String): Unit = {
+ @transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit)
+ if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) {
+ val result = graphProvider.submit(exec)
+ if (!result.isSucceeded) {
+ failedExecs.append(exec)
+ LOG.error(s"failed to write ${exec} for " + result.getErrorMessage)
+ } else {
+ LOG.info(s"batch write succeed")
+ LOG.debug(s"batch write succeed: ${exec}")
+ }
+ } else {
+ failedExecs.append(exec)
+ LOG.error(s"failed to acquire reteLimiter for statement {$exec}")
+ }
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala
new file mode 100644
index 00000000..c69f3976
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala
@@ -0,0 +1,92 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.connector.exception.IllegalOptionException
+import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions}
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.connector.write.{
+ BatchWrite,
+ SupportsOverwrite,
+ SupportsTruncate,
+ WriteBuilder
+}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+class NebulaWriterBuilder(schema: StructType, saveMode: SaveMode, nebulaOptions: NebulaOptions)
+ extends WriteBuilder
+ with SupportsOverwrite
+ with SupportsTruncate {
+
+ override def buildForBatch(): BatchWrite = {
+ val dataType = nebulaOptions.dataType
+ if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) {
+ val vertexFiled = nebulaOptions.vertexField
+ val vertexIndex: Int = {
+ var index: Int = -1
+ for (i <- schema.fields.indices) {
+ if (schema.fields(i).name.equals(vertexFiled)) {
+ index = i
+ }
+ }
+ if (index < 0) {
+ throw new IllegalOptionException(
+ s" vertex field ${vertexFiled} does not exist in dataframe")
+ }
+ index
+ }
+ new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema)
+ } else {
+ val srcVertexFiled = nebulaOptions.srcVertexField
+ val dstVertexField = nebulaOptions.dstVertexField
+ val rankExist = !nebulaOptions.rankField.isEmpty
+ val edgeFieldsIndex = {
+ var srcIndex: Int = -1
+ var dstIndex: Int = -1
+ var rankIndex: Int = -1
+ for (i <- schema.fields.indices) {
+ if (schema.fields(i).name.equals(srcVertexFiled)) {
+ srcIndex = i
+ }
+ if (schema.fields(i).name.equals(dstVertexField)) {
+ dstIndex = i
+ }
+ if (rankExist) {
+ if (schema.fields(i).name.equals(nebulaOptions.rankField)) {
+ rankIndex = i
+ }
+ }
+ }
+ // check src filed and dst field
+ if (srcIndex < 0 || dstIndex < 0) {
+ throw new IllegalOptionException(
+ s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe")
+ }
+ // check rank field
+ if (rankExist && rankIndex < 0) {
+ throw new IllegalOptionException(s"rank field does not exist in dataframe")
+ }
+
+ if (!rankExist) {
+ (srcIndex, dstIndex, Option.empty)
+ } else {
+ (srcIndex, dstIndex, Option(rankIndex))
+ }
+
+ }
+ new NebulaDataSourceEdgeWriter(nebulaOptions,
+ edgeFieldsIndex._1,
+ edgeFieldsIndex._2,
+ edgeFieldsIndex._3,
+ schema)
+ }
+ }
+
+ override def overwrite(filters: Array[Filter]): WriteBuilder = {
+ new NebulaWriterBuilder(schema, SaveMode.Overwrite, nebulaOptions)
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml b/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml
new file mode 100644
index 00000000..348f539b
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml
@@ -0,0 +1,362 @@
+version: '3.4'
+services:
+ metad0:
+ image: vesoft/nebula-metad:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.1.1
+ - --ws_ip=172.28.1.1
+ - --port=9559
+ - --data_path=/data/meta
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19559
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.1.1:19559/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9559:9559"
+ - 19559
+ - 11002
+ volumes:
+ - ./data/meta0:/data/meta:Z
+ - ./logs/meta0:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.1.1
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ metad1:
+ image: vesoft/nebula-metad:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.1.2
+ - --ws_ip=172.28.1.2
+ - --port=9559
+ - --data_path=/data/meta
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19559
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.1.2:19559/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9560:9559"
+ - 19559
+ - 11002
+ volumes:
+ - ./data/meta1:/data/meta:Z
+ - ./logs/meta1:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.1.2
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ metad2:
+ image: vesoft/nebula-metad:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.1.3
+ - --ws_ip=172.28.1.3
+ - --port=9559
+ - --data_path=/data/meta
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19559
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.1.3:19559/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9561:9559"
+ - 19559
+ - 11002
+ volumes:
+ - ./data/meta2:/data/meta:Z
+ - ./logs/meta2:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.1.3
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ storaged0:
+ image: vesoft/nebula-storaged:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.2.1
+ - --ws_ip=172.28.2.1
+ - --port=9779
+ - --data_path=/data/storage
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19779
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.2.1:19779/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9779:9779"
+ - 19779
+ - 12002
+ volumes:
+ - ./data/storage0:/data/storage:Z
+ - ./logs/storage0:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.2.1
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ storaged1:
+ image: vesoft/nebula-storaged:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.2.2
+ - --ws_ip=172.28.2.2
+ - --port=9779
+ - --data_path=/data/storage
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19779
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.2.2:19779/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9780:9779"
+ - 19779
+ - 12002
+ volumes:
+ - ./data/storage1:/data/storage:Z
+ - ./logs/storage1:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.2.2
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ storaged2:
+ image: vesoft/nebula-storaged:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --local_ip=172.28.2.3
+ - --ws_ip=172.28.2.3
+ - --port=9779
+ - --data_path=/data/storage
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19779
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.2.3:19779/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9781:9779"
+ - 19779
+ - 12002
+ volumes:
+ - ./data/storage2:/data/storage:Z
+ - ./logs/storage2:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.2.3
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ graphd0:
+ image: vesoft/nebula-graphd:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --port=9669
+ - --ws_ip=172.28.3.1
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19669
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.3.1:19669/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9669:9669"
+ - 19669
+ - 13002
+ volumes:
+ - ./logs/graph0:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.3.1
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ graphd1:
+ image: vesoft/nebula-graphd:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --port=9669
+ - --ws_ip=172.28.3.2
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19669
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.3.2:19669/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9670:9669"
+ - 19669
+ - 13002
+ volumes:
+ - ./logs/graph1:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.3.2
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ graphd2:
+ image: vesoft/nebula-graphd:nightly
+ environment:
+ USER: root
+ TZ: "${TZ}"
+ command:
+ - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559
+ - --port=9669
+ - --ws_ip=172.28.3.3
+ - --log_dir=/logs
+ - --v=0
+ - --ws_http_port=19669
+ - --minloglevel=0
+ - --heartbeat_interval_secs=2
+ depends_on:
+ - metad0
+ - metad1
+ - metad2
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://172.28.3.3:19669/status"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 20s
+ ports:
+ - "9671:9669"
+ - 19669
+ - 13002
+ volumes:
+ - ./logs/graph2:/logs:Z
+ networks:
+ nebula-net:
+ ipv4_address: 172.28.3.3
+ restart: on-failure
+ cap_add:
+ - SYS_PTRACE
+
+ console:
+ image: vesoft/nebula-console:nightly
+ entrypoint: ""
+ command:
+ - sh
+ - -c
+ - |
+ sleep 3 &&
+ nebula-console -addr graphd0 -port 9669 -u root -p nebula -e 'ADD HOSTS "172.28.2.1":9779,"172.28.2.2":9779,"172.28.2.3":9779' &&
+ sleep 36000
+ depends_on:
+ - graphd0
+ networks:
+ - nebula-net
+
+networks:
+ nebula-net:
+ ipam:
+ driver: default
+ config:
+ - subnet: 172.28.0.0/16
diff --git a/nebula-spark-connector_3.0/src/test/resources/edge.csv b/nebula-spark-connector_3.0/src/test/resources/edge.csv
new file mode 100644
index 00000000..2a2380fe
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/resources/edge.csv
@@ -0,0 +1,14 @@
+id1,id2,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14
+1,2,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2)
+2,3,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4)
+3,4,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6)
+4,5,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7)
+5,6,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5)
+6,7,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)"
+7,1,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)"
+8,1,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)"
+9,1,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)"
+10,2,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)"
+-1,5,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))"
+-2,6,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))"
+-3,7,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))"
diff --git a/nebula-spark-connector_3.0/src/test/resources/log4j.properties b/nebula-spark-connector_3.0/src/test/resources/log4j.properties
new file mode 100644
index 00000000..913391db
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/resources/log4j.properties
@@ -0,0 +1,6 @@
+# Global logging configuration
+log4j.rootLogger=INFO, stdout
+# Console output...
+log4j.appender.stdout=org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n
diff --git a/nebula-spark-connector_3.0/src/test/resources/vertex.csv b/nebula-spark-connector_3.0/src/test/resources/vertex.csv
new file mode 100644
index 00000000..2b74dfa0
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/resources/vertex.csv
@@ -0,0 +1,14 @@
+id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15
+1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2),"duration({years:1,months:1,seconds:1})"
+2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4),"duration({years:1,months:1,seconds:1})"
+3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6),"duration({years:1,months:1,seconds:1})"
+4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7),"duration({years:1,months:1,seconds:1})"
+5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5),"duration({years:1,months:1,seconds:1})"
+6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})"
+7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})"
+8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})"
+9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})"
+10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})"
+-1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})"
+-2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})"
+-3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})"
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala
new file mode 100644
index 00000000..81a8a39c
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala
@@ -0,0 +1,21 @@
+/* Copyright (c) 2022 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector
+
+import com.vesoft.nebula.connector.utils.SparkValidate
+import org.apache.spark.sql.SparkSession
+import org.scalatest.funsuite.AnyFunSuite
+
+class SparkVersionValidateSuite extends AnyFunSuite {
+ test("spark version validate") {
+ try {
+ val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN")
+ SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*")
+ } catch {
+ case e: Exception => assert(false)
+ }
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
new file mode 100644
index 00000000..46c8f502
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala
@@ -0,0 +1,194 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.mock
+
+import com.vesoft.nebula.client.graph.NebulaPoolConfig
+import com.vesoft.nebula.client.graph.data.HostAddress
+import com.vesoft.nebula.client.graph.net.NebulaPool
+import org.apache.log4j.Logger
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+
+class NebulaGraphMock {
+ private[this] val LOG = Logger.getLogger(this.getClass)
+
+ @transient val nebulaPoolConfig = new NebulaPoolConfig
+ @transient val pool: NebulaPool = new NebulaPool
+ val address = new ListBuffer[HostAddress]()
+ address.append(new HostAddress("127.0.0.1", 9669))
+
+ val randAddr = scala.util.Random.shuffle(address)
+ pool.init(randAddr.asJava, nebulaPoolConfig)
+
+ def mockStringIdGraph(): Unit = {
+ val session = pool.getSession("root", "nebula", true)
+
+ val createSpace = "CREATE SPACE IF NOT EXISTS test_string(partition_num=10,vid_type=fixed_string(8));" +
+ "USE test_string;" + "CREATE TAG IF NOT EXISTS person(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" +
+ "CREATE EDGE IF NOT EXISTS friend(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" +
+ "CREATE TAG IF NOT EXISTS geo_shape(geo geography);"
+ val createResp = session.execute(createSpace)
+ if (!createResp.isSucceeded) {
+ close()
+ LOG.error("create string type space failed," + createResp.getErrorMessage)
+ sys.exit(-1)
+ }
+
+ Thread.sleep(10000)
+ val insertTag =
+ "INSERT VERTEX person(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " +
+ " \"1\":(\"person1\", \"person1\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"2\":(\"person2\", \"person2\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"3\":(\"person3\", \"person3\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"4\":(\"person4\", \"person4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"5\":(\"person5\", \"person5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"6\":(\"person6\", \"person6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"7\":(\"person7\", \"person7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"8\":(\"person8\", \"person8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"9\":(\"person9\", \"person9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"10\":(\"person10\", \"person10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"11\":(\"person11\", \"person11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"12\":(\"person12\", \"person11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"-1\":(\"person00\", \"person00\", 23, 1400, 13000, 1388888, date(\"2021-01-13\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"-2\":(\"person01\", \"person01\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"-3\":(\"person02\", \"person02\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"19\":(\"person19\", \"person22\", 25, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"22\":(\"person22\", \"person22\", 26, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"));" +
+ "INSERT VERTEX geo_shape(geo) VALUES \"100\":(ST_GeogFromText(\"POINT(1 2)\")), \"101\":(ST_GeogFromText(\"LINESTRING(1 2, 3 4)\")), \"102\":(ST_GeogFromText(\"POLYGON((0 1, 1 2, 2 3, 0 1))\"))"
+ val insertTagResp = session.execute(insertTag)
+ if (!insertTagResp.isSucceeded) {
+ close()
+ LOG.error("insert vertex for string type space failed," + insertTagResp.getErrorMessage)
+ sys.exit(-1)
+ }
+
+ val insertEdge = "INSERT EDGE friend(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " +
+ " \"1\" -> \"2\":(\"friend1\", \"friend2\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"2\" -> \"3\":(\"friend2\", \"friend3\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"3\" -> \"4\":(\"friend3\", \"friend4\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"4\" -> \"5\":(\"friend4\", \"friend4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"5\" -> \"6\":(\"friend5\", \"friend5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"6\" -> \"7\":(\"friend6\", \"friend6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"7\" -> \"8\":(\"friend7\", \"friend7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"8\" -> \"9\":(\"friend8\", \"friend8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"9\" -> \"10\":(\"friend9\", \"friend9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"10\" -> \"11\":(\"friend10\", \"friend10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"11\" -> \"12\":(\"friend11\", \"friend11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"12\" -> \"1\":(\"friend12\", \"friend11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"-1\" -> \"11\":(\"friend13\", \"friend12\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " \"-2\" -> \"-1\":(\"friend14\", \"friend13\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))"
+ val insertEdgeResp = session.execute(insertEdge)
+ if (!insertEdgeResp.isSucceeded) {
+ close()
+ LOG.error("insert edge for string type space failed," + insertEdgeResp.getErrorMessage)
+ sys.exit(-1)
+ }
+ }
+
+ def mockIntIdGraph(): Unit = {
+ val session = pool.getSession("root", "nebula", true)
+
+ val createSpace = "CREATE SPACE IF NOT EXISTS test_int(partition_num=10, vid_type=int64);" +
+ "USE test_int;" + "CREATE TAG IF NOT EXISTS person(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" +
+ "CREATE EDGE IF NOT EXISTS friend(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" +
+ "CREATE TAG IF NOT EXISTS geo_shape(geo geography);" +
+ "CREATE TAG IF NOT EXISTS tag_duration(col duration);"
+ val createResp = session.execute(createSpace)
+ if (!createResp.isSucceeded) {
+ close()
+ LOG.error("create int type space failed," + createResp.getErrorMessage)
+ sys.exit(-1)
+ }
+
+ Thread.sleep(10000)
+ val insertTag =
+ "INSERT VERTEX person(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " +
+ " 1:(\"person1\", \"person1\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 2:(\"person2\", \"person2\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 3:(\"person3\", \"person3\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 4:(\"person4\", \"person4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 5:(\"person5\", \"person5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 6:(\"person6\", \"person6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 7:(\"person7\", \"person7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 8:(\"person8\", \"person8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 9:(\"person9\", \"person9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 10:(\"person10\", \"person10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 11:(\"person11\", \"person11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 12:(\"person12\", \"person11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " -1:(\"person00\", \"person00\", 23, 1400, 13000, 1388888, date(\"2021-01-13\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " -2:(\"person01\", \"person01\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " -3:(\"person02\", \"person02\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 19:(\"person19\", \"person22\", 25, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 22:(\"person22\", \"person22\", 26, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\")), " +
+ " 0:(null, null, null, null, null, null, null, null, null, null, null, null, null);" +
+ "INSERT VERTEX geo_shape(geo) VALUES 100:(ST_GeogFromText(\"POINT(1 2)\")), 101:(ST_GeogFromText(\"LINESTRING(1 2, 3 4)\")), 102:(ST_GeogFromText(\"POLYGON((0 1, 1 2, 2 3, 0 1))\"));" +
+ "INSERT VERTEX tag_duration(col) VALUES 200:(duration({months:1, seconds:100, microseconds:20}))"
+
+ val insertTagResp = session.execute(insertTag)
+ if (!insertTagResp.isSucceeded) {
+ close()
+ LOG.error("insert vertex for int type space failed," + insertTagResp.getErrorMessage)
+ sys.exit(-1)
+ }
+
+ val insertEdge = "INSERT EDGE friend(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " +
+ " 1 -> 2:(\"friend1\", \"friend2\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 2 -> 3:(\"friend2\", \"friend3\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 3 -> 4:(\"friend3\", \"friend4\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 4 -> 5:(\"friend4\", \"friend4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 5 -> 6:(\"friend5\", \"friend5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 6 -> 7:(\"friend6\", \"friend6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 7 -> 8:(\"friend7\", \"friend7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 8 -> 9:(\"friend8\", \"friend8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 9 -> 10:(\"friend9\", \"friend9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 10 -> 11:(\"friend10\", \"friend10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 11 -> 12:(\"friend11\", \"friend11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," +
+ " 12 -> 1:(\"friend12\", \"friend11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))"
+ val insertEdgeResp = session.execute(insertEdge)
+ if (!insertEdgeResp.isSucceeded) {
+ close()
+ LOG.error("insert edge for int type space failed," + insertEdgeResp.getErrorMessage)
+ sys.exit(-1)
+ }
+ }
+
+ def mockStringIdGraphSchema(): Unit = {
+ val session = pool.getSession("root", "nebula", true)
+
+ val createSpace = "CREATE SPACE IF NOT EXISTS test_write_string(partition_num=10,vid_type=fixed_string(8));" +
+ "USE test_write_string;" +
+ "CREATE TAG IF NOT EXISTS person_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography, col15 duration);" +
+ "CREATE EDGE IF NOT EXISTS friend_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography);";
+ val createResp = session.execute(createSpace)
+ if (!createResp.isSucceeded) {
+ close()
+ LOG.error("create string type space failed," + createResp.getErrorMessage)
+ sys.exit(-1)
+ }
+ Thread.sleep(10000)
+ }
+
+ def mockIntIdGraphSchema(): Unit = {
+ val session = pool.getSession("root", "nebula", true)
+
+ val createSpace = "CREATE SPACE IF NOT EXISTS test_write_int(partition_num=10, vid_type=int64);" +
+ "USE test_write_int;" +
+ "CREATE TAG IF NOT EXISTS person_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography, col15 duration);" +
+ "CREATE EDGE IF NOT EXISTS friend_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography);";
+ val createResp = session.execute(createSpace)
+ if (!createResp.isSucceeded) {
+ close()
+ LOG.error("create int type space failed," + createResp.getErrorMessage)
+ sys.exit(-1)
+ }
+ Thread.sleep(10000)
+ }
+
+ def close(): Unit = {
+ pool.close()
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala
new file mode 100644
index 00000000..a8eec279
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala
@@ -0,0 +1,219 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.mock
+
+import com.facebook.thrift.protocol.TCompactProtocol
+import com.vesoft.nebula.connector.{
+ NebulaConnectionConfig,
+ WriteMode,
+ WriteNebulaEdgeConfig,
+ WriteNebulaVertexConfig
+}
+import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+
+object SparkMock {
+
+ /**
+ * write nebula vertex with insert mode
+ */
+ def writeVertex(): Unit = {
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ val spark = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+
+ val df = spark.read
+ .option("header", true)
+ .csv("src/test/resources/vertex.csv")
+
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withGraphAddress("127.0.0.1:9669")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig
+ .builder()
+ .withSpace("test_write_string")
+ .withTag("person_connector")
+ .withVidField("id")
+ .withVidAsProp(false)
+ .withBatch(5)
+ .build()
+ df.write.nebula(config, nebulaWriteVertexConfig).writeVertices()
+
+ spark.stop()
+ }
+
+ /**
+ * write nebula vertex with delete mode
+ */
+ def deleteVertex(): Unit = {
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ val spark = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+
+ val df = spark.read
+ .option("header", true)
+ .csv("src/test/resources/vertex.csv")
+
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withGraphAddress("127.0.0.1:9669")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig
+ .builder()
+ .withSpace("test_write_string")
+ .withTag("person_connector")
+ .withVidField("id")
+ .withVidAsProp(false)
+ .withWriteMode(WriteMode.DELETE)
+ .withBatch(5)
+ .build()
+ df.write.nebula(config, nebulaWriteVertexConfig).writeVertices()
+
+ spark.stop()
+ }
+
+ /**
+ * write nebula edge with insert mode
+ */
+ def writeEdge(): Unit = {
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ val spark = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+
+ val df = spark.read
+ .option("header", true)
+ .csv("src/test/resources/edge.csv")
+
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withGraphAddress("127.0.0.1:9669")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig
+ .builder()
+ .withSpace("test_write_string")
+ .withEdge("friend_connector")
+ .withSrcIdField("id1")
+ .withDstIdField("id2")
+ .withRankField("col3")
+ .withRankAsProperty(true)
+ .withBatch(5)
+ .build()
+ df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges()
+
+ spark.stop()
+ }
+
+ /**
+ * write nebula edge with delete mode
+ */
+ def deleteEdge(): Unit = {
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ val spark = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+
+ val df = spark.read
+ .option("header", true)
+ .csv("src/test/resources/edge.csv")
+
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withGraphAddress("127.0.0.1:9669")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig
+ .builder()
+ .withSpace("test_write_string")
+ .withEdge("friend_connector")
+ .withSrcIdField("id1")
+ .withDstIdField("id2")
+ .withRankField("col3")
+ .withRankAsProperty(true)
+ .withWriteMode(WriteMode.DELETE)
+ .withBatch(5)
+ .build()
+ df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges()
+
+ spark.stop()
+ }
+
+ /**
+ * write nebula vertex with delete_with_edge mode
+ */
+ def deleteVertexWithEdge(): Unit = {
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ val spark = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+
+ val df = spark.read
+ .option("header", true)
+ .csv("src/test/resources/vertex.csv")
+
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withGraphAddress("127.0.0.1:9669")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig
+ .builder()
+ .withSpace("test_write_string")
+ .withTag("person_connector")
+ .withVidField("id")
+ .withVidAsProp(false)
+ .withWriteMode(WriteMode.DELETE)
+ .withDeleteEdge(true)
+ .withBatch(5)
+ .build()
+ df.write.nebula(config, nebulaWriteVertexConfig).writeVertices()
+
+ spark.stop()
+ }
+
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala
new file mode 100644
index 00000000..39ee72f5
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala
@@ -0,0 +1,340 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.reader
+
+import com.facebook.thrift.protocol.TCompactProtocol
+import com.vesoft.nebula.connector.connector.NebulaDataFrameReader
+import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig}
+import com.vesoft.nebula.connector.mock.NebulaGraphMock
+import org.apache.log4j.BasicConfigurator
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{Encoders, SparkSession}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+class ReadSuite extends AnyFunSuite with BeforeAndAfterAll {
+ BasicConfigurator.configure()
+ var sparkSession: SparkSession = null
+
+ override def beforeAll(): Unit = {
+ val graphMock = new NebulaGraphMock
+ graphMock.mockIntIdGraph()
+ graphMock.close()
+ val sparkConf = new SparkConf
+ sparkConf
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
+ sparkSession = SparkSession
+ .builder()
+ .master("local")
+ .config(sparkConf)
+ .getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ sparkSession.stop()
+ }
+
+ test("read vertex with no properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("person")
+ .withNoColumn(true)
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
+ vertex.printSchema()
+ vertex.show()
+ assert(vertex.count() == 18)
+ assert(vertex.schema.fields.length == 1)
+ }
+
+ test("read vertex with specific properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("person")
+ .withNoColumn(false)
+ .withReturnCols(List("col1"))
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
+ vertex.printSchema()
+ vertex.show()
+ assert(vertex.count() == 18)
+ assert(vertex.schema.fields.length == 2)
+
+ vertex.map(row => {
+ row.getAs[Long]("_vertexId") match {
+ case 1L => {
+ assert(row.getAs[String]("col1").equals("person1"))
+ }
+ case 0L => {
+ assert(row.isNullAt(1))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+
+ }
+
+ test("read vertex with all properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("person")
+ .withNoColumn(false)
+ .withReturnCols(List())
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
+ vertex.printSchema()
+ vertex.show()
+ assert(vertex.count() == 18)
+ assert(vertex.schema.fields.length == 14)
+
+ vertex.map(row => {
+ row.getAs[Long]("_vertexId") match {
+ case 1L => {
+ assert(row.getAs[String]("col1").equals("person1"))
+ assert(row.getAs[String]("col2").equals("person1"))
+ assert(row.getAs[Long]("col3") == 11)
+ assert(row.getAs[Long]("col4") == 200)
+ assert(row.getAs[Long]("col5") == 1000)
+ assert(row.getAs[Long]("col6") == 188888)
+ assert(row.getAs[String]("col7").equals("2021-01-01"))
+ assert(row.getAs[String]("col8").equals("2021-01-01T12:00:00.000"))
+ assert(row.getAs[Long]("col9") == 1609502400)
+ assert(row.getAs[Boolean]("col10"))
+ assert(row.getAs[Double]("col11") < 1.001)
+ assert(row.getAs[Double]("col12") < 2.001)
+ assert(row.getAs[String]("col13").equals("12:01:01"))
+ }
+ case 0L => {
+ assert(row.isNullAt(1))
+ assert(row.isNullAt(2))
+ assert(row.isNullAt(3))
+ assert(row.isNullAt(4))
+ assert(row.isNullAt(5))
+ assert(row.isNullAt(6))
+ assert(row.isNullAt(7))
+ assert(row.isNullAt(8))
+ assert(row.isNullAt(9))
+ assert(row.isNullAt(10))
+ assert(row.isNullAt(11))
+ assert(row.isNullAt(12))
+ assert(row.isNullAt(13))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+ test("read vertex for geo_shape") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("geo_shape")
+ .withNoColumn(false)
+ .withReturnCols(List("geo"))
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
+ vertex.printSchema()
+ vertex.show()
+ assert(vertex.count() == 3)
+ assert(vertex.schema.fields.length == 2)
+
+ vertex.map(row => {
+ row.getAs[Long]("_vertexId") match {
+ case 100L => {
+ assert(row.getAs[String]("geo").equals("POINT(1 2)"))
+ }
+ case 101L => {
+ assert(row.getAs[String]("geo").equals("LINESTRING(1 2, 3 4)"))
+ }
+ case 102L => {
+ assert(row.getAs[String]("geo").equals("POLYGON((0 1, 1 2, 2 3, 0 1))"))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+ test("read vertex for tag_duration") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("tag_duration")
+ .withNoColumn(false)
+ .withReturnCols(List("col"))
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF()
+ vertex.printSchema()
+ vertex.show()
+ assert(vertex.count() == 1)
+ assert(vertex.schema.fields.length == 2)
+
+ vertex.map(row => {
+ row.getAs[Long]("_vertexId") match {
+ case 200L => {
+ assert(
+ row.getAs[String]("col").equals("duration({months:1, seconds:100, microseconds:20})"))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+ test("read edge with no properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("friend")
+ .withNoColumn(true)
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val edge = sparkSession.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
+ edge.printSchema()
+ edge.show()
+ assert(edge.count() == 12)
+ assert(edge.schema.fields.length == 3)
+
+ edge.map(row => {
+ row.getAs[Long]("_srcId") match {
+ case 1L => {
+ assert(row.getAs[Long]("_dstId") == 2)
+ assert(row.getAs[Long]("_rank") == 0)
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+ test("read edge with specific properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("friend")
+ .withNoColumn(false)
+ .withReturnCols(List("col1"))
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val edge = sparkSession.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
+ edge.printSchema()
+ edge.show(20)
+ assert(edge.count() == 12)
+ assert(edge.schema.fields.length == 4)
+ edge.map(row => {
+ row.getAs[Long]("_srcId") match {
+ case 1L => {
+ assert(row.getAs[Long]("_dstId") == 2)
+ assert(row.getAs[Long]("_rank") == 0)
+ assert(row.getAs[String]("col1").equals("friend1"))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+ test("read edge with all properties") {
+ val config =
+ NebulaConnectionConfig
+ .builder()
+ .withMetaAddress("127.0.0.1:9559")
+ .withConenctionRetry(2)
+ .build()
+ val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig
+ .builder()
+ .withSpace("test_int")
+ .withLabel("friend")
+ .withNoColumn(false)
+ .withReturnCols(List())
+ .withLimit(10)
+ .withPartitionNum(10)
+ .build()
+ val edge = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadEdgesToDF()
+ edge.printSchema()
+ edge.show()
+ assert(edge.count() == 12)
+ assert(edge.schema.fields.length == 16)
+
+ edge.map(row => {
+ row.getAs[Long]("_srcId") match {
+ case 1L => {
+ assert(row.getAs[Long]("_dstId") == 2)
+ assert(row.getAs[Long]("_rank") == 0)
+ assert(row.getAs[String]("col1").equals("friend1"))
+ assert(row.getAs[String]("col2").equals("friend2"))
+ assert(row.getAs[Long]("col3") == 11)
+ assert(row.getAs[Long]("col4") == 200)
+ assert(row.getAs[Long]("col5") == 1000)
+ assert(row.getAs[Long]("col6") == 188888)
+ assert(row.getAs[String]("col7").equals("2021-01-01"))
+ assert(row.getAs[String]("col8").equals("2021-01-01T12:00:00.000"))
+ assert(row.getAs[Long]("col9") == 1609502400)
+ assert(row.getAs[Boolean]("col10"))
+ assert(row.getAs[Double]("col11") < 1.001)
+ assert(row.getAs[Double]("col12") < 2.001)
+ assert(row.getAs[String]("col13").equals("12:01:01"))
+ }
+ }
+ ""
+ })(Encoders.STRING)
+ }
+
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala
new file mode 100644
index 00000000..7a95f623
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala
@@ -0,0 +1,408 @@
+/* Copyright (c) 2020 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.connector.KeyPolicy
+import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges, NebulaVertex, NebulaVertices}
+import org.apache.log4j.BasicConfigurator
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.types.{
+ BooleanType,
+ DataTypes,
+ LongType,
+ StringType,
+ StructField,
+ StructType
+}
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import scala.collection.mutable.ListBuffer
+
+class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll {
+ BasicConfigurator.configure()
+ var schema: StructType = _
+ var row: InternalRow = _
+
+ override def beforeAll(): Unit = {
+ val fields = new ListBuffer[StructField]
+ fields.append(DataTypes.createStructField("col1", StringType, false))
+ fields.append(DataTypes.createStructField("col2", BooleanType, false))
+ fields.append(DataTypes.createStructField("col3", LongType, false))
+ schema = new StructType(fields.toArray)
+
+ val values = new ListBuffer[Any]
+ values.append("aaa")
+ values.append(true)
+ values.append(1L)
+ row = new GenericInternalRow(values.toArray)
+ }
+
+ override def afterAll(): Unit = super.afterAll()
+
+ test("test extraID") {
+ val index: Int = 1
+ val policy: Option[KeyPolicy.Value] = None
+ val isVidStringType: Boolean = true
+ val stringId = NebulaExecutor.extraID(schema, row, index, policy, isVidStringType)
+ assert("\"true\"".equals(stringId))
+
+ // test hash vertexId
+ val hashId = NebulaExecutor.extraID(schema, row, index, Some(KeyPolicy.HASH), false)
+ assert("true".equals(hashId))
+ }
+
+ test("test extraRank") {
+ // test correct type for rank
+ assert(NebulaExecutor.extraRank(schema, row, 2) == 1)
+
+ // test wrong type for rank
+ try {
+ NebulaExecutor.extraRank(schema, row, 1)
+ } catch {
+ case e: java.lang.AssertionError => assert(true)
+ }
+ }
+
+ test("test vid as prop for assignVertexPropValues ") {
+ val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2)
+ // test vid as prop
+ val props = NebulaExecutor.assignVertexPropValues(schema, row, 0, true, fieldTypeMap)
+ assert(props.size == 3)
+ assert(props.contains("\"aaa\""))
+ }
+
+ test("test vid not as prop for assignVertexPropValues ") {
+ val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2)
+ // test vid not as prop
+ val props = NebulaExecutor.assignVertexPropValues(schema, row, 0, false, fieldTypeMap)
+ assert(props.size == 2)
+ assert(!props.contains("\"aaa\""))
+ }
+
+ test("test src & dst & rank all as prop for assignEdgeValues") {
+ val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2)
+
+ val prop = NebulaExecutor.assignEdgeValues(schema, row, 0, 1, 2, true, true, true, fieldTypeMap)
+ assert(prop.size == 3)
+ }
+
+ test("test src & dst & rank all not as prop for assignEdgeValues") {
+ val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2)
+
+ val prop =
+ NebulaExecutor.assignEdgeValues(schema, row, 0, 1, 2, false, false, false, fieldTypeMap)
+ assert(prop.isEmpty)
+ }
+
+ test("test assignVertexPropNames") {
+ // test vid as prop
+ val propNames = NebulaExecutor.assignVertexPropNames(schema, 0, true)
+ assert(propNames.size == 3)
+ assert(propNames.contains("col1") && propNames.contains("col2") && propNames.contains("col3"))
+
+ // test vid not as prop
+ val propNames1 = NebulaExecutor.assignVertexPropNames(schema, 0, false)
+ assert(propNames1.size == 2)
+ assert(!propNames1.contains("col1") && propNames.contains("col2") && propNames.contains("col3"))
+ }
+
+ test("test assignEdgePropNames") {
+ // test src / dst / rank all as prop
+ val propNames = NebulaExecutor.assignEdgePropNames(schema, 0, 1, 2, true, true, true)
+ assert(propNames.size == 3)
+
+ // test src / dst / rank all not as prop
+ val propNames1 = NebulaExecutor.assignEdgePropNames(schema, 0, 1, 2, false, false, false)
+ assert(propNames1.isEmpty)
+ }
+
+ test("test toExecuteSentence for vertex") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ val tagName = "person"
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date",
+ "col_geo")
+
+ val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(3 8)")
+ val props2 =
+ List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "LINESTRING(1 2, 3 4)")
+ vertices.append(NebulaVertex("\"vid1\"", props1))
+ vertices.append(NebulaVertex("\"vid2\"", props2))
+
+ val nebulaVertices = NebulaVertices(propNames, vertices.toList, None)
+ val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices)
+
+ val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," +
+ "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\": (" + props1
+ .mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")"
+ assert(expectStatement.equals(vertexStatement))
+ }
+
+ test("test toExecuteSentence for vertex with hash policy") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ val tagName = "person"
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date",
+ "col_geo")
+
+ val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(1 2)")
+ val props2 =
+ List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "LINESTRING(1 2, 3 4)")
+ vertices.append(NebulaVertex("vid1", props1))
+ vertices.append(NebulaVertex("vid2", props2))
+
+ val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH))
+ val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices)
+
+ val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," +
+ "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES hash(\"vid1\"): (" + props1
+ .mkString(", ") +
+ "), hash(\"vid2\"): (" + props2.mkString(", ") + ")"
+ assert(expectStatement.equals(vertexStatement))
+ }
+
+ test("test toExecuteSentence for edge") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ val edgeName = "friend"
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date",
+ "col_geo")
+ val props1 =
+ List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POLYGON((0 1, 1 2, 2 3, 0 1))")
+ val props2 = List("\"Bob\"",
+ "\"Bob\"",
+ false,
+ 20,
+ 200L,
+ 2.0,
+ "2021-05-01",
+ "POLYGON((0 1, 1 2, 2 3, 0 1))")
+ edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), props1))
+ edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2))
+
+ val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None)
+ val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges)
+
+ val expectStatement = "INSERT edge `friend`(`col_string`,`col_fixed_string`,`col_bool`,`col_int`" +
+ ",`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\"->\"vid2\"@1: (" +
+ props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + props2.mkString(", ") + ")"
+ assert(expectStatement.equals(edgeStatement))
+ }
+
+ test("test toUpdateExecuteSentence for vertex") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date",
+ "col_geo")
+
+ val props1 =
+ List("\"name\"", "\"name\"", true, 10, 100L, 1.0, "2021-11-12", "LINESTRING(1 2, 3 4)")
+ val props2 = List("\"name2\"",
+ "\"name2\"",
+ false,
+ 11,
+ 101L,
+ 2.0,
+ "2021-11-13",
+ "POLYGON((0 1, 1 2, 2 3, 0 1))")
+
+ vertices.append(NebulaVertex("\"vid1\"", props1))
+ vertices.append(NebulaVertex("\"vid2\"", props2))
+ val nebulaVertices = NebulaVertices(propNames, vertices.toList, None)
+
+ val updateVertexStatement =
+ NebulaExecutor.toUpdateExecuteStatement("person", nebulaVertices)
+
+ val expectVertexUpdate =
+ "UPDATE VERTEX ON `person` \"vid1\" SET `col_string`=\"name\",`col_fixed_string`=\"name\"," +
+ "`col_bool`=true,`col_int`=10,`col_int64`=100,`col_double`=1.0,`col_date`=2021-11-12," +
+ "`col_geo`=LINESTRING(1 2, 3 4);UPDATE VERTEX ON `person` \"vid2\" SET " +
+ "`col_string`=\"name2\",`col_fixed_string`=\"name2\",`col_bool`=false,`col_int`=11," +
+ "`col_int64`=101,`col_double`=2.0,`col_date`=2021-11-13," +
+ "`col_geo`=POLYGON((0 1, 1 2, 2 3, 0 1))"
+ assert(expectVertexUpdate.equals(updateVertexStatement))
+ }
+
+ test("test toUpdateExecuteSentence for vertex with hash policy") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date")
+
+ val props1 = List("\"name\"", "\"name\"", true, 10, 100L, 1.0, "2021-11-12")
+ val props2 = List("\"name2\"", "\"name2\"", false, 11, 101L, 2.0, "2021-11-13")
+
+ vertices.append(NebulaVertex("vid1", props1))
+ vertices.append(NebulaVertex("vid2", props2))
+ val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH))
+
+ val updateVertexStatement =
+ NebulaExecutor.toUpdateExecuteStatement("person", nebulaVertices)
+ val expectVertexUpdate =
+ "UPDATE VERTEX ON `person` hash(\"vid1\") SET `col_string`=\"name\",`col_fixed_string`=\"name\"," +
+ "`col_bool`=true,`col_int`=10,`col_int64`=100,`col_double`=1.0,`col_date`=2021-11-12;" +
+ "UPDATE VERTEX ON `person` hash(\"vid2\") SET `col_string`=\"name2\",`col_fixed_string`=\"name2\"," +
+ "`col_bool`=false,`col_int`=11,`col_int64`=101,`col_double`=2.0,`col_date`=2021-11-13"
+ assert(expectVertexUpdate.equals(updateVertexStatement))
+ }
+
+ test("test toUpdateExecuteSentence for edge") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date",
+ "col_geo")
+ val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(1 2)")
+ val props2 = List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "POINT(2 3)")
+ edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), props1))
+ edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2))
+
+ val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None)
+ val updateEdgeStatement = NebulaExecutor.toUpdateExecuteStatement("friend", nebulaEdges)
+ val expectEdgeUpdate =
+ "UPDATE EDGE ON `friend` \"vid1\"->\"vid2\"@1 SET `col_string`=\"Tom\"," +
+ "`col_fixed_string`=\"Tom\",`col_bool`=true,`col_int`=10,`col_int64`=100," +
+ "`col_double`=1.0,`col_date`=2021-11-12,`col_geo`=POINT(1 2);" +
+ "UPDATE EDGE ON `friend` \"vid2\"->\"vid1\"@2 SET `col_string`=\"Bob\"," +
+ "`col_fixed_string`=\"Bob\",`col_bool`=false,`col_int`=20,`col_int64`=200," +
+ "`col_double`=2.0,`col_date`=2021-05-01,`col_geo`=POINT(2 3)"
+ assert(expectEdgeUpdate.equals(updateEdgeStatement))
+ }
+
+ test("test toUpdateExecuteSentence for edge with hash policy") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ val propNames = List("col_string",
+ "col_fixed_string",
+ "col_bool",
+ "col_int",
+ "col_int64",
+ "col_double",
+ "col_date")
+ val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12")
+ val props2 = List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01")
+ edges.append(NebulaEdge("vid1", "vid2", Some(1L), props1))
+ edges.append(NebulaEdge("vid2", "vid1", Some(2L), props2))
+
+ val nebulaEdges =
+ NebulaEdges(propNames, edges.toList, Some(KeyPolicy.HASH), Some(KeyPolicy.HASH))
+ val updateEdgeStatement = NebulaExecutor.toUpdateExecuteStatement("friend", nebulaEdges)
+ val expectEdgeUpdate =
+ "UPDATE EDGE ON `friend` hash(\"vid1\")->hash(\"vid2\")@1 SET `col_string`=\"Tom\"," +
+ "`col_fixed_string`=\"Tom\",`col_bool`=true,`col_int`=10,`col_int64`=100," +
+ "`col_double`=1.0,`col_date`=2021-11-12;" +
+ "UPDATE EDGE ON `friend` hash(\"vid2\")->hash(\"vid1\")@2 SET `col_string`=\"Bob\"," +
+ "`col_fixed_string`=\"Bob\",`col_bool`=false,`col_int`=20,`col_int64`=200," +
+ "`col_double`=2.0,`col_date`=2021-05-01"
+ assert(expectEdgeUpdate.equals(updateEdgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for vertex") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ vertices.append(NebulaVertex("\"vid1\"", List()))
+ vertices.append(NebulaVertex("\"vid2\"", List()))
+
+ val nebulaVertices = NebulaVertices(List(), vertices.toList, None)
+ val vertexStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, false)
+ val expectVertexDeleteStatement = "DELETE VERTEX \"vid1\",\"vid2\""
+ assert(expectVertexDeleteStatement.equals(vertexStatement))
+
+ val vertexWithEdgeStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, true)
+ val expectVertexWithEdgeDeleteStatement = "DELETE VERTEX \"vid1\",\"vid2\" WITH EDGE"
+ assert(expectVertexWithEdgeDeleteStatement.equals(vertexWithEdgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for vertex with HASH policy") {
+ val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex]
+ vertices.append(NebulaVertex("vid1", List()))
+ vertices.append(NebulaVertex("vid2", List()))
+
+ val nebulaVertices = NebulaVertices(List(), vertices.toList, Some(KeyPolicy.HASH))
+ val vertexStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, false)
+ val expectVertexDeleteStatement = "DELETE VERTEX hash(\"vid1\"),hash(\"vid2\")"
+ assert(expectVertexDeleteStatement.equals(vertexStatement))
+
+ val vertexWithEdgeStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, true)
+ val expectVertexWithEdgeDeleteStatement =
+ "DELETE VERTEX hash(\"vid1\"),hash(\"vid2\") WITH EDGE"
+ assert(expectVertexWithEdgeDeleteStatement.equals(vertexWithEdgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for edge") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), List()))
+ edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), List()))
+
+ val nebulaEdges = NebulaEdges(List(), edges.toList, None, None)
+ val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges)
+ val expectEdgeDeleteStatement = "DELETE EDGE `friend` \"vid1\"->\"vid2\"@1,\"vid2\"->\"vid1\"@2"
+ assert(expectEdgeDeleteStatement.equals(edgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for edge without rank") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Option.empty, List()))
+ edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Option.empty, List()))
+
+ val nebulaEdges = NebulaEdges(List(), edges.toList, None, None)
+ val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges)
+ val expectEdgeDeleteStatement = "DELETE EDGE `friend` \"vid1\"->\"vid2\"@0,\"vid2\"->\"vid1\"@0"
+ assert(expectEdgeDeleteStatement.equals(edgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for edge with src HASH policy") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ edges.append(NebulaEdge("vid1", "\"vid2\"", Some(1L), List()))
+ edges.append(NebulaEdge("vid2", "\"vid1\"", Some(2L), List()))
+
+ val nebulaEdges = NebulaEdges(List(), edges.toList, Some(KeyPolicy.HASH), None)
+ val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges)
+ val expectEdgeDeleteStatement =
+ "DELETE EDGE `friend` hash(\"vid1\")->\"vid2\"@1,hash(\"vid2\")->\"vid1\"@2"
+ assert(expectEdgeDeleteStatement.equals(edgeStatement))
+ }
+
+ test("test toDeleteExecuteStatement for edge with all HASH policy") {
+ val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge]
+ edges.append(NebulaEdge("vid1", "vid2", Some(1L), List()))
+ edges.append(NebulaEdge("vid2", "vid1", Some(2L), List()))
+
+ val nebulaEdges = NebulaEdges(List(), edges.toList, Some(KeyPolicy.HASH), Some(KeyPolicy.HASH))
+ val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges)
+ val expectEdgeDeleteStatement =
+ "DELETE EDGE `friend` hash(\"vid1\")->hash(\"vid2\")@1,hash(\"vid2\")->hash(\"vid1\")@2"
+ assert(expectEdgeDeleteStatement.equals(edgeStatement))
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala
new file mode 100644
index 00000000..bab62143
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala
@@ -0,0 +1,82 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.client.graph.data.ResultSet
+import com.vesoft.nebula.connector.Address
+import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock}
+import com.vesoft.nebula.connector.nebula.GraphProvider
+import org.apache.log4j.BasicConfigurator
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll {
+ BasicConfigurator.configure()
+
+ override def beforeAll(): Unit = {
+ val graphMock = new NebulaGraphMock
+ graphMock.mockStringIdGraphSchema()
+ graphMock.mockIntIdGraphSchema()
+ graphMock.close()
+ Thread.sleep(10000)
+ SparkMock.writeVertex()
+ SparkMock.writeEdge()
+ }
+
+ test("write vertex into test_write_string space with delete mode") {
+ SparkMock.deleteVertex()
+ val addresses: List[Address] = List(new Address("127.0.0.1", 9669))
+ val graphProvider = new GraphProvider(addresses, 3000)
+
+ graphProvider.switchSpace("root", "nebula", "test_write_string")
+ val resultSet: ResultSet =
+ graphProvider.submit("use test_write_string;"
+ + "match (v:person_connector) return v limit 100000;")
+ assert(resultSet.isSucceeded)
+ assert(resultSet.getColumnNames.size() == 1)
+ assert(resultSet.isEmpty)
+ }
+
+ test("write vertex into test_write_with_edge_string space with delete with edge mode") {
+ SparkMock.writeVertex()
+ SparkMock.writeEdge()
+ SparkMock.deleteVertexWithEdge()
+ val addresses: List[Address] = List(new Address("127.0.0.1", 9669))
+ val graphProvider = new GraphProvider(addresses, 3000)
+
+ graphProvider.switchSpace("root", "nebula", "test_write_string")
+ // assert vertex is deleted
+ val vertexResultSet: ResultSet =
+ graphProvider.submit("use test_write_string;"
+ + "match (v:person_connector) return v limit 1000000;")
+ assert(vertexResultSet.isSucceeded)
+ assert(vertexResultSet.getColumnNames.size() == 1)
+ assert(vertexResultSet.isEmpty)
+
+ // assert edge is deleted
+ val edgeResultSet: ResultSet =
+ graphProvider.submit("use test_write_string;"
+ + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e")
+ assert(vertexResultSet.isSucceeded)
+ assert(edgeResultSet.getColumnNames.size() == 1)
+ assert(edgeResultSet.isEmpty)
+
+ }
+
+ test("write edge into test_write_string space with delete mode") {
+ SparkMock.deleteEdge()
+ val addresses: List[Address] = List(new Address("127.0.0.1", 9669))
+ val graphProvider = new GraphProvider(addresses, 3000)
+
+ graphProvider.switchSpace("root", "nebula", "test_write_string")
+ val resultSet: ResultSet =
+ graphProvider.submit("use test_write_string;"
+ + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e;")
+ assert(resultSet.isSucceeded)
+ assert(resultSet.getColumnNames.size() == 1)
+ assert(resultSet.isEmpty)
+ }
+}
diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala
new file mode 100644
index 00000000..8856aa29
--- /dev/null
+++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala
@@ -0,0 +1,71 @@
+/* Copyright (c) 2021 vesoft inc. All rights reserved.
+ *
+ * This source code is licensed under Apache 2.0 License.
+ */
+
+package com.vesoft.nebula.connector.writer
+
+import com.vesoft.nebula.client.graph.data.ResultSet
+import com.vesoft.nebula.connector.Address
+import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock}
+import com.vesoft.nebula.connector.nebula.GraphProvider
+import org.apache.log4j.BasicConfigurator
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll {
+ BasicConfigurator.configure()
+
+ override def beforeAll(): Unit = {
+ val graphMock = new NebulaGraphMock
+ graphMock.mockStringIdGraphSchema()
+ graphMock.mockIntIdGraphSchema()
+ graphMock.close()
+ Thread.sleep(10000)
+ }
+
+ test("write vertex into test_write_string space with insert mode") {
+ SparkMock.writeVertex()
+ val addresses: List[Address] = List(new Address("127.0.0.1", 9669))
+ val graphProvider = new GraphProvider(addresses, 3000)
+
+ graphProvider.switchSpace("root", "nebula", "test_write_string")
+ val createIndexResult: ResultSet = graphProvider.submit(
+ "use test_write_string; "
+ + "create tag index if not exists person_index on person_connector(col1(20));")
+ Thread.sleep(5000)
+ graphProvider.submit("rebuild tag index person_index;")
+
+ Thread.sleep(5000)
+
+ graphProvider.submit("use test_write_string;")
+ val resultSet: ResultSet =
+ graphProvider.submit("match (v:person_connector) return v;")
+ assert(resultSet.isSucceeded)
+ assert(resultSet.getColumnNames.size() == 1)
+ assert(resultSet.getRows.size() == 13)
+ }
+
+ test("write edge into test_write_string space with insert mode") {
+ SparkMock.writeEdge()
+
+ val addresses: List[Address] = List(new Address("127.0.0.1", 9669))
+ val graphProvider = new GraphProvider(addresses, 3000)
+
+ graphProvider.switchSpace("root", "nebula", "test_write_string")
+ val createIndexResult: ResultSet = graphProvider.submit(
+ "use test_write_string; "
+ + "create edge index if not exists friend_index on friend_connector(col1(20));")
+ Thread.sleep(5000)
+ graphProvider.submit("rebuild edge index friend_index;")
+
+ Thread.sleep(5000)
+
+ graphProvider.submit("use test_write_string;")
+ val resultSet: ResultSet =
+ graphProvider.submit("match (v:person_connector)-[e:friend_connector]-> () return e;")
+ assert(resultSet.isSucceeded)
+ assert(resultSet.getColumnNames.size() == 1)
+ assert(resultSet.getRows.size() == 13)
+ }
+}
diff --git a/pom.xml b/pom.xml
index e0c2e1df..cbc42452 100644
--- a/pom.xml
+++ b/pom.xml
@@ -11,6 +11,9 @@
UTF-8
+ 2.4.4
+ 2.12
+ 2.12.10
3.2.3
@@ -48,6 +51,7 @@
nebula-spark-connector
nebula-spark-connector_2.2
+ nebula-spark-connector_3.0
example
nebula-spark-common
@@ -126,6 +130,38 @@
+
+ scala-2.11
+
+ 2.11.12
+ 2.11
+
+
+
+ scala-2.12
+
+ 2.12.10
+ 2.12
+
+
+
+ spark-2.2
+
+ 2.2.0
+
+
+
+ spark-2.4
+
+ 2.4.4
+
+
+
+ spark-3.0
+
+ 3.0.0
+
+