From 9da30d8300f5900da134d55352891107dc7a7133 Mon Sep 17 00:00:00 2001 From: Anqi Date: Mon, 14 Nov 2022 16:26:14 +0800 Subject: [PATCH 01/13] add connector for spark3.0 --- nebula-spark-connector_3.0/.gitignore | 36 ++ nebula-spark-connector_3.0/pom.xml | 292 +++++++++++++ .../nebula/connector/NebulaDataSource.scala | 149 +++++++ .../vesoft/nebula/connector/NebulaTable.scala | 70 +++ .../com/vesoft/nebula/connector/package.scala | 279 ++++++++++++ .../reader/NebulaEdgePartitionReader.scala | 76 ++++ .../connector/reader/NebulaPartition.scala | 10 + .../reader/NebulaPartitionReader.scala | 162 +++++++ .../reader/NebulaPartitionReaderFactory.scala | 25 ++ .../reader/NebulaVertexPartitionReader.scala | 79 ++++ .../connector/reader/SimpleScanBuilder.scala | 73 ++++ .../nebula/connector/utils/Validations.scala | 18 + .../writer/NebulaCommitMessage.scala | 10 + .../connector/writer/NebulaEdgeWriter.scala | 115 +++++ .../connector/writer/NebulaSourceWriter.scala | 105 +++++ .../connector/writer/NebulaVertexWriter.scala | 99 +++++ .../connector/writer/NebulaWriter.scala | 62 +++ .../writer/NebulaWriterBuilder.scala | 92 ++++ .../src/test/resources/docker-compose.yaml | 362 ++++++++++++++++ .../src/test/resources/edge.csv | 14 + .../src/test/resources/log4j.properties | 6 + .../src/test/resources/vertex.csv | 14 + .../connector/mock/NebulaGraphMock.scala | 192 +++++++++ .../nebula/connector/mock/SparkMock.scala | 219 ++++++++++ .../nebula/connector/reader/ReadSuite.scala | 340 +++++++++++++++ .../writer/NebulaExecutorSuite.scala | 408 ++++++++++++++++++ .../connector/writer/WriteDeleteSuite.scala | 82 ++++ .../connector/writer/WriteInsertSuite.scala | 71 +++ pom.xml | 1 + 29 files changed, 3461 insertions(+) create mode 100644 nebula-spark-connector_3.0/.gitignore create mode 100644 nebula-spark-connector_3.0/pom.xml create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala create mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala create mode 100644 nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml create mode 100644 nebula-spark-connector_3.0/src/test/resources/edge.csv create mode 100644 nebula-spark-connector_3.0/src/test/resources/log4j.properties create mode 100644 nebula-spark-connector_3.0/src/test/resources/vertex.csv create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala diff --git a/nebula-spark-connector_3.0/.gitignore b/nebula-spark-connector_3.0/.gitignore new file mode 100644 index 00000000..84e7a6bc --- /dev/null +++ b/nebula-spark-connector_3.0/.gitignore @@ -0,0 +1,36 @@ +# Compiled class file +*.class + +# Log file +*.log + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.nar +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* + +# build target +target/ + +# IDE +.idea/ +.eclipse/ +*.iml + +spark-importer.ipr +spark-importer.iws + +.DS_Store diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml new file mode 100644 index 00000000..bcfae8ec --- /dev/null +++ b/nebula-spark-connector_3.0/pom.xml @@ -0,0 +1,292 @@ + + + + nebula-spark + com.vesoft + 3.0-SNAPSHOT + ../pom.xml + + 4.0.0 + + nebula-spark-connector_3.0 + + + 1.8 + 1.8 + + + + + com.vesoft + nebula-spark-common + ${project.version} + + + + + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 + + + default-deploy + deploy + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.6.8 + true + + ossrh + https://oss.sonatype.org/ + true + + + + + org.apache.maven.plugins + maven-jar-plugin + 3.2.0 + + + + test-jar + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + ${compiler.source.version} + ${compiler.target.version} + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.1 + + + package + + shade + + + false + + + org.apache.spark:* + org.apache.hadoop:* + org.apache.hive:* + log4j:log4j + org.apache.orc:* + xml-apis:xml-apis + javax.inject:javax.inject + org.spark-project.hive:hive-exec + stax:stax-api + org.glassfish.hk2.external:aopalliance-repackaged + + + + + + *:* + + com/vesoft/tools/** + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + 2.11.12 + + -target:jvm-1.8 + + + -Xss4096K + + + + + scala-compile + + compile + + + + com/vesoft/tools/** + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + scala-test-compile + + testCompile + + + + com/vesoft/tools/** + + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.2.0 + + + attach-sources + + jar + + + + + + + + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + + compile + testCompile + + + + Scaladoc + + doc + + prepare-package + + + -nobootcp + -no-link-warnings + + + + + attach-javadocs + + doc-jar + + + + -nobootcp + -no-link-warnings + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.12.4 + + + **/*Test.* + **/*Suite.* + + + + + org.scalatest + scalatest-maven-plugin + 2.0.0 + + + test + + test + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.2.0 + + com.facebook.thrift:com.facebook.thrift.* + + + + + attach-javadocs + package + + jar + + + UTF-8 + UTF-8 + + -source 8 + -Xdoclint:none + + + + + + + org.jacoco + jacoco-maven-plugin + 0.8.7 + + + + prepare-agent + + + + report + test + + report + + + + + + + + + snapshots + https://oss.sonatype.org/content/repositories/snapshots/ + + + diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala new file mode 100644 index 00000000..0c6ef613 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -0,0 +1,149 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector + +import java.util +import java.util.Map.Entry + +import com.vesoft.nebula.connector.nebula.MetaProvider +import com.vesoft.nebula.connector.reader.SimpleScanBuilder +import com.vesoft.nebula.connector.utils.Validations +import com.vesoft.nebula.meta.ColumnDef +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{ + SupportsRead, + SupportsWrite, + Table, + TableCapability, + TableProvider +} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters.asScalaSetConverter + +class NebulaDataSource extends TableProvider with DataSourceRegister { + private val LOG = LoggerFactory.getLogger(this.getClass) + + Validations.validateSparkVersion("3.*") + + private var schema: StructType = null + private var nebulaOptions: NebulaOptions = _ + + /** + * The string that represents the format that nebula data source provider uses. + */ + override def shortName(): String = "nebula" + + override def supportsExternalMetadata(): Boolean = true + + override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = { + if (schema == null) { + nebulaOptions = getNebulaOptions(caseInsensitiveStringMap) + if (nebulaOptions.operaType == OperaType.READ) { + schema = getSchema(nebulaOptions) + } else { + schema = new StructType() + } + } + schema + } + + override def getTable(tableSchema: StructType, + transforms: Array[Transform], + map: util.Map[String, String]): Table = { + new NebulaTable(tableSchema, nebulaOptions) + } + + /** + * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. + */ + private def getSchema(nebulaOptions: NebulaOptions): StructType = { + val returnCols = nebulaOptions.getReturnCols + val noColumn = nebulaOptions.noColumn + val fields: ListBuffer[StructField] = new ListBuffer[StructField] + val metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + + import scala.collection.JavaConverters._ + var schemaCols: Seq[ColumnDef] = Seq() + val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) + + // construct vertex or edge default prop + if (isVertex) { + fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) + } else { + fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) + fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) + fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) + } + + var dataSchema: StructType = null + // read no column + if (noColumn) { + dataSchema = new StructType(fields.toArray) + return dataSchema + } + // get tag schema or edge schema + val schema = if (isVertex) { + metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) + } else { + metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) + } + + schemaCols = schema.columns.asScala + + // read all columns + if (returnCols.isEmpty) { + schemaCols.foreach(columnDef => { + LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") + fields.append( + DataTypes.createStructField(new String(columnDef.getName), + NebulaUtils.convertDataType(columnDef.getType), + true)) + }) + } else { + for (col: String <- returnCols) { + fields.append( + DataTypes + .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) + } + } + dataSchema = new StructType(fields.toArray) + dataSchema + } + + /** + * construct nebula options with DataSourceOptions + */ + private def getNebulaOptions( + caseInsensitiveStringMap: CaseInsensitiveStringMap): NebulaOptions = { + var parameters: Map[String, String] = Map() + for (entry: Entry[String, String] <- caseInsensitiveStringMap + .asCaseSensitiveMap() + .entrySet() + .asScala) { + parameters += (entry.getKey -> entry.getValue) + } + val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters)) + nebulaOptions + } + +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala new file mode 100644 index 00000000..a256fa93 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala @@ -0,0 +1,70 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector + +import java.util +import java.util.Map.Entry + +import com.vesoft.nebula.connector.reader.SimpleScanBuilder +import com.vesoft.nebula.connector.writer.NebulaWriterBuilder +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ + +class NebulaTable(schema: StructType, nebulaOptions: NebulaOptions) + extends Table + with SupportsRead + with SupportsWrite { + + private val LOG = LoggerFactory.getLogger(this.getClass) + + /** + * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. + */ + override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): ScanBuilder = { + LOG.info("create scan builder") + LOG.info(s"options ${caseInsensitiveStringMap.asCaseSensitiveMap()}") + + new SimpleScanBuilder(nebulaOptions, schema) + } + + /** + * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph. + */ + override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = { + LOG.info("create writer") + LOG.info(s"options ${logicalWriteInfo.options().asCaseSensitiveMap()}") + new NebulaWriterBuilder(logicalWriteInfo.schema(), SaveMode.Append, nebulaOptions) + } + + /** + * NebulaGraph table name + */ + override def name(): String = { + nebulaOptions.label + } + + override def schema(): StructType = schema + + override def capabilities(): util.Set[TableCapability] = + Set( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.ACCEPT_ANY_SCHEMA, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC, + TableCapability.STREAMING_WRITE, + TableCapability.MICRO_BATCH_READ + ).asJava + +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala new file mode 100644 index 00000000..b3cc5eed --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -0,0 +1,279 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector + +import com.vesoft.nebula.connector.ssl.SSLSignType +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{ + DataFrame, + DataFrameReader, + DataFrameWriter, + Encoder, + Encoders, + Row, + SaveMode +} + +import scala.collection.mutable.ListBuffer + +package object connector { + + /** + * spark reader for nebula graph + */ + implicit class NebulaDataFrameReader(reader: DataFrameReader) { + var connectionConfig: NebulaConnectionConfig = _ + var readConfig: ReadNebulaConfig = _ + + def nebula(connectionConfig: NebulaConnectionConfig, + readConfig: ReadNebulaConfig): NebulaDataFrameReader = { + this.connectionConfig = connectionConfig + this.readConfig = readConfig + this + } + + /** + * Reading com.vesoft.nebula.tools.connector.vertices from Nebula Graph + * @return DataFrame + */ + def loadVerticesToDF(): DataFrame = { + assert(connectionConfig != null && readConfig != null, + "nebula config is not set, please call nebula() before loadVerticesToDF") + val dfReader = reader + .format(classOf[NebulaDataSource].getName) + .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) + .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) + .option(NebulaOptions.LABEL, readConfig.getLabel) + .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) + .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) + .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn) + .option(NebulaOptions.LIMIT, readConfig.getLimit) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() + } + + /** + * Reading edges from Nebula Graph + * @return DataFrame + */ + def loadEdgesToDF(): DataFrame = { + assert(connectionConfig != null && readConfig != null, + "nebula config is not set, please call nebula() before loadEdgesToDF") + + val dfReader = reader + .format(classOf[NebulaDataSource].getName) + .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) + .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) + .option(NebulaOptions.LABEL, readConfig.getLabel) + .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) + .option(NebulaOptions.NO_COLUMN, readConfig.getNoColumn) + .option(NebulaOptions.LIMIT, readConfig.getLimit) + .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + .option(NebulaOptions.ENABLE_STORAGE_SSL, connectionConfig.getEnableStorageSSL) + + if (connectionConfig.getEnableStorageSSL || connectionConfig.getEnableMetaSSL) { + dfReader.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfReader.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfReader.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfReader.load() + } + + /** + * read nebula vertex edge to graphx's vertex + * use hash() for String type vertex id. + */ + def loadVerticesToGraphx(): RDD[NebulaGraphxVertex] = { + val vertexDataset = loadVerticesToDF() + implicit val encoder: Encoder[NebulaGraphxVertex] = + Encoders.bean[NebulaGraphxVertex](classOf[NebulaGraphxVertex]) + + vertexDataset + .map(row => { + val vertexId = row.get(0) + val vid: Long = vertexId.toString.toLong + val props: ListBuffer[Any] = ListBuffer() + for (i <- row.schema.fields.indices) { + if (i != 0) { + props.append(row.get(i)) + } + } + (vid, props.toList) + })(encoder) + .rdd + } + + /** + * read nebula edge edge to graphx's edge + * use hash() for String type srcId and dstId. + */ + def loadEdgesToGraphx(): RDD[NebulaGraphxEdge] = { + val edgeDataset = loadEdgesToDF() + implicit val encoder: Encoder[NebulaGraphxEdge] = + Encoders.bean[NebulaGraphxEdge](classOf[NebulaGraphxEdge]) + + edgeDataset + .map(row => { + val props: ListBuffer[Any] = ListBuffer() + for (i <- row.schema.fields.indices) { + if (i != 0 && i != 1 && i != 2) { + props.append(row.get(i)) + } + } + val srcId = row.get(0) + val dstId = row.get(1) + val edgeSrc = srcId.toString.toLong + val edgeDst = dstId.toString.toLong + val edgeProp = (row.get(2).toString.toLong, props.toList) + org.apache.spark.graphx + .Edge(edgeSrc, edgeDst, edgeProp) + })(encoder) + .rdd + } + + } + + /** + * spark writer for nebula graph + */ + implicit class NebulaDataFrameWriter(writer: DataFrameWriter[Row]) { + + var connectionConfig: NebulaConnectionConfig = _ + var writeNebulaConfig: WriteNebulaConfig = _ + + /** + * config nebula connection + * @param connectionConfig connection parameters + * @param writeNebulaConfig write parameters for vertex or edge + */ + def nebula(connectionConfig: NebulaConnectionConfig, + writeNebulaConfig: WriteNebulaConfig): NebulaDataFrameWriter = { + this.connectionConfig = connectionConfig + this.writeNebulaConfig = writeNebulaConfig + this + } + + /** + * write dataframe into nebula vertex + */ + def writeVertices(): Unit = { + assert(connectionConfig != null && writeNebulaConfig != null, + "nebula config is not set, please call nebula() before writeVertices") + val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaVertexConfig] + val dfWriter = writer + .format(classOf[NebulaDataSource].getName) + .mode(SaveMode.Overwrite) + .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) + .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) + .option(NebulaOptions.LABEL, writeConfig.getTagName) + .option(NebulaOptions.USER_NAME, writeConfig.getUser) + .option(NebulaOptions.PASSWD, writeConfig.getPasswd) + .option(NebulaOptions.VERTEX_FIELD, writeConfig.getVidField) + .option(NebulaOptions.VID_POLICY, writeConfig.getVidPolicy) + .option(NebulaOptions.BATCH, writeConfig.getBatch) + .option(NebulaOptions.VID_AS_PROP, writeConfig.getVidAsProp) + .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.DELETE_EDGE, writeConfig.getDeleteEdge) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + + if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) { + dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfWriter.save() + } + + /** + * write dataframe into nebula edge + */ + def writeEdges(): Unit = { + + assert(connectionConfig != null && writeNebulaConfig != null, + "nebula config is not set, please call nebula() before writeEdges") + val writeConfig = writeNebulaConfig.asInstanceOf[WriteNebulaEdgeConfig] + val dfWriter = writer + .format(classOf[NebulaDataSource].getName) + .mode(SaveMode.Overwrite) + .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) + .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) + .option(NebulaOptions.USER_NAME, writeConfig.getUser) + .option(NebulaOptions.PASSWD, writeConfig.getPasswd) + .option(NebulaOptions.LABEL, writeConfig.getEdgeName) + .option(NebulaOptions.SRC_VERTEX_FIELD, writeConfig.getSrcFiled) + .option(NebulaOptions.DST_VERTEX_FIELD, writeConfig.getDstField) + .option(NebulaOptions.SRC_POLICY, writeConfig.getSrcPolicy) + .option(NebulaOptions.DST_POLICY, writeConfig.getDstPolicy) + .option(NebulaOptions.RANK_FIELD, writeConfig.getRankField) + .option(NebulaOptions.BATCH, writeConfig.getBatch) + .option(NebulaOptions.SRC_AS_PROP, writeConfig.getSrcAsProp) + .option(NebulaOptions.DST_AS_PROP, writeConfig.getDstAsProp) + .option(NebulaOptions.RANK_AS_PROP, writeConfig.getRankAsProp) + .option(NebulaOptions.WRITE_MODE, writeConfig.getWriteMode) + .option(NebulaOptions.META_ADDRESS, connectionConfig.getMetaAddress) + .option(NebulaOptions.GRAPH_ADDRESS, connectionConfig.getGraphAddress) + .option(NebulaOptions.TIMEOUT, connectionConfig.getTimeout) + .option(NebulaOptions.CONNECTION_RETRY, connectionConfig.getConnectionRetry) + .option(NebulaOptions.EXECUTION_RETRY, connectionConfig.getExecRetry) + .option(NebulaOptions.ENABLE_GRAPH_SSL, connectionConfig.getEnableGraphSSL) + .option(NebulaOptions.ENABLE_META_SSL, connectionConfig.getEnableMetaSSL) + + if (connectionConfig.getEnableGraphSSL || connectionConfig.getEnableMetaSSL) { + dfWriter.option(NebulaOptions.SSL_SIGN_TYPE, connectionConfig.getSignType) + SSLSignType.withName(connectionConfig.getSignType) match { + case SSLSignType.CA => + dfWriter.option(NebulaOptions.CA_SIGN_PARAM, connectionConfig.getCaSignParam) + case SSLSignType.SELF => + dfWriter.option(NebulaOptions.SELF_SIGN_PARAM, connectionConfig.getSelfSignParam) + } + } + + dfWriter.save() + } + } + +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala new file mode 100644 index 00000000..d1a607d9 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala @@ -0,0 +1,76 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator} +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} +import scala.collection.JavaConverters._ + +class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) + extends NebulaPartitionReader(index, nebulaOptions, schema) { + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var responseIterator: ScanEdgeResultIterator = _ + + override def next(): Boolean = { + if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) + return false + + var continue: Boolean = false + var break: Boolean = false + while ((dataIterator == null || !dataIterator.hasNext) && !break) { + resultValues.clear() + continue = false + if (responseIterator == null || !responseIterator.hasNext) { + if (scanPartIterator.hasNext) { + try { + if (nebulaOptions.noColumn) { + responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.limit, + 0L, + Long.MaxValue, + true, + true) + } else { + responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.getReturnCols.asJava, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } + } catch { + case e: Exception => + LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) + storageClient.close() + throw new Exception(e.getMessage, e) + } + // jump to the next loop + continue = true + } + // break while loop + break = !continue + } else { + val next: ScanEdgeResult = responseIterator.next + if (!next.isEmpty) { + dataIterator = next.getEdgeTableRows.iterator().asScala + } + } + } + + if (dataIterator == null) { + return false + } + dataIterator.hasNext + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala new file mode 100644 index 00000000..d296376b --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala @@ -0,0 +1,10 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import org.apache.spark.sql.connector.read.InputPartition + +case class NebulaPartition(partition: Int) extends InputPartition diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala new file mode 100644 index 00000000..a6fb5704 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -0,0 +1,162 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.client.graph.data.{ + CASignedSSLParam, + HostAddress, + SSLParam, + SelfSignedSSLParam, + ValueWrapper +} +import com.vesoft.nebula.client.storage.StorageClient +import com.vesoft.nebula.client.storage.data.BaseTableRow +import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} +import com.vesoft.nebula.connector.exception.GraphConnectException +import com.vesoft.nebula.connector.nebula.MetaProvider +import com.vesoft.nebula.connector.ssl.SSLSignType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * Read nebula data for each spark partition + */ +abstract class NebulaPartitionReader extends PartitionReader[InternalRow] { + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var metaProvider: MetaProvider = _ + private var schema: StructType = _ + + protected var dataIterator: Iterator[BaseTableRow] = _ + protected var scanPartIterator: Iterator[Integer] = _ + protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]() + protected var storageClient: StorageClient = _ + + /** + * @param index identifier for spark partition + * @param nebulaOptions nebula Options + * @param schema of data need to read + */ + def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { + this() + this.schema = schema + + metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] + + for (addr <- nebulaOptions.getMetaAddress) { + address.append(new HostAddress(addr._1, addr._2)) + } + + var sslParam: SSLParam = null + if (nebulaOptions.enableStorageSSL) { + SSLSignType.withName(nebulaOptions.sslSignType) match { + case SSLSignType.CA => { + val caSSLSignParams = nebulaOptions.caSignParam + sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, + caSSLSignParams.crtFilePath, + caSSLSignParams.keyFilePath) + } + case SSLSignType.SELF => { + val selfSSLSignParams = nebulaOptions.selfSignParam + sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, + selfSSLSignParams.keyFilePath, + selfSSLSignParams.password) + } + case _ => throw new IllegalArgumentException("ssl sign type is not supported") + } + this.storageClient = new StorageClient(address.asJava, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + true, + sslParam) + } else { + this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) + } + + if (!storageClient.connect()) { + throw new GraphConnectException("storage connect failed.") + } + // allocate scanPart to this partition + val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) + + // index starts with 1 + val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) + LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") + scanPartIterator = scanParts.iterator + } + + override def get(): InternalRow = { + val resultSet: Array[ValueWrapper] = + dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper]) + val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) + val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) + + for (i <- getters.indices) { + val value: ValueWrapper = resultSet(i) + var resolved = false + if (value.isNull) { + mutableRow.setNullAt(i) + resolved = true + } + if (value.isString) { + getters(i).apply(value.asString(), mutableRow, i) + resolved = true + } + if (value.isDate) { + getters(i).apply(value.asDate(), mutableRow, i) + resolved = true + } + if (value.isTime) { + getters(i).apply(value.asTime(), mutableRow, i) + resolved = true + } + if (value.isDateTime) { + getters(i).apply(value.asDateTime(), mutableRow, i) + resolved = true + } + if (value.isLong) { + getters(i).apply(value.asLong(), mutableRow, i) + } + if (value.isBoolean) { + getters(i).apply(value.asBoolean(), mutableRow, i) + } + if (value.isDouble) { + getters(i).apply(value.asDouble(), mutableRow, i) + } + if (value.isGeography) { + getters(i).apply(value.asGeography(), mutableRow, i) + } + if (value.isDuration) { + getters(i).apply(value.asDuration(), mutableRow, i) + } + } + mutableRow + } + + override def close(): Unit = { + metaProvider.close() + storageClient.close() + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala new file mode 100644 index 00000000..9e5e7e5e --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.types.StructType + +class NebulaPartitionReaderFactory(private val nebulaOptions: NebulaOptions, + private val schema: StructType) + extends PartitionReaderFactory { + override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = { + val partition = inputPartition.asInstanceOf[NebulaPartition].partition + if (DataTypeEnum.VERTEX.toString.equals(nebulaOptions.dataType)) { + + new NebulaVertexPartitionReader(partition, nebulaOptions, schema) + } else { + new NebulaEdgePartitionReader(partition, nebulaOptions, schema) + } + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala new file mode 100644 index 00000000..3466b1dd --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala @@ -0,0 +1,79 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator} +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConverters._ + +class NebulaVertexPartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) + extends NebulaPartitionReader(index, nebulaOptions, schema) { + + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var responseIterator: ScanVertexResultIterator = _ + + override def next(): Boolean = { + if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) + return false + + var continue: Boolean = false + var break: Boolean = false + while ((dataIterator == null || !dataIterator.hasNext) && !break) { + resultValues.clear() + continue = false + if (responseIterator == null || !responseIterator.hasNext) { + if (scanPartIterator.hasNext) { + try { + if (nebulaOptions.noColumn) { + responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } else { + responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.getReturnCols.asJava, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } + } catch { + case e: Exception => + LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) + storageClient.close() + throw new Exception(e.getMessage, e) + } + // jump to the next loop + continue = true + } + // break while loop + break = !continue + } else { + val next: ScanVertexResult = responseIterator.next + if (!next.isEmpty) { + dataIterator = next.getVertexTableRows.iterator().asScala + } + } + } + + if (dataIterator == null) { + return false + } + dataIterator.hasNext + } + +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala new file mode 100644 index 00000000..7009838b --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import java.util + +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.sql.connector.read.{ + Batch, + InputPartition, + PartitionReaderFactory, + Scan, + ScanBuilder, + SupportsPushDownFilters, + SupportsPushDownRequiredColumns +} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters.asScalaBufferConverter + +class SimpleScanBuilder(nebulaOptions: NebulaOptions, schema: StructType) + extends ScanBuilder + with SupportsPushDownFilters + with SupportsPushDownRequiredColumns { + + private var filters: Array[Filter] = Array[Filter]() + + override def build(): Scan = { + new SimpleScan(nebulaOptions, 10, schema) + } + + override def pushFilters(pushFilters: Array[Filter]): Array[Filter] = { + if (nebulaOptions.pushDownFiltersEnabled) { + filters = pushFilters + } + pushFilters + } + + override def pushedFilters(): Array[Filter] = filters + + override def pruneColumns(requiredColumns: StructType): Unit = { + if (!nebulaOptions.pushDownFiltersEnabled || requiredColumns == schema) { + new StructType() + } else { + requiredColumns + } + } +} + +class SimpleScan(nebulaOptions: NebulaOptions, nebulaTotalPart: Int, schema: StructType) + extends Scan + with Batch { + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = { + val partitionSize = nebulaTotalPart + val inputPartitions: java.util.List[InputPartition] = new util.ArrayList[InputPartition]() + for (i <- 1 to partitionSize) { + inputPartitions.add(NebulaPartition(i)) + } + inputPartitions.asScala.toArray + } + + override def readSchema(): StructType = schema + + override def createReaderFactory(): PartitionReaderFactory = + new NebulaPartitionReaderFactory(nebulaOptions, schema) +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala new file mode 100644 index 00000000..93b9a87c --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala @@ -0,0 +1,18 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.utils + +import org.apache.spark.sql.SparkSession + +object Validations { + def validateSparkVersion(supportedVersions: String*): Unit = { + val sparkVersion = SparkSession.getActiveSession.map { _.version }.getOrElse("UNKNOWN") + if (!(sparkVersion == "UNKNOWN" || supportedVersions.exists(sparkVersion.matches))) { + throw new RuntimeException( + s"Your current spark version ${sparkVersion} is not supported bt the current connector.") + } + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala new file mode 100644 index 00000000..184f63ba --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala @@ -0,0 +1,10 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import org.apache.spark.sql.connector.write.WriterCommitMessage + +case class NebulaCommitMessage(executeStatements: List[String]) extends WriterCommitMessage diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala new file mode 100644 index 00000000..9c28df82 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala @@ -0,0 +1,115 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges} +import com.vesoft.nebula.connector.{KeyPolicy, NebulaOptions, WriteMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.types.StructType +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer + +class NebulaEdgeWriter(nebulaOptions: NebulaOptions, + srcIndex: Int, + dstIndex: Int, + rankIndex: Option[Int], + schema: StructType) + extends NebulaWriter(nebulaOptions) + with DataWriter[InternalRow] { + + private val LOG = LoggerFactory.getLogger(this.getClass) + + val rankIdx = if (rankIndex.isDefined) rankIndex.get else -1 + val propNames = NebulaExecutor.assignEdgePropNames(schema, + srcIndex, + dstIndex, + rankIdx, + nebulaOptions.srcAsProp, + nebulaOptions.dstAsProp, + nebulaOptions.rankAsProp) + val fieldTypMap: Map[String, Integer] = + if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() + else metaProvider.getEdgeSchema(nebulaOptions.spaceName, nebulaOptions.label) + + val srcPolicy = + if (nebulaOptions.srcPolicy.isEmpty) Option.empty + else Option(KeyPolicy.withName(nebulaOptions.srcPolicy)) + val dstPolicy = { + if (nebulaOptions.dstPolicy.isEmpty) Option.empty + else Option(KeyPolicy.withName(nebulaOptions.dstPolicy)) + } + + /** buffer to save batch edges */ + var edges: ListBuffer[NebulaEdge] = new ListBuffer() + + prepareSpace() + + /** + * write one edge record to buffer + */ + override def write(row: InternalRow): Unit = { + val srcId = NebulaExecutor.extraID(schema, row, srcIndex, srcPolicy, isVidStringType) + val dstId = NebulaExecutor.extraID(schema, row, dstIndex, dstPolicy, isVidStringType) + val rank = + if (rankIndex.isEmpty) Option.empty + else Option(NebulaExecutor.extraRank(schema, row, rankIndex.get)) + val values = + if (nebulaOptions.writeMode == WriteMode.DELETE) List() + else + NebulaExecutor.assignEdgeValues(schema, + row, + srcIndex, + dstIndex, + rankIdx, + nebulaOptions.srcAsProp, + nebulaOptions.dstAsProp, + nebulaOptions.rankAsProp, + fieldTypMap) + val nebulaEdge = NebulaEdge(srcId, dstId, rank, values) + edges.append(nebulaEdge) + if (edges.size >= nebulaOptions.batch) { + execute() + } + } + + /** + * submit buffer edges to nebula + */ + def execute(): Unit = { + val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) + val exec = nebulaOptions.writeMode match { + case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges) + case WriteMode.UPDATE => + NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) + case WriteMode.DELETE => + NebulaExecutor.toDeleteExecuteStatement(nebulaOptions.label, nebulaEdges) + case _ => + throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") + } + edges.clear() + submit(exec) + } + + override def commit(): WriterCommitMessage = { + if (edges.nonEmpty) { + execute() + } + graphProvider.close() + metaProvider.close() + NebulaCommitMessage.apply(failedExecs.toList) + } + + override def abort(): Unit = { + LOG.error("insert edge task abort.") + graphProvider.close() + } + + override def close(): Unit = { + graphProvider.close() + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala new file mode 100644 index 00000000..18a596d6 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala @@ -0,0 +1,105 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{ + BatchWrite, + DataWriter, + DataWriterFactory, + PhysicalWriteInfo, + WriterCommitMessage +} +import org.apache.spark.sql.types.StructType +import org.slf4j.LoggerFactory + +/** + * creating and initializing the actual Nebula vertex writer at executor side + */ +class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) + extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new NebulaVertexWriter(nebulaOptions, vertexIndex, schema) + } +} + +/** + * creating and initializing the actual Nebula edge writer at executor side + */ +class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions, + srcIndex: Int, + dstIndex: Int, + rankIndex: Option[Int], + schema: StructType) + extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) + } +} + +/** + * nebula vertex writer to create factory + */ +class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions, + vertexIndex: Int, + schema: StructType) + extends BatchWrite { + private val LOG = LoggerFactory.getLogger(this.getClass) + + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + LOG.debug(s"${messages.length}") + for (msg <- messages) { + val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] + if (nebulaMsg.executeStatements.nonEmpty) { + LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") + } else { + LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") + } + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + LOG.error("NebulaDataSourceVertexWriter abort") + } +} + +/** + * nebula edge writer to create factory + */ +class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions, + srcIndex: Int, + dstIndex: Int, + rankIndex: Option[Int], + schema: StructType) + extends BatchWrite { + private val LOG = LoggerFactory.getLogger(this.getClass) + + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = + new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + LOG.debug(s"${messages.length}") + for (msg <- messages) { + val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] + if (nebulaMsg.executeStatements.nonEmpty) { + LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") + } else { + LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") + } + } + + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + LOG.error("NebulaDataSourceEdgeWriter abort") + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala new file mode 100644 index 00000000..22a5d311 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala @@ -0,0 +1,99 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.connector.{ + KeyPolicy, + NebulaOptions, + NebulaVertex, + NebulaVertices, + WriteMode +} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.types.StructType +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer + +class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) + extends NebulaWriter(nebulaOptions) + with DataWriter[InternalRow] { + + private val LOG = LoggerFactory.getLogger(this.getClass) + + val propNames = NebulaExecutor.assignVertexPropNames(schema, vertexIndex, nebulaOptions.vidAsProp) + val fieldTypMap: Map[String, Integer] = + if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() + else metaProvider.getTagSchema(nebulaOptions.spaceName, nebulaOptions.label) + + val policy = { + if (nebulaOptions.vidPolicy.isEmpty) Option.empty + else Option(KeyPolicy.withName(nebulaOptions.vidPolicy)) + } + + /** buffer to save batch vertices */ + var vertices: ListBuffer[NebulaVertex] = new ListBuffer() + + prepareSpace() + + /** + * write one vertex row to buffer + */ + override def write(row: InternalRow): Unit = { + val vertex = + NebulaExecutor.extraID(schema, row, vertexIndex, policy, isVidStringType) + val values = + if (nebulaOptions.writeMode == WriteMode.DELETE) List() + else + NebulaExecutor.assignVertexPropValues(schema, + row, + vertexIndex, + nebulaOptions.vidAsProp, + fieldTypMap) + val nebulaVertex = NebulaVertex(vertex, values) + vertices.append(nebulaVertex) + if (vertices.size >= nebulaOptions.batch) { + execute() + } + } + + /** + * submit buffer vertices to nebula + */ + def execute(): Unit = { + val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) + val exec = nebulaOptions.writeMode match { + case WriteMode.INSERT => NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaVertices) + case WriteMode.UPDATE => + NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) + case WriteMode.DELETE => + NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, nebulaOptions.deleteEdge) + case _ => + throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") + } + vertices.clear() + submit(exec) + } + + override def commit(): WriterCommitMessage = { + if (vertices.nonEmpty) { + execute() + } + graphProvider.close() + metaProvider.close() + NebulaCommitMessage(failedExecs.toList) + } + + override def abort(): Unit = { + LOG.error("insert vertex task abort.") + graphProvider.close() + } + + override def close(): Unit = { + graphProvider.close() + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala new file mode 100644 index 00000000..05718536 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala @@ -0,0 +1,62 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import java.util.concurrent.TimeUnit + +import com.google.common.util.concurrent.RateLimiter +import com.vesoft.nebula.connector.NebulaOptions +import com.vesoft.nebula.connector.nebula.{GraphProvider, MetaProvider, VidType} +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer + +class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable { + private val LOG = LoggerFactory.getLogger(this.getClass) + + val failedExecs: ListBuffer[String] = new ListBuffer[String] + + val metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + val graphProvider = new GraphProvider( + nebulaOptions.getGraphAddress, + nebulaOptions.timeout, + nebulaOptions.enableGraphSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING + + def prepareSpace(): Unit = { + graphProvider.switchSpace(nebulaOptions.user, nebulaOptions.passwd, nebulaOptions.spaceName) + } + + def submit(exec: String): Unit = { + @transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit) + if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) { + val result = graphProvider.submit(exec) + if (!result.isSucceeded) { + failedExecs.append(exec) + LOG.error(s"failed to write ${exec} for " + result.getErrorMessage) + } else { + LOG.info(s"batch write succeed") + LOG.debug(s"batch write succeed: ${exec}") + } + } else { + failedExecs.append(exec) + LOG.error(s"failed to acquire reteLimiter for statement {$exec}") + } + } +} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala new file mode 100644 index 00000000..c69f3976 --- /dev/null +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala @@ -0,0 +1,92 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.connector.exception.IllegalOptionException +import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.connector.write.{ + BatchWrite, + SupportsOverwrite, + SupportsTruncate, + WriteBuilder +} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +class NebulaWriterBuilder(schema: StructType, saveMode: SaveMode, nebulaOptions: NebulaOptions) + extends WriteBuilder + with SupportsOverwrite + with SupportsTruncate { + + override def buildForBatch(): BatchWrite = { + val dataType = nebulaOptions.dataType + if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { + val vertexFiled = nebulaOptions.vertexField + val vertexIndex: Int = { + var index: Int = -1 + for (i <- schema.fields.indices) { + if (schema.fields(i).name.equals(vertexFiled)) { + index = i + } + } + if (index < 0) { + throw new IllegalOptionException( + s" vertex field ${vertexFiled} does not exist in dataframe") + } + index + } + new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema) + } else { + val srcVertexFiled = nebulaOptions.srcVertexField + val dstVertexField = nebulaOptions.dstVertexField + val rankExist = !nebulaOptions.rankField.isEmpty + val edgeFieldsIndex = { + var srcIndex: Int = -1 + var dstIndex: Int = -1 + var rankIndex: Int = -1 + for (i <- schema.fields.indices) { + if (schema.fields(i).name.equals(srcVertexFiled)) { + srcIndex = i + } + if (schema.fields(i).name.equals(dstVertexField)) { + dstIndex = i + } + if (rankExist) { + if (schema.fields(i).name.equals(nebulaOptions.rankField)) { + rankIndex = i + } + } + } + // check src filed and dst field + if (srcIndex < 0 || dstIndex < 0) { + throw new IllegalOptionException( + s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe") + } + // check rank field + if (rankExist && rankIndex < 0) { + throw new IllegalOptionException(s"rank field does not exist in dataframe") + } + + if (!rankExist) { + (srcIndex, dstIndex, Option.empty) + } else { + (srcIndex, dstIndex, Option(rankIndex)) + } + + } + new NebulaDataSourceEdgeWriter(nebulaOptions, + edgeFieldsIndex._1, + edgeFieldsIndex._2, + edgeFieldsIndex._3, + schema) + } + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + new NebulaWriterBuilder(schema, SaveMode.Overwrite, nebulaOptions) + } +} diff --git a/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml b/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml new file mode 100644 index 00000000..348f539b --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/resources/docker-compose.yaml @@ -0,0 +1,362 @@ +version: '3.4' +services: + metad0: + image: vesoft/nebula-metad:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.1.1 + - --ws_ip=172.28.1.1 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --ws_http_port=19559 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.1.1:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9559:9559" + - 19559 + - 11002 + volumes: + - ./data/meta0:/data/meta:Z + - ./logs/meta0:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.1.1 + restart: on-failure + cap_add: + - SYS_PTRACE + + metad1: + image: vesoft/nebula-metad:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.1.2 + - --ws_ip=172.28.1.2 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --ws_http_port=19559 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.1.2:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9560:9559" + - 19559 + - 11002 + volumes: + - ./data/meta1:/data/meta:Z + - ./logs/meta1:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.1.2 + restart: on-failure + cap_add: + - SYS_PTRACE + + metad2: + image: vesoft/nebula-metad:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.1.3 + - --ws_ip=172.28.1.3 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --ws_http_port=19559 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.1.3:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9561:9559" + - 19559 + - 11002 + volumes: + - ./data/meta2:/data/meta:Z + - ./logs/meta2:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.1.3 + restart: on-failure + cap_add: + - SYS_PTRACE + + storaged0: + image: vesoft/nebula-storaged:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.2.1 + - --ws_ip=172.28.2.1 + - --port=9779 + - --data_path=/data/storage + - --log_dir=/logs + - --v=0 + - --ws_http_port=19779 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.2.1:19779/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9779:9779" + - 19779 + - 12002 + volumes: + - ./data/storage0:/data/storage:Z + - ./logs/storage0:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.2.1 + restart: on-failure + cap_add: + - SYS_PTRACE + + storaged1: + image: vesoft/nebula-storaged:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.2.2 + - --ws_ip=172.28.2.2 + - --port=9779 + - --data_path=/data/storage + - --log_dir=/logs + - --v=0 + - --ws_http_port=19779 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.2.2:19779/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9780:9779" + - 19779 + - 12002 + volumes: + - ./data/storage1:/data/storage:Z + - ./logs/storage1:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.2.2 + restart: on-failure + cap_add: + - SYS_PTRACE + + storaged2: + image: vesoft/nebula-storaged:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --local_ip=172.28.2.3 + - --ws_ip=172.28.2.3 + - --port=9779 + - --data_path=/data/storage + - --log_dir=/logs + - --v=0 + - --ws_http_port=19779 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.2.3:19779/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9781:9779" + - 19779 + - 12002 + volumes: + - ./data/storage2:/data/storage:Z + - ./logs/storage2:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.2.3 + restart: on-failure + cap_add: + - SYS_PTRACE + + graphd0: + image: vesoft/nebula-graphd:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --port=9669 + - --ws_ip=172.28.3.1 + - --log_dir=/logs + - --v=0 + - --ws_http_port=19669 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.3.1:19669/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9669:9669" + - 19669 + - 13002 + volumes: + - ./logs/graph0:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.3.1 + restart: on-failure + cap_add: + - SYS_PTRACE + + graphd1: + image: vesoft/nebula-graphd:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --port=9669 + - --ws_ip=172.28.3.2 + - --log_dir=/logs + - --v=0 + - --ws_http_port=19669 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.3.2:19669/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9670:9669" + - 19669 + - 13002 + volumes: + - ./logs/graph1:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.3.2 + restart: on-failure + cap_add: + - SYS_PTRACE + + graphd2: + image: vesoft/nebula-graphd:nightly + environment: + USER: root + TZ: "${TZ}" + command: + - --meta_server_addrs=172.28.1.1:9559,172.28.1.2:9559,172.28.1.3:9559 + - --port=9669 + - --ws_ip=172.28.3.3 + - --log_dir=/logs + - --v=0 + - --ws_http_port=19669 + - --minloglevel=0 + - --heartbeat_interval_secs=2 + depends_on: + - metad0 + - metad1 + - metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://172.28.3.3:19669/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9671:9669" + - 19669 + - 13002 + volumes: + - ./logs/graph2:/logs:Z + networks: + nebula-net: + ipv4_address: 172.28.3.3 + restart: on-failure + cap_add: + - SYS_PTRACE + + console: + image: vesoft/nebula-console:nightly + entrypoint: "" + command: + - sh + - -c + - | + sleep 3 && + nebula-console -addr graphd0 -port 9669 -u root -p nebula -e 'ADD HOSTS "172.28.2.1":9779,"172.28.2.2":9779,"172.28.2.3":9779' && + sleep 36000 + depends_on: + - graphd0 + networks: + - nebula-net + +networks: + nebula-net: + ipam: + driver: default + config: + - subnet: 172.28.0.0/16 diff --git a/nebula-spark-connector_3.0/src/test/resources/edge.csv b/nebula-spark-connector_3.0/src/test/resources/edge.csv new file mode 100644 index 00000000..2a2380fe --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/resources/edge.csv @@ -0,0 +1,14 @@ +id1,id2,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14 +1,2,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2) +2,3,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4) +3,4,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6) +4,5,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7) +5,6,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5) +6,7,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)" +7,1,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)" +8,1,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)" +9,1,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)" +10,2,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)" +-1,5,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" +-2,6,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" +-3,7,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" diff --git a/nebula-spark-connector_3.0/src/test/resources/log4j.properties b/nebula-spark-connector_3.0/src/test/resources/log4j.properties new file mode 100644 index 00000000..913391db --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/resources/log4j.properties @@ -0,0 +1,6 @@ +# Global logging configuration +log4j.rootLogger=INFO, stdout +# Console output... +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n diff --git a/nebula-spark-connector_3.0/src/test/resources/vertex.csv b/nebula-spark-connector_3.0/src/test/resources/vertex.csv new file mode 100644 index 00000000..2b74dfa0 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/resources/vertex.csv @@ -0,0 +1,14 @@ +id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15 +1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2),"duration({years:1,months:1,seconds:1})" +2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4),"duration({years:1,months:1,seconds:1})" +3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6),"duration({years:1,months:1,seconds:1})" +4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7),"duration({years:1,months:1,seconds:1})" +5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5),"duration({years:1,months:1,seconds:1})" +6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" +7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" +8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" +9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" +10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" +-1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" +-2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" +-3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala new file mode 100644 index 00000000..b8f0e72e --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala @@ -0,0 +1,192 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.mock + +import com.vesoft.nebula.client.graph.NebulaPoolConfig +import com.vesoft.nebula.client.graph.data.HostAddress +import com.vesoft.nebula.client.graph.net.NebulaPool +import org.apache.log4j.Logger + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + +class NebulaGraphMock { + private[this] val LOG = Logger.getLogger(this.getClass) + + @transient val nebulaPoolConfig = new NebulaPoolConfig + @transient val pool: NebulaPool = new NebulaPool + val address = new ListBuffer[HostAddress]() + address.append(new HostAddress("127.0.0.1", 9669)) + + val randAddr = scala.util.Random.shuffle(address) + pool.init(randAddr.asJava, nebulaPoolConfig) + + def mockStringIdGraph(): Unit = { + val session = pool.getSession("root", "nebula", true) + + val createSpace = "CREATE SPACE IF NOT EXISTS test_string(partition_num=10,vid_type=fixed_string(8));" + + "USE test_string;" + "CREATE TAG IF NOT EXISTS person(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" + + "CREATE EDGE IF NOT EXISTS friend(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" + + "CREATE TAG IF NOT EXISTS geo_shape(geo geography);" + val createResp = session.execute(createSpace) + if (!createResp.isSucceeded) { + close() + LOG.error("create string type space failed," + createResp.getErrorMessage) + sys.exit(-1) + } + + Thread.sleep(10000) + val insertTag = + "INSERT VERTEX person(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " + + " \"1\":(\"person1\", \"person1\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"2\":(\"person2\", \"person2\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"3\":(\"person3\", \"person3\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"4\":(\"person4\", \"person4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"5\":(\"person5\", \"person5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"6\":(\"person6\", \"person6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"7\":(\"person7\", \"person7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"8\":(\"person8\", \"person8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"9\":(\"person9\", \"person9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"10\":(\"person10\", \"person10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"11\":(\"person11\", \"person11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"12\":(\"person12\", \"person11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"-1\":(\"person00\", \"person00\", 23, 1400, 13000, 1388888, date(\"2021-01-13\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"-2\":(\"person01\", \"person01\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"-3\":(\"person02\", \"person02\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"19\":(\"person19\", \"person22\", 25, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"22\":(\"person22\", \"person22\", 26, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"));" + + "INSERT VERTEX geo_shape(geo) VALUES \"100\":(ST_GeogFromText(\"POINT(1 2)\")), \"101\":(ST_GeogFromText(\"LINESTRING(1 2, 3 4)\")), \"102\":(ST_GeogFromText(\"POLYGON((0 1, 1 2, 2 3, 0 1))\"))" + val insertTagResp = session.execute(insertTag) + if (!insertTagResp.isSucceeded) { + close() + LOG.error("insert vertex for string type space failed," + insertTagResp.getErrorMessage) + sys.exit(-1) + } + + val insertEdge = "INSERT EDGE friend(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " + + " \"1\" -> \"2\":(\"friend1\", \"friend2\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"2\" -> \"3\":(\"friend2\", \"friend3\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"3\" -> \"4\":(\"friend3\", \"friend4\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"4\" -> \"5\":(\"friend4\", \"friend4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"5\" -> \"6\":(\"friend5\", \"friend5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"6\" -> \"7\":(\"friend6\", \"friend6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"7\" -> \"8\":(\"friend7\", \"friend7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"8\" -> \"9\":(\"friend8\", \"friend8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"9\" -> \"10\":(\"friend9\", \"friend9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"10\" -> \"11\":(\"friend10\", \"friend10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"11\" -> \"12\":(\"friend11\", \"friend11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " \"12\" -> \"1\":(\"friend12\", \"friend11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"-1\" -> \"11\":(\"friend13\", \"friend12\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " \"-2\" -> \"-1\":(\"friend14\", \"friend13\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))" + val insertEdgeResp = session.execute(insertEdge) + if (!insertEdgeResp.isSucceeded) { + close() + LOG.error("insert edge for string type space failed," + insertEdgeResp.getErrorMessage) + sys.exit(-1) + } + } + + def mockIntIdGraph(): Unit = { + val session = pool.getSession("root", "nebula", true) + + val createSpace = "CREATE SPACE IF NOT EXISTS test_int(partition_num=10, vid_type=int64);" + + "USE test_int;" + "CREATE TAG IF NOT EXISTS person(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" + + "CREATE EDGE IF NOT EXISTS friend(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time);" + + "CREATE TAG IF NOT EXISTS geo_shape(geo geography);" + + "CREATE TAG IF NOT EXISTS tag_duration(col duration);" + val createResp = session.execute(createSpace) + if (!createResp.isSucceeded) { + close() + LOG.error("create int type space failed," + createResp.getErrorMessage) + sys.exit(-1) + } + + Thread.sleep(10000) + val insertTag = + "INSERT VERTEX person(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " + + " 1:(\"person1\", \"person1\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 2:(\"person2\", \"person2\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 3:(\"person3\", \"person3\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 4:(\"person4\", \"person4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 5:(\"person5\", \"person5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 6:(\"person6\", \"person6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 7:(\"person7\", \"person7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 8:(\"person8\", \"person8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 9:(\"person9\", \"person9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 10:(\"person10\", \"person10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 11:(\"person11\", \"person11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 12:(\"person12\", \"person11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " -1:(\"person00\", \"person00\", 23, 1400, 13000, 1388888, date(\"2021-01-13\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " -2:(\"person01\", \"person01\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " -3:(\"person02\", \"person02\", 24, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 19:(\"person19\", \"person22\", 25, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 22:(\"person22\", \"person22\", 26, 1500, 14000, 1488888, date(\"2021-01-14\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\")), " + + " 0:(null, null, null, null, null, null, null, null, null, null, null, null, null);" + + "INSERT VERTEX geo_shape(geo) VALUES 100:(ST_GeogFromText(\"POINT(1 2)\")), 101:(ST_GeogFromText(\"LINESTRING(1 2, 3 4)\")), 102:(ST_GeogFromText(\"POLYGON((0 1, 1 2, 2 3, 0 1))\"));" + + "INSERT VERTEX tag_duration(col) VALUES 200:(duration({months:1, seconds:100, microseconds:20}))" + + val insertTagResp = session.execute(insertTag) + if (!insertTagResp.isSucceeded) { + close() + LOG.error("insert vertex for int type space failed," + insertTagResp.getErrorMessage) + sys.exit(-1) + } + + val insertEdge = "INSERT EDGE friend(col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13) VALUES " + + " 1 -> 2:(\"friend1\", \"friend2\", 11, 200, 1000, 188888, date(\"2021-01-01\"), datetime(\"2021-01-01T12:00:00\"),timestamp(\"2021-01-01T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 2 -> 3:(\"friend2\", \"friend3\", 12, 300, 2000, 288888, date(\"2021-01-02\"), datetime(\"2021-01-02T12:00:00\"),timestamp(\"2021-01-02T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 3 -> 4:(\"friend3\", \"friend4\", 13, 400, 3000, 388888, date(\"2021-01-03\"), datetime(\"2021-01-03T12:00:00\"),timestamp(\"2021-01-03T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 4 -> 5:(\"friend4\", \"friend4\", 14, 500, 4000, 488888, date(\"2021-01-04\"), datetime(\"2021-01-04T12:00:00\"),timestamp(\"2021-01-04T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 5 -> 6:(\"friend5\", \"friend5\", 15, 600, 5000, 588888, date(\"2021-01-05\"), datetime(\"2021-01-05T12:00:00\"),timestamp(\"2021-01-05T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 6 -> 7:(\"friend6\", \"friend6\", 16, 700, 6000, 688888, date(\"2021-01-06\"), datetime(\"2021-01-06T12:00:00\"),timestamp(\"2021-01-06T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 7 -> 8:(\"friend7\", \"friend7\", 17, 800, 7000, 788888, date(\"2021-01-07\"), datetime(\"2021-01-07T12:00:00\"),timestamp(\"2021-01-07T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 8 -> 9:(\"friend8\", \"friend8\", 18, 900, 8000, 888888, date(\"2021-01-08\"), datetime(\"2021-01-08T12:00:00\"),timestamp(\"2021-01-08T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 9 -> 10:(\"friend9\", \"friend9\", 19, 1000, 9000, 988888, date(\"2021-01-09\"), datetime(\"2021-01-09T12:00:00\"),timestamp(\"2021-01-09T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 10 -> 11:(\"friend10\", \"friend10\", 20, 1100, 10000, 1088888, date(\"2021-01-10\"), datetime(\"2021-01-10T12:00:00\"),timestamp(\"2021-01-10T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))," + + " 11 -> 12:(\"friend11\", \"friend11\", 21, 1200, 11000, 1188888, date(\"2021-01-11\"), datetime(\"2021-01-11T12:00:00\"),timestamp(\"2021-01-11T12:00:00\"), false, 1.0, 2.0, time(\"12:01:01\"))," + + " 12 -> 1:(\"friend12\", \"friend11\", 22, 1300, 12000, 1288888, date(\"2021-01-12\"), datetime(\"2021-01-12T12:00:00\"),timestamp(\"2021-01-12T12:00:00\"), true, 1.0, 2.0, time(\"12:01:01\"))" + val insertEdgeResp = session.execute(insertEdge) + if (!insertEdgeResp.isSucceeded) { + close() + LOG.error("insert edge for int type space failed," + insertEdgeResp.getErrorMessage) + sys.exit(-1) + } + } + + def mockStringIdGraphSchema(): Unit = { + val session = pool.getSession("root", "nebula", true) + + val createSpace = "CREATE SPACE IF NOT EXISTS test_write_string(partition_num=10,vid_type=fixed_string(8));" + + "USE test_write_string;" + + "CREATE TAG IF NOT EXISTS person_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography, col15 duration);" + + "CREATE EDGE IF NOT EXISTS friend_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography);"; + val createResp = session.execute(createSpace) + if (!createResp.isSucceeded) { + close() + LOG.error("create string type space failed," + createResp.getErrorMessage) + sys.exit(-1) + } + } + + def mockIntIdGraphSchema(): Unit = { + val session = pool.getSession("root", "nebula", true) + + val createSpace = "CREATE SPACE IF NOT EXISTS test_write_int(partition_num=10, vid_type=int64);" + + "USE test_write_int;" + + "CREATE TAG IF NOT EXISTS person_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography, col15 duration);" + + "CREATE EDGE IF NOT EXISTS friend_connector(col1 string, col2 fixed_string(8), col3 int8, col4 int16, col5 int32, col6 int64, col7 date, col8 datetime, col9 timestamp, col10 bool, col11 double, col12 float, col13 time, col14 geography);"; + val createResp = session.execute(createSpace) + if (!createResp.isSucceeded) { + close() + LOG.error("create int type space failed," + createResp.getErrorMessage) + sys.exit(-1) + } + } + + def close(): Unit = { + pool.close() + } +} diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala new file mode 100644 index 00000000..a8eec279 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala @@ -0,0 +1,219 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.mock + +import com.facebook.thrift.protocol.TCompactProtocol +import com.vesoft.nebula.connector.{ + NebulaConnectionConfig, + WriteMode, + WriteNebulaEdgeConfig, + WriteNebulaVertexConfig +} +import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession + +object SparkMock { + + /** + * write nebula vertex with insert mode + */ + def writeVertex(): Unit = { + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + + val df = spark.read + .option("header", true) + .csv("src/test/resources/vertex.csv") + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig + .builder() + .withSpace("test_write_string") + .withTag("person_connector") + .withVidField("id") + .withVidAsProp(false) + .withBatch(5) + .build() + df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() + + spark.stop() + } + + /** + * write nebula vertex with delete mode + */ + def deleteVertex(): Unit = { + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + + val df = spark.read + .option("header", true) + .csv("src/test/resources/vertex.csv") + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig + .builder() + .withSpace("test_write_string") + .withTag("person_connector") + .withVidField("id") + .withVidAsProp(false) + .withWriteMode(WriteMode.DELETE) + .withBatch(5) + .build() + df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() + + spark.stop() + } + + /** + * write nebula edge with insert mode + */ + def writeEdge(): Unit = { + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + + val df = spark.read + .option("header", true) + .csv("src/test/resources/edge.csv") + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig + .builder() + .withSpace("test_write_string") + .withEdge("friend_connector") + .withSrcIdField("id1") + .withDstIdField("id2") + .withRankField("col3") + .withRankAsProperty(true) + .withBatch(5) + .build() + df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() + + spark.stop() + } + + /** + * write nebula edge with delete mode + */ + def deleteEdge(): Unit = { + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + + val df = spark.read + .option("header", true) + .csv("src/test/resources/edge.csv") + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig + .builder() + .withSpace("test_write_string") + .withEdge("friend_connector") + .withSrcIdField("id1") + .withDstIdField("id2") + .withRankField("col3") + .withRankAsProperty(true) + .withWriteMode(WriteMode.DELETE) + .withBatch(5) + .build() + df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() + + spark.stop() + } + + /** + * write nebula vertex with delete_with_edge mode + */ + def deleteVertexWithEdge(): Unit = { + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + val spark = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + + val df = spark.read + .option("header", true) + .csv("src/test/resources/vertex.csv") + + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withGraphAddress("127.0.0.1:9669") + .withConenctionRetry(2) + .build() + val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig + .builder() + .withSpace("test_write_string") + .withTag("person_connector") + .withVidField("id") + .withVidAsProp(false) + .withWriteMode(WriteMode.DELETE) + .withDeleteEdge(true) + .withBatch(5) + .build() + df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() + + spark.stop() + } + +} diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala new file mode 100644 index 00000000..39ee72f5 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/reader/ReadSuite.scala @@ -0,0 +1,340 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.facebook.thrift.protocol.TCompactProtocol +import com.vesoft.nebula.connector.connector.NebulaDataFrameReader +import com.vesoft.nebula.connector.{NebulaConnectionConfig, ReadNebulaConfig} +import com.vesoft.nebula.connector.mock.NebulaGraphMock +import org.apache.log4j.BasicConfigurator +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Encoders, SparkSession} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +class ReadSuite extends AnyFunSuite with BeforeAndAfterAll { + BasicConfigurator.configure() + var sparkSession: SparkSession = null + + override def beforeAll(): Unit = { + val graphMock = new NebulaGraphMock + graphMock.mockIntIdGraph() + graphMock.close() + val sparkConf = new SparkConf + sparkConf + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) + sparkSession = SparkSession + .builder() + .master("local") + .config(sparkConf) + .getOrCreate() + } + + override def afterAll(): Unit = { + sparkSession.stop() + } + + test("read vertex with no properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("person") + .withNoColumn(true) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show() + assert(vertex.count() == 18) + assert(vertex.schema.fields.length == 1) + } + + test("read vertex with specific properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("person") + .withNoColumn(false) + .withReturnCols(List("col1")) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show() + assert(vertex.count() == 18) + assert(vertex.schema.fields.length == 2) + + vertex.map(row => { + row.getAs[Long]("_vertexId") match { + case 1L => { + assert(row.getAs[String]("col1").equals("person1")) + } + case 0L => { + assert(row.isNullAt(1)) + } + } + "" + })(Encoders.STRING) + + } + + test("read vertex with all properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("person") + .withNoColumn(false) + .withReturnCols(List()) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show() + assert(vertex.count() == 18) + assert(vertex.schema.fields.length == 14) + + vertex.map(row => { + row.getAs[Long]("_vertexId") match { + case 1L => { + assert(row.getAs[String]("col1").equals("person1")) + assert(row.getAs[String]("col2").equals("person1")) + assert(row.getAs[Long]("col3") == 11) + assert(row.getAs[Long]("col4") == 200) + assert(row.getAs[Long]("col5") == 1000) + assert(row.getAs[Long]("col6") == 188888) + assert(row.getAs[String]("col7").equals("2021-01-01")) + assert(row.getAs[String]("col8").equals("2021-01-01T12:00:00.000")) + assert(row.getAs[Long]("col9") == 1609502400) + assert(row.getAs[Boolean]("col10")) + assert(row.getAs[Double]("col11") < 1.001) + assert(row.getAs[Double]("col12") < 2.001) + assert(row.getAs[String]("col13").equals("12:01:01")) + } + case 0L => { + assert(row.isNullAt(1)) + assert(row.isNullAt(2)) + assert(row.isNullAt(3)) + assert(row.isNullAt(4)) + assert(row.isNullAt(5)) + assert(row.isNullAt(6)) + assert(row.isNullAt(7)) + assert(row.isNullAt(8)) + assert(row.isNullAt(9)) + assert(row.isNullAt(10)) + assert(row.isNullAt(11)) + assert(row.isNullAt(12)) + assert(row.isNullAt(13)) + } + } + "" + })(Encoders.STRING) + } + + test("read vertex for geo_shape") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("geo_shape") + .withNoColumn(false) + .withReturnCols(List("geo")) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show() + assert(vertex.count() == 3) + assert(vertex.schema.fields.length == 2) + + vertex.map(row => { + row.getAs[Long]("_vertexId") match { + case 100L => { + assert(row.getAs[String]("geo").equals("POINT(1 2)")) + } + case 101L => { + assert(row.getAs[String]("geo").equals("LINESTRING(1 2, 3 4)")) + } + case 102L => { + assert(row.getAs[String]("geo").equals("POLYGON((0 1, 1 2, 2 3, 0 1))")) + } + } + "" + })(Encoders.STRING) + } + + test("read vertex for tag_duration") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("tag_duration") + .withNoColumn(false) + .withReturnCols(List("col")) + .withLimit(10) + .withPartitionNum(10) + .build() + val vertex = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadVerticesToDF() + vertex.printSchema() + vertex.show() + assert(vertex.count() == 1) + assert(vertex.schema.fields.length == 2) + + vertex.map(row => { + row.getAs[Long]("_vertexId") match { + case 200L => { + assert( + row.getAs[String]("col").equals("duration({months:1, seconds:100, microseconds:20})")) + } + } + "" + })(Encoders.STRING) + } + + test("read edge with no properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("friend") + .withNoColumn(true) + .withLimit(10) + .withPartitionNum(10) + .build() + val edge = sparkSession.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() + edge.printSchema() + edge.show() + assert(edge.count() == 12) + assert(edge.schema.fields.length == 3) + + edge.map(row => { + row.getAs[Long]("_srcId") match { + case 1L => { + assert(row.getAs[Long]("_dstId") == 2) + assert(row.getAs[Long]("_rank") == 0) + } + } + "" + })(Encoders.STRING) + } + + test("read edge with specific properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("friend") + .withNoColumn(false) + .withReturnCols(List("col1")) + .withLimit(10) + .withPartitionNum(10) + .build() + val edge = sparkSession.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() + edge.printSchema() + edge.show(20) + assert(edge.count() == 12) + assert(edge.schema.fields.length == 4) + edge.map(row => { + row.getAs[Long]("_srcId") match { + case 1L => { + assert(row.getAs[Long]("_dstId") == 2) + assert(row.getAs[Long]("_rank") == 0) + assert(row.getAs[String]("col1").equals("friend1")) + } + } + "" + })(Encoders.STRING) + } + + test("read edge with all properties") { + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("127.0.0.1:9559") + .withConenctionRetry(2) + .build() + val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig + .builder() + .withSpace("test_int") + .withLabel("friend") + .withNoColumn(false) + .withReturnCols(List()) + .withLimit(10) + .withPartitionNum(10) + .build() + val edge = sparkSession.read.nebula(config, nebulaReadVertexConfig).loadEdgesToDF() + edge.printSchema() + edge.show() + assert(edge.count() == 12) + assert(edge.schema.fields.length == 16) + + edge.map(row => { + row.getAs[Long]("_srcId") match { + case 1L => { + assert(row.getAs[Long]("_dstId") == 2) + assert(row.getAs[Long]("_rank") == 0) + assert(row.getAs[String]("col1").equals("friend1")) + assert(row.getAs[String]("col2").equals("friend2")) + assert(row.getAs[Long]("col3") == 11) + assert(row.getAs[Long]("col4") == 200) + assert(row.getAs[Long]("col5") == 1000) + assert(row.getAs[Long]("col6") == 188888) + assert(row.getAs[String]("col7").equals("2021-01-01")) + assert(row.getAs[String]("col8").equals("2021-01-01T12:00:00.000")) + assert(row.getAs[Long]("col9") == 1609502400) + assert(row.getAs[Boolean]("col10")) + assert(row.getAs[Double]("col11") < 1.001) + assert(row.getAs[Double]("col12") < 2.001) + assert(row.getAs[String]("col13").equals("12:01:01")) + } + } + "" + })(Encoders.STRING) + } + +} diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala new file mode 100644 index 00000000..7a95f623 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/NebulaExecutorSuite.scala @@ -0,0 +1,408 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.connector.KeyPolicy +import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges, NebulaVertex, NebulaVertices} +import org.apache.log4j.BasicConfigurator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.{ + BooleanType, + DataTypes, + LongType, + StringType, + StructField, + StructType +} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import scala.collection.mutable.ListBuffer + +class NebulaExecutorSuite extends AnyFunSuite with BeforeAndAfterAll { + BasicConfigurator.configure() + var schema: StructType = _ + var row: InternalRow = _ + + override def beforeAll(): Unit = { + val fields = new ListBuffer[StructField] + fields.append(DataTypes.createStructField("col1", StringType, false)) + fields.append(DataTypes.createStructField("col2", BooleanType, false)) + fields.append(DataTypes.createStructField("col3", LongType, false)) + schema = new StructType(fields.toArray) + + val values = new ListBuffer[Any] + values.append("aaa") + values.append(true) + values.append(1L) + row = new GenericInternalRow(values.toArray) + } + + override def afterAll(): Unit = super.afterAll() + + test("test extraID") { + val index: Int = 1 + val policy: Option[KeyPolicy.Value] = None + val isVidStringType: Boolean = true + val stringId = NebulaExecutor.extraID(schema, row, index, policy, isVidStringType) + assert("\"true\"".equals(stringId)) + + // test hash vertexId + val hashId = NebulaExecutor.extraID(schema, row, index, Some(KeyPolicy.HASH), false) + assert("true".equals(hashId)) + } + + test("test extraRank") { + // test correct type for rank + assert(NebulaExecutor.extraRank(schema, row, 2) == 1) + + // test wrong type for rank + try { + NebulaExecutor.extraRank(schema, row, 1) + } catch { + case e: java.lang.AssertionError => assert(true) + } + } + + test("test vid as prop for assignVertexPropValues ") { + val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2) + // test vid as prop + val props = NebulaExecutor.assignVertexPropValues(schema, row, 0, true, fieldTypeMap) + assert(props.size == 3) + assert(props.contains("\"aaa\"")) + } + + test("test vid not as prop for assignVertexPropValues ") { + val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2) + // test vid not as prop + val props = NebulaExecutor.assignVertexPropValues(schema, row, 0, false, fieldTypeMap) + assert(props.size == 2) + assert(!props.contains("\"aaa\"")) + } + + test("test src & dst & rank all as prop for assignEdgeValues") { + val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2) + + val prop = NebulaExecutor.assignEdgeValues(schema, row, 0, 1, 2, true, true, true, fieldTypeMap) + assert(prop.size == 3) + } + + test("test src & dst & rank all not as prop for assignEdgeValues") { + val fieldTypeMap: Map[String, Integer] = Map("col1" -> 6, "col2" -> 1, "col3" -> 2) + + val prop = + NebulaExecutor.assignEdgeValues(schema, row, 0, 1, 2, false, false, false, fieldTypeMap) + assert(prop.isEmpty) + } + + test("test assignVertexPropNames") { + // test vid as prop + val propNames = NebulaExecutor.assignVertexPropNames(schema, 0, true) + assert(propNames.size == 3) + assert(propNames.contains("col1") && propNames.contains("col2") && propNames.contains("col3")) + + // test vid not as prop + val propNames1 = NebulaExecutor.assignVertexPropNames(schema, 0, false) + assert(propNames1.size == 2) + assert(!propNames1.contains("col1") && propNames.contains("col2") && propNames.contains("col3")) + } + + test("test assignEdgePropNames") { + // test src / dst / rank all as prop + val propNames = NebulaExecutor.assignEdgePropNames(schema, 0, 1, 2, true, true, true) + assert(propNames.size == 3) + + // test src / dst / rank all not as prop + val propNames1 = NebulaExecutor.assignEdgePropNames(schema, 0, 1, 2, false, false, false) + assert(propNames1.isEmpty) + } + + test("test toExecuteSentence for vertex") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + val tagName = "person" + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date", + "col_geo") + + val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(3 8)") + val props2 = + List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "LINESTRING(1 2, 3 4)") + vertices.append(NebulaVertex("\"vid1\"", props1)) + vertices.append(NebulaVertex("\"vid2\"", props2)) + + val nebulaVertices = NebulaVertices(propNames, vertices.toList, None) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + + val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\": (" + props1 + .mkString(", ") + "), \"vid2\": (" + props2.mkString(", ") + ")" + assert(expectStatement.equals(vertexStatement)) + } + + test("test toExecuteSentence for vertex with hash policy") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + val tagName = "person" + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date", + "col_geo") + + val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(1 2)") + val props2 = + List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "LINESTRING(1 2, 3 4)") + vertices.append(NebulaVertex("vid1", props1)) + vertices.append(NebulaVertex("vid2", props2)) + + val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH)) + val vertexStatement = NebulaExecutor.toExecuteSentence(tagName, nebulaVertices) + + val expectStatement = "INSERT vertex `person`(`col_string`,`col_fixed_string`,`col_bool`," + + "`col_int`,`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES hash(\"vid1\"): (" + props1 + .mkString(", ") + + "), hash(\"vid2\"): (" + props2.mkString(", ") + ")" + assert(expectStatement.equals(vertexStatement)) + } + + test("test toExecuteSentence for edge") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + val edgeName = "friend" + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date", + "col_geo") + val props1 = + List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POLYGON((0 1, 1 2, 2 3, 0 1))") + val props2 = List("\"Bob\"", + "\"Bob\"", + false, + 20, + 200L, + 2.0, + "2021-05-01", + "POLYGON((0 1, 1 2, 2 3, 0 1))") + edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), props1)) + edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2)) + + val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None) + val edgeStatement = NebulaExecutor.toExecuteSentence(edgeName, nebulaEdges) + + val expectStatement = "INSERT edge `friend`(`col_string`,`col_fixed_string`,`col_bool`,`col_int`" + + ",`col_int64`,`col_double`,`col_date`,`col_geo`) VALUES \"vid1\"->\"vid2\"@1: (" + + props1.mkString(", ") + "), \"vid2\"->\"vid1\"@2: (" + props2.mkString(", ") + ")" + assert(expectStatement.equals(edgeStatement)) + } + + test("test toUpdateExecuteSentence for vertex") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date", + "col_geo") + + val props1 = + List("\"name\"", "\"name\"", true, 10, 100L, 1.0, "2021-11-12", "LINESTRING(1 2, 3 4)") + val props2 = List("\"name2\"", + "\"name2\"", + false, + 11, + 101L, + 2.0, + "2021-11-13", + "POLYGON((0 1, 1 2, 2 3, 0 1))") + + vertices.append(NebulaVertex("\"vid1\"", props1)) + vertices.append(NebulaVertex("\"vid2\"", props2)) + val nebulaVertices = NebulaVertices(propNames, vertices.toList, None) + + val updateVertexStatement = + NebulaExecutor.toUpdateExecuteStatement("person", nebulaVertices) + + val expectVertexUpdate = + "UPDATE VERTEX ON `person` \"vid1\" SET `col_string`=\"name\",`col_fixed_string`=\"name\"," + + "`col_bool`=true,`col_int`=10,`col_int64`=100,`col_double`=1.0,`col_date`=2021-11-12," + + "`col_geo`=LINESTRING(1 2, 3 4);UPDATE VERTEX ON `person` \"vid2\" SET " + + "`col_string`=\"name2\",`col_fixed_string`=\"name2\",`col_bool`=false,`col_int`=11," + + "`col_int64`=101,`col_double`=2.0,`col_date`=2021-11-13," + + "`col_geo`=POLYGON((0 1, 1 2, 2 3, 0 1))" + assert(expectVertexUpdate.equals(updateVertexStatement)) + } + + test("test toUpdateExecuteSentence for vertex with hash policy") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date") + + val props1 = List("\"name\"", "\"name\"", true, 10, 100L, 1.0, "2021-11-12") + val props2 = List("\"name2\"", "\"name2\"", false, 11, 101L, 2.0, "2021-11-13") + + vertices.append(NebulaVertex("vid1", props1)) + vertices.append(NebulaVertex("vid2", props2)) + val nebulaVertices = NebulaVertices(propNames, vertices.toList, Some(KeyPolicy.HASH)) + + val updateVertexStatement = + NebulaExecutor.toUpdateExecuteStatement("person", nebulaVertices) + val expectVertexUpdate = + "UPDATE VERTEX ON `person` hash(\"vid1\") SET `col_string`=\"name\",`col_fixed_string`=\"name\"," + + "`col_bool`=true,`col_int`=10,`col_int64`=100,`col_double`=1.0,`col_date`=2021-11-12;" + + "UPDATE VERTEX ON `person` hash(\"vid2\") SET `col_string`=\"name2\",`col_fixed_string`=\"name2\"," + + "`col_bool`=false,`col_int`=11,`col_int64`=101,`col_double`=2.0,`col_date`=2021-11-13" + assert(expectVertexUpdate.equals(updateVertexStatement)) + } + + test("test toUpdateExecuteSentence for edge") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date", + "col_geo") + val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12", "POINT(1 2)") + val props2 = List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01", "POINT(2 3)") + edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), props1)) + edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), props2)) + + val nebulaEdges = NebulaEdges(propNames, edges.toList, None, None) + val updateEdgeStatement = NebulaExecutor.toUpdateExecuteStatement("friend", nebulaEdges) + val expectEdgeUpdate = + "UPDATE EDGE ON `friend` \"vid1\"->\"vid2\"@1 SET `col_string`=\"Tom\"," + + "`col_fixed_string`=\"Tom\",`col_bool`=true,`col_int`=10,`col_int64`=100," + + "`col_double`=1.0,`col_date`=2021-11-12,`col_geo`=POINT(1 2);" + + "UPDATE EDGE ON `friend` \"vid2\"->\"vid1\"@2 SET `col_string`=\"Bob\"," + + "`col_fixed_string`=\"Bob\",`col_bool`=false,`col_int`=20,`col_int64`=200," + + "`col_double`=2.0,`col_date`=2021-05-01,`col_geo`=POINT(2 3)" + assert(expectEdgeUpdate.equals(updateEdgeStatement)) + } + + test("test toUpdateExecuteSentence for edge with hash policy") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + val propNames = List("col_string", + "col_fixed_string", + "col_bool", + "col_int", + "col_int64", + "col_double", + "col_date") + val props1 = List("\"Tom\"", "\"Tom\"", true, 10, 100L, 1.0, "2021-11-12") + val props2 = List("\"Bob\"", "\"Bob\"", false, 20, 200L, 2.0, "2021-05-01") + edges.append(NebulaEdge("vid1", "vid2", Some(1L), props1)) + edges.append(NebulaEdge("vid2", "vid1", Some(2L), props2)) + + val nebulaEdges = + NebulaEdges(propNames, edges.toList, Some(KeyPolicy.HASH), Some(KeyPolicy.HASH)) + val updateEdgeStatement = NebulaExecutor.toUpdateExecuteStatement("friend", nebulaEdges) + val expectEdgeUpdate = + "UPDATE EDGE ON `friend` hash(\"vid1\")->hash(\"vid2\")@1 SET `col_string`=\"Tom\"," + + "`col_fixed_string`=\"Tom\",`col_bool`=true,`col_int`=10,`col_int64`=100," + + "`col_double`=1.0,`col_date`=2021-11-12;" + + "UPDATE EDGE ON `friend` hash(\"vid2\")->hash(\"vid1\")@2 SET `col_string`=\"Bob\"," + + "`col_fixed_string`=\"Bob\",`col_bool`=false,`col_int`=20,`col_int64`=200," + + "`col_double`=2.0,`col_date`=2021-05-01" + assert(expectEdgeUpdate.equals(updateEdgeStatement)) + } + + test("test toDeleteExecuteStatement for vertex") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + vertices.append(NebulaVertex("\"vid1\"", List())) + vertices.append(NebulaVertex("\"vid2\"", List())) + + val nebulaVertices = NebulaVertices(List(), vertices.toList, None) + val vertexStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, false) + val expectVertexDeleteStatement = "DELETE VERTEX \"vid1\",\"vid2\"" + assert(expectVertexDeleteStatement.equals(vertexStatement)) + + val vertexWithEdgeStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, true) + val expectVertexWithEdgeDeleteStatement = "DELETE VERTEX \"vid1\",\"vid2\" WITH EDGE" + assert(expectVertexWithEdgeDeleteStatement.equals(vertexWithEdgeStatement)) + } + + test("test toDeleteExecuteStatement for vertex with HASH policy") { + val vertices: ListBuffer[NebulaVertex] = new ListBuffer[NebulaVertex] + vertices.append(NebulaVertex("vid1", List())) + vertices.append(NebulaVertex("vid2", List())) + + val nebulaVertices = NebulaVertices(List(), vertices.toList, Some(KeyPolicy.HASH)) + val vertexStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, false) + val expectVertexDeleteStatement = "DELETE VERTEX hash(\"vid1\"),hash(\"vid2\")" + assert(expectVertexDeleteStatement.equals(vertexStatement)) + + val vertexWithEdgeStatement = NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, true) + val expectVertexWithEdgeDeleteStatement = + "DELETE VERTEX hash(\"vid1\"),hash(\"vid2\") WITH EDGE" + assert(expectVertexWithEdgeDeleteStatement.equals(vertexWithEdgeStatement)) + } + + test("test toDeleteExecuteStatement for edge") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Some(1L), List())) + edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Some(2L), List())) + + val nebulaEdges = NebulaEdges(List(), edges.toList, None, None) + val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges) + val expectEdgeDeleteStatement = "DELETE EDGE `friend` \"vid1\"->\"vid2\"@1,\"vid2\"->\"vid1\"@2" + assert(expectEdgeDeleteStatement.equals(edgeStatement)) + } + + test("test toDeleteExecuteStatement for edge without rank") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + edges.append(NebulaEdge("\"vid1\"", "\"vid2\"", Option.empty, List())) + edges.append(NebulaEdge("\"vid2\"", "\"vid1\"", Option.empty, List())) + + val nebulaEdges = NebulaEdges(List(), edges.toList, None, None) + val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges) + val expectEdgeDeleteStatement = "DELETE EDGE `friend` \"vid1\"->\"vid2\"@0,\"vid2\"->\"vid1\"@0" + assert(expectEdgeDeleteStatement.equals(edgeStatement)) + } + + test("test toDeleteExecuteStatement for edge with src HASH policy") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + edges.append(NebulaEdge("vid1", "\"vid2\"", Some(1L), List())) + edges.append(NebulaEdge("vid2", "\"vid1\"", Some(2L), List())) + + val nebulaEdges = NebulaEdges(List(), edges.toList, Some(KeyPolicy.HASH), None) + val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges) + val expectEdgeDeleteStatement = + "DELETE EDGE `friend` hash(\"vid1\")->\"vid2\"@1,hash(\"vid2\")->\"vid1\"@2" + assert(expectEdgeDeleteStatement.equals(edgeStatement)) + } + + test("test toDeleteExecuteStatement for edge with all HASH policy") { + val edges: ListBuffer[NebulaEdge] = new ListBuffer[NebulaEdge] + edges.append(NebulaEdge("vid1", "vid2", Some(1L), List())) + edges.append(NebulaEdge("vid2", "vid1", Some(2L), List())) + + val nebulaEdges = NebulaEdges(List(), edges.toList, Some(KeyPolicy.HASH), Some(KeyPolicy.HASH)) + val edgeStatement = NebulaExecutor.toDeleteExecuteStatement("friend", nebulaEdges) + val expectEdgeDeleteStatement = + "DELETE EDGE `friend` hash(\"vid1\")->hash(\"vid2\")@1,hash(\"vid2\")->hash(\"vid1\")@2" + assert(expectEdgeDeleteStatement.equals(edgeStatement)) + } +} diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala new file mode 100644 index 00000000..bab62143 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.client.graph.data.ResultSet +import com.vesoft.nebula.connector.Address +import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} +import com.vesoft.nebula.connector.nebula.GraphProvider +import org.apache.log4j.BasicConfigurator +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { + BasicConfigurator.configure() + + override def beforeAll(): Unit = { + val graphMock = new NebulaGraphMock + graphMock.mockStringIdGraphSchema() + graphMock.mockIntIdGraphSchema() + graphMock.close() + Thread.sleep(10000) + SparkMock.writeVertex() + SparkMock.writeEdge() + } + + test("write vertex into test_write_string space with delete mode") { + SparkMock.deleteVertex() + val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) + val graphProvider = new GraphProvider(addresses, 3000) + + graphProvider.switchSpace("root", "nebula", "test_write_string") + val resultSet: ResultSet = + graphProvider.submit("use test_write_string;" + + "match (v:person_connector) return v limit 100000;") + assert(resultSet.isSucceeded) + assert(resultSet.getColumnNames.size() == 1) + assert(resultSet.isEmpty) + } + + test("write vertex into test_write_with_edge_string space with delete with edge mode") { + SparkMock.writeVertex() + SparkMock.writeEdge() + SparkMock.deleteVertexWithEdge() + val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) + val graphProvider = new GraphProvider(addresses, 3000) + + graphProvider.switchSpace("root", "nebula", "test_write_string") + // assert vertex is deleted + val vertexResultSet: ResultSet = + graphProvider.submit("use test_write_string;" + + "match (v:person_connector) return v limit 1000000;") + assert(vertexResultSet.isSucceeded) + assert(vertexResultSet.getColumnNames.size() == 1) + assert(vertexResultSet.isEmpty) + + // assert edge is deleted + val edgeResultSet: ResultSet = + graphProvider.submit("use test_write_string;" + + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e") + assert(vertexResultSet.isSucceeded) + assert(edgeResultSet.getColumnNames.size() == 1) + assert(edgeResultSet.isEmpty) + + } + + test("write edge into test_write_string space with delete mode") { + SparkMock.deleteEdge() + val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) + val graphProvider = new GraphProvider(addresses, 3000) + + graphProvider.switchSpace("root", "nebula", "test_write_string") + val resultSet: ResultSet = + graphProvider.submit("use test_write_string;" + + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e;") + assert(resultSet.isSucceeded) + assert(resultSet.getColumnNames.size() == 1) + assert(resultSet.isEmpty) + } +} diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala new file mode 100644 index 00000000..3cd6c7d6 --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala @@ -0,0 +1,71 @@ +/* Copyright (c) 2021 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.writer + +import com.vesoft.nebula.client.graph.data.ResultSet +import com.vesoft.nebula.connector.Address +import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} +import com.vesoft.nebula.connector.nebula.GraphProvider +import org.apache.log4j.BasicConfigurator +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { + BasicConfigurator.configure() + + override def beforeAll(): Unit = { + val graphMock = new NebulaGraphMock + graphMock.mockStringIdGraphSchema() + graphMock.mockIntIdGraphSchema() + graphMock.close() + Thread.sleep(10000) + } + + test("write vertex into test_write_string space with insert mode") { + SparkMock.writeVertex() + val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) + val graphProvider = new GraphProvider(addresses, 3000) + + graphProvider.switchSpace("root", "nebula", "test_write_string") + val createIndexResult: ResultSet = graphProvider.submit( + "use test_write_string; " + + "create tag index if not exists person_index on person_connector(col1(20));") + Thread.sleep(5000) + graphProvider.submit("rebuild tag index person_index;") + + Thread.sleep(5000) + + graphProvider.submit("use test_write_string;") + val resultSet: ResultSet = + graphProvider.submit("match (v:person_connector) return v;") + assert(resultSet.isSucceeded) + assert(resultSet.getColumnNames.size() == 1) + assert(resultSet.getRows.size() == 13) + } + + test("write edge into test_write_string space with insert mode") { + SparkMock.writeEdge() + + val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) + val graphProvider = new GraphProvider(addresses, 3000) + + graphProvider.switchSpace("root", "nebula", "test_write_string") + val createIndexResult: ResultSet = graphProvider.submit( + "use test_write_string; " + + "create edge index if not exists friend_index on friend_connector(col1(20));") + Thread.sleep(5000) + graphProvider.submit("rebuild edge index friend_index;") + + Thread.sleep(5000) + + graphProvider.submit("use test_write_string;") + val resultSet: ResultSet = + graphProvider.submit("match (v:person_connector)-[e:friend_connector] -> () return e;") + assert(resultSet.isSucceeded) + assert(resultSet.getColumnNames.size() == 1) + assert(resultSet.getRows.size() == 13) + } +} diff --git a/pom.xml b/pom.xml index e0c2e1df..5e3752dc 100644 --- a/pom.xml +++ b/pom.xml @@ -48,6 +48,7 @@ nebula-spark-connector nebula-spark-connector_2.2 + nebula-spark-connector_3.0 example nebula-spark-common From b90a5c17737961b17ca611fe8e3f86abd49b46f0 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 18 Nov 2022 09:57:45 +0800 Subject: [PATCH 02/13] add connector for spark3 & extract common code --- .../connector/NebulaSparkReaderExample.scala | 5 +- nebula-spark-common/pom.xml | 2 +- .../nebula/connector/NebulaOptions.scala | 62 ++-- .../vesoft/nebula/connector/NebulaUtils.scala | 71 +++++ .../connector/reader/NebulaReader.scala | 295 ++++++++++++++++++ .../nebula/connector/NebulaDataSource.scala | 8 +- .../reader/NebulaEdgePartitionReader.scala | 63 +--- .../reader/NebulaPartitionReader.scala | 138 +------- .../connector/reader/NebulaSourceReader.scala | 81 +---- .../reader/NebulaVertexPartitionReader.scala | 66 +--- .../nebula/connector/NebulaDataSource.scala | 9 +- .../reader/NebulaEdgePartitionReader.scala | 18 ++ .../connector/reader/NebulaEdgeReader.scala | 77 ----- .../connector/reader/NebulaIterator.scala | 141 +-------- .../nebula/connector/reader/NebulaRDD.scala | 4 +- .../connector/reader/NebulaRelation.scala | 84 +---- .../reader/NebulaVertexPartitionReader.scala | 19 ++ .../connector/reader/NebulaVertexReader.scala | 78 ----- nebula-spark-connector_3.0/pom.xml | 96 +++--- .../nebula/connector/NebulaDataSource.scala | 68 +--- .../reader/NebulaEdgePartitionReader.scala | 63 +--- .../reader/NebulaPartitionReader.scala | 137 +------- .../reader/NebulaVertexPartitionReader.scala | 70 +---- .../connector/reader/SimpleScanBuilder.scala | 2 - 24 files changed, 541 insertions(+), 1116 deletions(-) create mode 100644 nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala create mode 100644 nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala delete mode 100644 nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala create mode 100644 nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala delete mode 100644 nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala index 8a4a94b4..e78f7e8d 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala @@ -40,11 +40,10 @@ object NebulaSparkReaderExample { } def readVertex(spark: SparkSession): Unit = { - LOG.info("start to read nebula vertices") val config = NebulaConnectionConfig .builder() - .withMetaAddress("127.0.0.1:9559") + .withMetaAddress("192.168.8.171:9559") .withConenctionRetry(2) .build() val nebulaReadVertexConfig: ReadNebulaConfig = ReadNebulaConfig @@ -52,7 +51,7 @@ object NebulaSparkReaderExample { .withSpace("test") .withLabel("person") .withNoColumn(false) - .withReturnCols(List("birthday")) + .withReturnCols(List()) .withLimit(10) .withPartitionNum(10) .build() diff --git a/nebula-spark-common/pom.xml b/nebula-spark-common/pom.xml index 9eea5315..ca7d5281 100644 --- a/nebula-spark-common/pom.xml +++ b/nebula-spark-common/pom.xml @@ -22,7 +22,7 @@ 2.11 2.4.4 - 2.11.12 + 2.12.10 diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala index 21195c5c..2bb630cf 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaOptions.scala @@ -15,22 +15,18 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import scala.collection.mutable.ListBuffer -class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( - operaType: OperaType.Value) - extends Serializable - with Logging { +class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String]) extends Serializable { import NebulaOptions._ def this(parameters: Map[String, String], operaType: OperaType.Value) = - this(CaseInsensitiveMap(parameters))(operaType) + this(CaseInsensitiveMap(parameters)) def this(hostAndPorts: String, spaceName: String, dataType: String, label: String, - parameters: Map[String, String], - operaType: OperaType.Value) = { + parameters: Map[String, String]) = { this( CaseInsensitiveMap( parameters ++ Map( @@ -39,8 +35,9 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( NebulaOptions.TYPE -> dataType, NebulaOptions.LABEL -> label )) - )(operaType) + ) } + val operaType = OperaType.withName(parameters(OPERATE_TYPE)) /** * Return property with all options @@ -104,21 +101,24 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( val label: String = parameters(LABEL) /** read parameters */ - var returnCols: String = _ - var partitionNums: String = _ - var noColumn: Boolean = _ - var limit: Int = _ - var ngql: String = _ + var returnCols: String = _ + var partitionNums: String = _ + var noColumn: Boolean = _ + var limit: Int = _ + var pushDownFiltersEnabled: Boolean = _ + var ngql: String = _ if (operaType == OperaType.READ) { returnCols = parameters(RETURN_COLS) noColumn = parameters.getOrElse(NO_COLUMN, false).toString.toBoolean partitionNums = parameters(PARTITION_NUMBER) limit = parameters.getOrElse(LIMIT, DEFAULT_LIMIT).toString.toInt - ngql = parameters.getOrElse(NGQL,EMPTY_STRING) - ngql = parameters.getOrElse(NGQL,EMPTY_STRING) - if(ngql!=EMPTY_STRING){ + // TODO explore the pushDownFiltersEnabled parameter to users + pushDownFiltersEnabled = parameters.getOrElse(PUSHDOWN_FILTERS_ENABLE, false).toString.toBoolean + ngql = parameters.getOrElse(NGQL, EMPTY_STRING) + ngql = parameters.getOrElse(NGQL, EMPTY_STRING) + if (ngql != EMPTY_STRING) { require(parameters.isDefinedAt(GRAPH_ADDRESS), - s"option $GRAPH_ADDRESS is required for ngql and can not be blank") + s"option $GRAPH_ADDRESS is required for ngql and can not be blank") graphAddress = parameters(GRAPH_ADDRESS) } } @@ -187,13 +187,11 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( def getMetaAddress: List[Address] = { val hostPorts: ListBuffer[Address] = new ListBuffer[Address] - metaAddress - .split(",") - .foreach(hostPort => { - // check host & port by getting HostAndPort - val addr = HostAndPort.fromString(hostPort) - hostPorts.append((addr.getHostText, addr.getPort)) - }) + for (hostPort <- metaAddress.split(",")) { + // check host & port by getting HostAndPort + val addr = HostAndPort.fromString(hostPort) + hostPorts.append((addr.getHostText, addr.getPort)) + } hostPorts.toList } @@ -211,9 +209,6 @@ class NebulaOptions(@transient val parameters: CaseInsensitiveMap[String])( } -class NebulaOptionsInWrite(@transient override val parameters: CaseInsensitiveMap[String]) - extends NebulaOptions(parameters)(OperaType.WRITE) {} - object NebulaOptions { /** nebula common config */ @@ -237,14 +232,17 @@ object NebulaOptions { val CA_SIGN_PARAM: String = "caSignParam" val SELF_SIGN_PARAM: String = "selfSignParam" + val OPERATE_TYPE: String = "operateType" + /** read config */ - val RETURN_COLS: String = "returnCols" - val NO_COLUMN: String = "noColumn" - val PARTITION_NUMBER: String = "partitionNumber" - val LIMIT: String = "limit" + val RETURN_COLS: String = "returnCols" + val NO_COLUMN: String = "noColumn" + val PARTITION_NUMBER: String = "partitionNumber" + val LIMIT: String = "limit" + val PUSHDOWN_FILTERS_ENABLE: String = "pushDownFiltersEnable" /** read by ngql **/ - val NGQL: String = "ngql" + val NGQL: String = "ngql" /** write config */ val RATE_LIMIT: String = "rateLimit" diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala index 47ad2d4c..bf48e2a6 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaUtils.scala @@ -7,22 +7,27 @@ package com.vesoft.nebula.connector import com.vesoft.nebula.PropertyType import com.vesoft.nebula.client.graph.data.{DateTimeWrapper, DurationWrapper, TimeWrapper} +import com.vesoft.nebula.connector.nebula.MetaProvider import com.vesoft.nebula.meta.{ColumnDef, ColumnTypeDef} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{ BooleanType, DataType, + DataTypes, DoubleType, FloatType, IntegerType, LongType, StringType, + StructField, StructType, TimestampType } import org.apache.spark.unsafe.types.UTF8String import org.slf4j.LoggerFactory +import scala.collection.mutable.ListBuffer + object NebulaUtils { private val LOG = LoggerFactory.getLogger(this.getClass) @@ -156,4 +161,70 @@ object NebulaUtils { s } + /** + * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. + */ + def getSchema(nebulaOptions: NebulaOptions): StructType = { + val returnCols = nebulaOptions.getReturnCols + val noColumn = nebulaOptions.noColumn + val fields: ListBuffer[StructField] = new ListBuffer[StructField] + val metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + + import scala.collection.JavaConverters._ + var schemaCols: Seq[ColumnDef] = Seq() + val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) + + // construct vertex or edge default prop + if (isVertex) { + fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) + } else { + fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) + fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) + fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) + } + + var dataSchema: StructType = null + // read no column + if (noColumn) { + dataSchema = new StructType(fields.toArray) + return dataSchema + } + // get tag schema or edge schema + val schema = if (isVertex) { + metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) + } else { + metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) + } + + schemaCols = schema.columns.asScala + + // read all columns + if (returnCols.isEmpty) { + schemaCols.foreach(columnDef => { + LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") + fields.append( + DataTypes.createStructField(new String(columnDef.getName), + NebulaUtils.convertDataType(columnDef.getType), + true)) + }) + } else { + for (col: String <- returnCols) { + fields.append( + DataTypes + .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) + } + } + dataSchema = new StructType(fields.toArray) + dataSchema + } + } diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala new file mode 100644 index 00000000..d48e4034 --- /dev/null +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala @@ -0,0 +1,295 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.client.graph.data.{ + CASignedSSLParam, + HostAddress, + SSLParam, + SelfSignedSSLParam, + ValueWrapper +} +import com.vesoft.nebula.client.storage.StorageClient +import com.vesoft.nebula.client.storage.data.BaseTableRow +import com.vesoft.nebula.client.storage.scan.{ + ScanEdgeResult, + ScanEdgeResultIterator, + ScanVertexResult, + ScanVertexResultIterator +} +import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} +import com.vesoft.nebula.connector.exception.GraphConnectException +import com.vesoft.nebula.connector.nebula.MetaProvider +import com.vesoft.nebula.connector.ssl.SSLSignType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +trait NebulaReader { + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var metaProvider: MetaProvider = _ + private var schema: StructType = _ + + protected var dataIterator: Iterator[BaseTableRow] = _ + protected var scanPartIterator: Iterator[Integer] = _ + protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]() + protected var storageClient: StorageClient = _ + protected var nebulaOptions: NebulaOptions = _ + + private var vertexResponseIterator: ScanVertexResultIterator = _ + private var edgeResponseIterator: ScanEdgeResultIterator = _ + + /** + * init the reader: init metaProvider, storageClient + */ + def init(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { + this.schema = schema + this.nebulaOptions = nebulaOptions + + metaProvider = new MetaProvider( + nebulaOptions.getMetaAddress, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + nebulaOptions.enableMetaSSL, + nebulaOptions.sslSignType, + nebulaOptions.caSignParam, + nebulaOptions.selfSignParam + ) + val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] + + for (addr <- nebulaOptions.getMetaAddress) { + address.append(new HostAddress(addr._1, addr._2)) + } + + var sslParam: SSLParam = null + if (nebulaOptions.enableStorageSSL) { + SSLSignType.withName(nebulaOptions.sslSignType) match { + case SSLSignType.CA => { + val caSSLSignParams = nebulaOptions.caSignParam + sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, + caSSLSignParams.crtFilePath, + caSSLSignParams.keyFilePath) + } + case SSLSignType.SELF => { + val selfSSLSignParams = nebulaOptions.selfSignParam + sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, + selfSSLSignParams.keyFilePath, + selfSSLSignParams.password) + } + case _ => throw new IllegalArgumentException("ssl sign type is not supported") + } + this.storageClient = new StorageClient(address.asJava, + nebulaOptions.timeout, + nebulaOptions.connectionRetry, + nebulaOptions.executionRetry, + true, + sslParam) + } else { + this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) + } + + if (!storageClient.connect()) { + throw new GraphConnectException("storage connect failed.") + } + // allocate scanPart to this partition + val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) + + // index starts with 1 + val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) + LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") + scanPartIterator = scanParts.iterator + } + + /** + * resolve the vertex/edge data to InternalRow + */ + protected def getRow(): InternalRow = { + val resultSet: Array[ValueWrapper] = + dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper]) + val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) + val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) + + for (i <- getters.indices) { + val value: ValueWrapper = resultSet(i) + var resolved = false + if (value.isNull) { + mutableRow.setNullAt(i) + resolved = true + } + if (value.isString) { + getters(i).apply(value.asString(), mutableRow, i) + resolved = true + } + if (value.isDate) { + getters(i).apply(value.asDate(), mutableRow, i) + resolved = true + } + if (value.isTime) { + getters(i).apply(value.asTime(), mutableRow, i) + resolved = true + } + if (value.isDateTime) { + getters(i).apply(value.asDateTime(), mutableRow, i) + resolved = true + } + if (value.isLong) { + getters(i).apply(value.asLong(), mutableRow, i) + } + if (value.isBoolean) { + getters(i).apply(value.asBoolean(), mutableRow, i) + } + if (value.isDouble) { + getters(i).apply(value.asDouble(), mutableRow, i) + } + if (value.isGeography) { + getters(i).apply(value.asGeography(), mutableRow, i) + } + if (value.isDuration) { + getters(i).apply(value.asDuration(), mutableRow, i) + } + } + mutableRow + } + + /** + * if the scan response has next vertex + */ + protected def hasNextVertexRow: Boolean = { + { + if (dataIterator == null && vertexResponseIterator == null && !scanPartIterator.hasNext) + return false + + var continue: Boolean = false + var break: Boolean = false + while ((dataIterator == null || !dataIterator.hasNext) && !break) { + resultValues.clear() + continue = false + if (vertexResponseIterator == null || !vertexResponseIterator.hasNext) { + if (scanPartIterator.hasNext) { + try { + if (nebulaOptions.noColumn) { + vertexResponseIterator = storageClient.scanVertex(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } else { + vertexResponseIterator = storageClient.scanVertex( + nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.getReturnCols.asJava, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } + } catch { + case e: Exception => + LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) + storageClient.close() + throw new Exception(e.getMessage, e) + } + // jump to the next loop + continue = true + } + // break while loop + break = !continue + } else { + val next: ScanVertexResult = vertexResponseIterator.next + if (!next.isEmpty) { + dataIterator = next.getVertexTableRows.iterator().asScala + } + } + } + + if (dataIterator == null) { + return false + } + dataIterator.hasNext + } + } + + /** + * if the scan response has next edge + */ + protected def hasNextEdgeRow: Boolean = { + if (dataIterator == null && edgeResponseIterator == null && !scanPartIterator.hasNext) + return false + + var continue: Boolean = false + var break: Boolean = false + while ((dataIterator == null || !dataIterator.hasNext) && !break) { + resultValues.clear() + continue = false + if (edgeResponseIterator == null || !edgeResponseIterator.hasNext) { + if (scanPartIterator.hasNext) { + try { + if (nebulaOptions.noColumn) { + edgeResponseIterator = storageClient.scanEdge(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.limit, + 0L, + Long.MaxValue, + true, + true) + } else { + edgeResponseIterator = storageClient.scanEdge(nebulaOptions.spaceName, + scanPartIterator.next(), + nebulaOptions.label, + nebulaOptions.getReturnCols.asJava, + nebulaOptions.limit, + 0, + Long.MaxValue, + true, + true) + } + } catch { + case e: Exception => + LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) + storageClient.close() + throw new Exception(e.getMessage, e) + } + // jump to the next loop + continue = true + } + // break while loop + break = !continue + } else { + val next: ScanEdgeResult = edgeResponseIterator.next + if (!next.isEmpty) { + dataIterator = next.getEdgeTableRows.iterator().asScala + } + } + } + + if (dataIterator == null) { + return false + } + dataIterator.hasNext + } + + /** + * close the reader + */ + protected def closeReader(): Unit = { + metaProvider.close() + storageClient.close() + } +} diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index c298a890..4233ec81 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -42,7 +42,7 @@ class NebulaDataSource * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. */ override def createReader(options: DataSourceOptions): DataSourceReader = { - val nebulaOptions = getNebulaOptions(options, OperaType.READ) + val nebulaOptions = getNebulaOptions(options) val dataType = nebulaOptions.dataType LOG.info("create reader") @@ -65,7 +65,7 @@ class NebulaDataSource mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = { - val nebulaOptions = getNebulaOptions(options, OperaType.WRITE) + val nebulaOptions = getNebulaOptions(options) val dataType = nebulaOptions.dataType if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { LOG.warn(s"Currently do not support mode") @@ -140,12 +140,12 @@ class NebulaDataSource /** * construct nebula options with DataSourceOptions */ - def getNebulaOptions(options: DataSourceOptions, operateType: OperaType.Value): NebulaOptions = { + def getNebulaOptions(options: DataSourceOptions): NebulaOptions = { var parameters: Map[String, String] = Map() for (entry: Entry[String, String] <- options.asMap().entrySet) { parameters += (entry.getKey -> entry.getValue) } - val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters))(operateType) + val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters)) nebulaOptions } } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala index b54f5a14..605a36dc 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala @@ -5,72 +5,11 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator} import com.vesoft.nebula.connector.NebulaOptions import org.apache.spark.sql.types.StructType -import org.slf4j.{Logger, LoggerFactory} -import scala.collection.JavaConverters._ class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) extends NebulaPartitionReader(index, nebulaOptions, schema) { - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - private var responseIterator: ScanEdgeResultIterator = _ - - override def next(): Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0L, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanEdgeResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getEdgeTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } + override def next(): Boolean = hasNextEdgeRow } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala index e72f3460..c491662d 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -5,44 +5,19 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.graph.data.{ - CASignedSSLParam, - HostAddress, - SSLParam, - SelfSignedSSLParam, - ValueWrapper -} -import com.vesoft.nebula.client.storage.StorageClient -import com.vesoft.nebula.client.storage.data.BaseTableRow -import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter -import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} -import com.vesoft.nebula.connector.exception.GraphConnectException -import com.vesoft.nebula.connector.nebula.MetaProvider -import com.vesoft.nebula.connector.ssl.SSLSignType +import com.vesoft.nebula.connector.{NebulaOptions} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow + import org.apache.spark.sql.sources.v2.reader.InputPartitionReader import org.apache.spark.sql.types.StructType import org.slf4j.{Logger, LoggerFactory} -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - /** * Read nebula data for each spark partition */ -abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] { +abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] with NebulaReader { private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - private var metaProvider: MetaProvider = _ - private var schema: StructType = _ - - protected var dataIterator: Iterator[BaseTableRow] = _ - protected var scanPartIterator: Iterator[Integer] = _ - protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]() - protected var storageClient: StorageClient = _ - /** * @param index identifier for spark partition * @param nebulaOptions nebula Options @@ -50,113 +25,12 @@ abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] { */ def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { this() - this.schema = schema - - metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] - - for (addr <- nebulaOptions.getMetaAddress) { - address.append(new HostAddress(addr._1, addr._2)) - } - - var sslParam: SSLParam = null - if (nebulaOptions.enableStorageSSL) { - SSLSignType.withName(nebulaOptions.sslSignType) match { - case SSLSignType.CA => { - val caSSLSignParams = nebulaOptions.caSignParam - sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, - caSSLSignParams.crtFilePath, - caSSLSignParams.keyFilePath) - } - case SSLSignType.SELF => { - val selfSSLSignParams = nebulaOptions.selfSignParam - sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, - selfSSLSignParams.keyFilePath, - selfSSLSignParams.password) - } - case _ => throw new IllegalArgumentException("ssl sign type is not supported") - } - this.storageClient = new StorageClient(address.asJava, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - true, - sslParam) - } else { - this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) - } - - if (!storageClient.connect()) { - throw new GraphConnectException("storage connect failed.") - } - // allocate scanPart to this partition - val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) - - // index starts with 1 - val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) - LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") - scanPartIterator = scanParts.iterator + super.init(index, nebulaOptions, schema) } - override def get(): InternalRow = { - val resultSet: Array[ValueWrapper] = - dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper]) - val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) - val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) - - for (i <- getters.indices) { - val value: ValueWrapper = resultSet(i) - var resolved = false - if (value.isNull) { - mutableRow.setNullAt(i) - resolved = true - } - if (value.isString) { - getters(i).apply(value.asString(), mutableRow, i) - resolved = true - } - if (value.isDate) { - getters(i).apply(value.asDate(), mutableRow, i) - resolved = true - } - if (value.isTime) { - getters(i).apply(value.asTime(), mutableRow, i) - resolved = true - } - if (value.isDateTime) { - getters(i).apply(value.asDateTime(), mutableRow, i) - resolved = true - } - if (value.isLong) { - getters(i).apply(value.asLong(), mutableRow, i) - } - if (value.isBoolean) { - getters(i).apply(value.asBoolean(), mutableRow, i) - } - if (value.isDouble) { - getters(i).apply(value.asDouble(), mutableRow, i) - } - if (value.isGeography) { - getters(i).apply(value.asGeography(), mutableRow, i) - } - if (value.isDuration) { - getters(i).apply(value.asDuration(), mutableRow, i) - } - } - mutableRow - } + override def get(): InternalRow = super.getRow() override def close(): Unit = { - metaProvider.close() - storageClient.close() + super.closeReader() } } diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala index 0ca8ee54..d859570e 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala @@ -7,15 +7,12 @@ package com.vesoft.nebula.connector.reader import java.util -import com.vesoft.nebula.connector.nebula.MetaProvider -import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils} -import com.vesoft.nebula.meta.ColumnDef +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition} -import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{StructType} import org.slf4j.LoggerFactory -import scala.collection.mutable.ListBuffer import scala.collection.JavaConverters._ /** @@ -27,78 +24,16 @@ abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSour private var datasetSchema: StructType = _ override def readSchema(): StructType = { - datasetSchema = getSchema(nebulaOptions) + if (datasetSchema == null) { + datasetSchema = NebulaUtils.getSchema(nebulaOptions) + } + LOG.info(s"dataset's schema: $datasetSchema") datasetSchema } - protected def getSchema: StructType = getSchema(nebulaOptions) - - /** - * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. - */ - def getSchema(nebulaOptions: NebulaOptions): StructType = { - val returnCols = nebulaOptions.getReturnCols - val noColumn = nebulaOptions.noColumn - val fields: ListBuffer[StructField] = new ListBuffer[StructField] - val metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - - import scala.collection.JavaConverters._ - var schemaCols: Seq[ColumnDef] = Seq() - val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) - - // construct vertex or edge default prop - if (isVertex) { - fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) - } else { - fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) - } - - var dataSchema: StructType = null - // read no column - if (noColumn) { - dataSchema = new StructType(fields.toArray) - return dataSchema - } - // get tag schema or edge schema - val schema = if (isVertex) { - metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) - } else { - metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) - } - - schemaCols = schema.columns.asScala - - // read all columns - if (returnCols.isEmpty) { - schemaCols.foreach(columnDef => { - LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") - fields.append( - DataTypes.createStructField(new String(columnDef.getName), - NebulaUtils.convertDataType(columnDef.getType), - true)) - }) - } else { - for (col: String <- returnCols) { - fields.append( - DataTypes - .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) - } - } - dataSchema = new StructType(fields.toArray) - dataSchema - } + protected def getSchema: StructType = + if (datasetSchema == null) NebulaUtils.getSchema(nebulaOptions) else datasetSchema } /** diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala index 3466b1dd..d6d056d2 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala @@ -5,75 +5,11 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator} import com.vesoft.nebula.connector.NebulaOptions import org.apache.spark.sql.types.StructType -import org.slf4j.{Logger, LoggerFactory} - -import scala.collection.JavaConverters._ class NebulaVertexPartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) extends NebulaPartitionReader(index, nebulaOptions, schema) { - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - - private var responseIterator: ScanVertexResultIterator = _ - - override def next(): Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanVertexResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getVertexTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } - + override def next(): Boolean = hasNextVertexRow } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index 8d978501..6019dc3c 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -43,7 +43,7 @@ class NebulaDataSource */ override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val nebulaOptions = getNebulaOptions(parameters, OperaType.READ) + val nebulaOptions = getNebulaOptions(parameters) LOG.info("create relation") LOG.info(s"options ${parameters}") @@ -59,7 +59,7 @@ class NebulaDataSource parameters: Map[String, String], data: DataFrame): BaseRelation = { - val nebulaOptions = getNebulaOptions(parameters, OperaType.WRITE) + val nebulaOptions = getNebulaOptions(parameters) if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { LOG.warn(s"Currently do not support mode") } @@ -78,9 +78,8 @@ class NebulaDataSource /** * construct nebula options with DataSourceOptions */ - def getNebulaOptions(options: Map[String, String], - operateType: OperaType.Value): NebulaOptions = { - val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options))(operateType) + def getNebulaOptions(options: Map[String, String]): NebulaOptions = { + val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options)) nebulaOptions } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala new file mode 100644 index 00000000..ef53ac54 --- /dev/null +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala @@ -0,0 +1,18 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.Partition +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +class NebulaEdgePartitionReader(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) + extends NebulaIterator(index, nebulaOptions, schema) { + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + override def hasNext(): Boolean = hasNextEdgeRow +} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala deleted file mode 100644 index 45fab6c8..00000000 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgeReader.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2022 vesoft inc. All rights reserved. - * - * This source code is licensed under Apache 2.0 License. - */ - -package com.vesoft.nebula.connector.reader - -import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator} -import com.vesoft.nebula.connector.NebulaOptions -import org.apache.spark.Partition -import org.apache.spark.sql.types.StructType -import org.slf4j.LoggerFactory -import scala.collection.JavaConverters._ - -class NebulaEdgeReader(split: Partition, nebulaOptions: NebulaOptions, schema: StructType) - extends NebulaIterator(split, nebulaOptions, schema) { - private val LOG = LoggerFactory.getLogger(this.getClass) - - private var responseIterator: ScanEdgeResultIterator = _ - - override def hasNext: Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0L, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanEdgeResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getEdgeTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } -} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala index dc7b085e..d03e5f50 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala @@ -5,163 +5,30 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.graph.data.{ - CASignedSSLParam, - HostAddress, - SSLParam, - SelfSignedSSLParam, - ValueWrapper -} -import com.vesoft.nebula.client.storage.data.BaseTableRow -import com.vesoft.nebula.client.storage.StorageClient import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} -import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter -import com.vesoft.nebula.connector.exception.GraphConnectException -import com.vesoft.nebula.connector.nebula.MetaProvider -import com.vesoft.nebula.connector.ssl.SSLSignType + import org.apache.spark.Partition import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.types.StructType -import org.slf4j.{Logger, LoggerFactory} - -import scala.collection.mutable -import scala.collection.mutable.ListBuffer -import scala.collection.JavaConverters._ /** - * @todo * iterator for nebula vertex or edge data * convert each vertex data or edge data to Spark SQL's Row */ -abstract class NebulaIterator extends Iterator[InternalRow] { - - private val LOG: Logger = LoggerFactory.getLogger(classOf[NebulaIterator]) - - private var metaProvider: MetaProvider = _ - private var schema: StructType = _ - - protected var dataIterator: Iterator[BaseTableRow] = _ - protected var scanPartIterator: Iterator[Integer] = _ - protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]() - protected var storageClient: StorageClient = _ +abstract class NebulaIterator extends Iterator[InternalRow] with NebulaReader { def this(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) { this() - this.schema = schema - - metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] - - for (addr <- nebulaOptions.getMetaAddress) { - address.append(new HostAddress(addr._1, addr._2)) - } - - var sslParam: SSLParam = null - if (nebulaOptions.enableStorageSSL) { - SSLSignType.withName(nebulaOptions.sslSignType) match { - case SSLSignType.CA => { - val caSSLSignParams = nebulaOptions.caSignParam - sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, - caSSLSignParams.crtFilePath, - caSSLSignParams.keyFilePath) - } - case SSLSignType.SELF => { - val selfSSLSignParams = nebulaOptions.selfSignParam - sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, - selfSSLSignParams.keyFilePath, - selfSSLSignParams.password) - } - case _ => throw new IllegalArgumentException("ssl sign type is not supported") - } - this.storageClient = new StorageClient(address.asJava, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - true, - sslParam) - } else { - this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) - } - - if (!storageClient.connect()) { - throw new GraphConnectException("storage connect failed.") - } - // allocate scanPart to this partition - val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) - - val nebulaPartition = index.asInstanceOf[NebulaPartition] - val scanParts = - nebulaPartition.getScanParts(totalPart, nebulaOptions.partitionNums.toInt) - LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") - scanPartIterator = scanParts.iterator + super.init(index.index, nebulaOptions, schema) } /** - * @todo * whether this iterator can provide another element. */ override def hasNext: Boolean /** - * @todo * Produces the next vertex or edge of this iterator. */ - override def next(): InternalRow = { - val resultSet: Array[ValueWrapper] = - dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper]) - val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) - val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) - - for (i <- getters.indices) { - val value: ValueWrapper = resultSet(i) - var resolved = false - if (value.isNull) { - mutableRow.setNullAt(i) - resolved = true - } - if (value.isString) { - getters(i).apply(value.asString(), mutableRow, i) - resolved = true - } - if (value.isDate) { - getters(i).apply(value.asDate(), mutableRow, i) - resolved = true - } - if (value.isTime) { - getters(i).apply(value.asTime(), mutableRow, i) - resolved = true - } - if (value.isDateTime) { - getters(i).apply(value.asDateTime(), mutableRow, i) - resolved = true - } - if (value.isLong) { - getters(i).apply(value.asLong(), mutableRow, i) - } - if (value.isBoolean) { - getters(i).apply(value.asBoolean(), mutableRow, i) - } - if (value.isDouble) { - getters(i).apply(value.asDouble(), mutableRow, i) - } - if (value.isGeography) { - getters(i).apply(value.asGeography(), mutableRow, i) - } - if (value.isDuration) { - getters(i).apply(value.asDuration(), mutableRow, i) - } - } - mutableRow - } - + override def next(): InternalRow = super.getRow() } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala index bbe5118c..3e29b822 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala @@ -27,8 +27,8 @@ class NebulaRDD(val sqlContext: SQLContext, var nebulaOptions: NebulaOptions, sc override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val dataType = nebulaOptions.dataType if (DataTypeEnum.VERTEX.toString.equalsIgnoreCase(dataType)) - new NebulaVertexReader(split, nebulaOptions, schema) - else new NebulaEdgeReader(split, nebulaOptions, schema) + new NebulaVertexPartitionReader(split, nebulaOptions, schema) + else new NebulaEdgePartitionReader(split, nebulaOptions, schema) } override def getPartitions = { diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala index 53733061..129f1b4d 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala @@ -5,95 +5,35 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.connector.nebula.MetaProvider -import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions, NebulaUtils} -import com.vesoft.nebula.meta.ColumnDef +import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.sources.{BaseRelation, TableScan} -import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.sql.types.{StructType} import org.slf4j.LoggerFactory -import scala.collection.mutable.ListBuffer - case class NebulaRelation(override val sqlContext: SQLContext, nebulaOptions: NebulaOptions) extends BaseRelation with TableScan { private val LOG = LoggerFactory.getLogger(this.getClass) - protected lazy val datasetSchema: StructType = getSchema(nebulaOptions) + protected var datasetSchema: StructType = _ + NebulaUtils.getSchema(nebulaOptions) override val needConversion: Boolean = false - override def schema: StructType = getSchema(nebulaOptions) - - /** - * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. - */ - private def getSchema(nebulaOptions: NebulaOptions): StructType = { - val returnCols = nebulaOptions.getReturnCols - val noColumn = nebulaOptions.noColumn - val fields: ListBuffer[StructField] = new ListBuffer[StructField] - val metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - - import scala.collection.JavaConverters._ - var schemaCols: Seq[ColumnDef] = Seq() - val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) - - // construct vertex or edge default prop - if (isVertex) { - fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) - } else { - fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) - } - - var dataSchema: StructType = null - // read no column - if (noColumn) { - dataSchema = new StructType(fields.toArray) - return dataSchema + override def schema: StructType = { + if (datasetSchema == null) { + datasetSchema = NebulaUtils.getSchema(nebulaOptions) } - // get tag schema or edge schema - val schema = if (isVertex) { - metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) - } else { - metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) - } - - schemaCols = schema.columns.asScala - - // read all columns - if (returnCols.isEmpty) { - schemaCols.foreach(columnDef => { - LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") - fields.append( - DataTypes.createStructField(new String(columnDef.getName), - NebulaUtils.convertDataType(columnDef.getType), - true)) - }) - } else { - for (col: String <- returnCols) { - fields.append( - DataTypes - .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) - } - } - dataSchema = new StructType(fields.toArray) - dataSchema + datasetSchema } override def buildScan(): RDD[Row] = { + + if (datasetSchema == null) { + datasetSchema = NebulaUtils.getSchema(nebulaOptions) + } if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) { new NebulaNgqlRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] } else { diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala new file mode 100644 index 00000000..2f4aadb5 --- /dev/null +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala @@ -0,0 +1,19 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector.reader + +import com.vesoft.nebula.connector.NebulaOptions +import org.apache.spark.Partition +import org.apache.spark.sql.types.StructType + +class NebulaVertexPartitionReader(index: Partition, + nebulaOptions: NebulaOptions, + schema: StructType) + extends NebulaIterator(index, nebulaOptions, schema) { + + override def hasNext: Boolean = hasNextVertexRow + +} diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala deleted file mode 100644 index c27dc339..00000000 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexReader.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2022 vesoft inc. All rights reserved. - * - * This source code is licensed under Apache 2.0 License. - */ - -package com.vesoft.nebula.connector.reader - -import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator} -import com.vesoft.nebula.connector.NebulaOptions -import org.apache.spark.Partition -import org.apache.spark.sql.types.StructType -import org.slf4j.LoggerFactory -import scala.collection.JavaConverters._ - -class NebulaVertexReader(split: Partition, nebulaOptions: NebulaOptions, schema: StructType) - extends NebulaIterator(split, nebulaOptions, schema) { - - private val LOG = LoggerFactory.getLogger(this.getClass) - - private var responseIterator: ScanVertexResultIterator = _ - - override def hasNext: Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanVertexResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getVertexTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } -} diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml index bcfae8ec..fae668eb 100644 --- a/nebula-spark-connector_3.0/pom.xml +++ b/nebula-spark-connector_3.0/pom.xml @@ -15,6 +15,10 @@ 1.8 1.8 + 2.12 + 3.0.0 + 2.12.10 + 3.2.3 @@ -22,6 +26,45 @@ com.vesoft nebula-spark-common ${project.version} + + + spark-core_2.11 + org.apache.spark + + + scalatest-funsuite_2.11 + org.scalatest + + + spark-graphx_2.11 + org.apache.spark + + + spark-sql_2.11 + org.apache.spark + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + + + org.apache.spark + spark-graphx_${scala.binary.version} + ${spark.version} + + + + org.scalatest + scalatest-funsuite_2.12 + ${scalatest.version} @@ -74,48 +117,25 @@ - org.apache.maven.plugins - maven-shade-plugin - 3.2.1 + maven-assembly-plugin + 2.5.3 package - shade + single - - false - - - org.apache.spark:* - org.apache.hadoop:* - org.apache.hive:* - log4j:log4j - org.apache.orc:* - xml-apis:xml-apis - javax.inject:javax.inject - org.spark-project.hive:hive-exec - stax:stax-api - org.glassfish.hk2.external:aopalliance-repackaged - - - - - - *:* - - com/vesoft/tools/** - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - + + + jar-with-dependencies + + ${project.artifactId}-${project.version}-jar-with-dependencies + false + @@ -139,7 +159,6 @@ - com/vesoft/tools/** META-INF/*.SF META-INF/*.DSA META-INF/*.RSA @@ -151,11 +170,6 @@ testCompile - - - com/vesoft/tools/** - - @@ -240,10 +254,6 @@ org.apache.maven.plugins maven-javadoc-plugin 3.2.0 - - com.facebook.thrift:com.facebook.thrift.* - - attach-javadocs diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index 0c6ef613..b7764e8d 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -50,7 +50,7 @@ class NebulaDataSource extends TableProvider with DataSourceRegister { if (schema == null) { nebulaOptions = getNebulaOptions(caseInsensitiveStringMap) if (nebulaOptions.operaType == OperaType.READ) { - schema = getSchema(nebulaOptions) + schema = NebulaUtils.getSchema(nebulaOptions) } else { schema = new StructType() } @@ -64,72 +64,6 @@ class NebulaDataSource extends TableProvider with DataSourceRegister { new NebulaTable(tableSchema, nebulaOptions) } - /** - * return the dataset's schema. Schema includes configured cols in returnCols or includes all properties in nebula. - */ - private def getSchema(nebulaOptions: NebulaOptions): StructType = { - val returnCols = nebulaOptions.getReturnCols - val noColumn = nebulaOptions.noColumn - val fields: ListBuffer[StructField] = new ListBuffer[StructField] - val metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - - import scala.collection.JavaConverters._ - var schemaCols: Seq[ColumnDef] = Seq() - val isVertex = DataTypeEnum.VERTEX.toString.equalsIgnoreCase(nebulaOptions.dataType) - - // construct vertex or edge default prop - if (isVertex) { - fields.append(DataTypes.createStructField("_vertexId", DataTypes.StringType, false)) - } else { - fields.append(DataTypes.createStructField("_srcId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_dstId", DataTypes.StringType, false)) - fields.append(DataTypes.createStructField("_rank", DataTypes.LongType, false)) - } - - var dataSchema: StructType = null - // read no column - if (noColumn) { - dataSchema = new StructType(fields.toArray) - return dataSchema - } - // get tag schema or edge schema - val schema = if (isVertex) { - metaProvider.getTag(nebulaOptions.spaceName, nebulaOptions.label) - } else { - metaProvider.getEdge(nebulaOptions.spaceName, nebulaOptions.label) - } - - schemaCols = schema.columns.asScala - - // read all columns - if (returnCols.isEmpty) { - schemaCols.foreach(columnDef => { - LOG.info(s"prop name ${new String(columnDef.getName)}, type ${columnDef.getType.getType} ") - fields.append( - DataTypes.createStructField(new String(columnDef.getName), - NebulaUtils.convertDataType(columnDef.getType), - true)) - }) - } else { - for (col: String <- returnCols) { - fields.append( - DataTypes - .createStructField(col, NebulaUtils.getColDataType(schemaCols.toList, col), true)) - } - } - dataSchema = new StructType(fields.toArray) - dataSchema - } - /** * construct nebula options with DataSourceOptions */ diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala index d1a607d9..e1062d49 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala @@ -5,72 +5,11 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.storage.scan.{ScanEdgeResult, ScanEdgeResultIterator} import com.vesoft.nebula.connector.NebulaOptions import org.apache.spark.sql.types.StructType -import org.slf4j.{Logger, LoggerFactory} -import scala.collection.JavaConverters._ class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) extends NebulaPartitionReader(index, nebulaOptions, schema) { - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - private var responseIterator: ScanEdgeResultIterator = _ - - override def next(): Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0L, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanEdge(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanEdgeResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getEdgeTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } + override def next(): Boolean = hasNextEdgeRow } diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala index a6fb5704..6b9878db 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -5,44 +5,18 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.graph.data.{ - CASignedSSLParam, - HostAddress, - SSLParam, - SelfSignedSSLParam, - ValueWrapper -} -import com.vesoft.nebula.client.storage.StorageClient -import com.vesoft.nebula.client.storage.data.BaseTableRow -import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter -import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} -import com.vesoft.nebula.connector.exception.GraphConnectException -import com.vesoft.nebula.connector.nebula.MetaProvider -import com.vesoft.nebula.connector.ssl.SSLSignType +import com.vesoft.nebula.connector.{NebulaOptions} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType import org.slf4j.{Logger, LoggerFactory} -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - /** * Read nebula data for each spark partition */ -abstract class NebulaPartitionReader extends PartitionReader[InternalRow] { +abstract class NebulaPartitionReader extends PartitionReader[InternalRow] with NebulaReader { private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - private var metaProvider: MetaProvider = _ - private var schema: StructType = _ - - protected var dataIterator: Iterator[BaseTableRow] = _ - protected var scanPartIterator: Iterator[Integer] = _ - protected var resultValues: mutable.ListBuffer[List[Object]] = mutable.ListBuffer[List[Object]]() - protected var storageClient: StorageClient = _ - /** * @param index identifier for spark partition * @param nebulaOptions nebula Options @@ -50,113 +24,12 @@ abstract class NebulaPartitionReader extends PartitionReader[InternalRow] { */ def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { this() - this.schema = schema - - metaProvider = new MetaProvider( - nebulaOptions.getMetaAddress, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - nebulaOptions.enableMetaSSL, - nebulaOptions.sslSignType, - nebulaOptions.caSignParam, - nebulaOptions.selfSignParam - ) - val address: ListBuffer[HostAddress] = new ListBuffer[HostAddress] - - for (addr <- nebulaOptions.getMetaAddress) { - address.append(new HostAddress(addr._1, addr._2)) - } - - var sslParam: SSLParam = null - if (nebulaOptions.enableStorageSSL) { - SSLSignType.withName(nebulaOptions.sslSignType) match { - case SSLSignType.CA => { - val caSSLSignParams = nebulaOptions.caSignParam - sslParam = new CASignedSSLParam(caSSLSignParams.caCrtFilePath, - caSSLSignParams.crtFilePath, - caSSLSignParams.keyFilePath) - } - case SSLSignType.SELF => { - val selfSSLSignParams = nebulaOptions.selfSignParam - sslParam = new SelfSignedSSLParam(selfSSLSignParams.crtFilePath, - selfSSLSignParams.keyFilePath, - selfSSLSignParams.password) - } - case _ => throw new IllegalArgumentException("ssl sign type is not supported") - } - this.storageClient = new StorageClient(address.asJava, - nebulaOptions.timeout, - nebulaOptions.connectionRetry, - nebulaOptions.executionRetry, - true, - sslParam) - } else { - this.storageClient = new StorageClient(address.asJava, nebulaOptions.timeout) - } - - if (!storageClient.connect()) { - throw new GraphConnectException("storage connect failed.") - } - // allocate scanPart to this partition - val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) - - // index starts with 1 - val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) - LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") - scanPartIterator = scanParts.iterator + super.init(index, nebulaOptions, schema) } - override def get(): InternalRow = { - val resultSet: Array[ValueWrapper] = - dataIterator.next().getValues.toArray.map(v => v.asInstanceOf[ValueWrapper]) - val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) - val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) - - for (i <- getters.indices) { - val value: ValueWrapper = resultSet(i) - var resolved = false - if (value.isNull) { - mutableRow.setNullAt(i) - resolved = true - } - if (value.isString) { - getters(i).apply(value.asString(), mutableRow, i) - resolved = true - } - if (value.isDate) { - getters(i).apply(value.asDate(), mutableRow, i) - resolved = true - } - if (value.isTime) { - getters(i).apply(value.asTime(), mutableRow, i) - resolved = true - } - if (value.isDateTime) { - getters(i).apply(value.asDateTime(), mutableRow, i) - resolved = true - } - if (value.isLong) { - getters(i).apply(value.asLong(), mutableRow, i) - } - if (value.isBoolean) { - getters(i).apply(value.asBoolean(), mutableRow, i) - } - if (value.isDouble) { - getters(i).apply(value.asDouble(), mutableRow, i) - } - if (value.isGeography) { - getters(i).apply(value.asGeography(), mutableRow, i) - } - if (value.isDuration) { - getters(i).apply(value.asDuration(), mutableRow, i) - } - } - mutableRow - } + override def get(): InternalRow = super.getRow() override def close(): Unit = { - metaProvider.close() - storageClient.close() + super.closeReader() } } diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala index 3466b1dd..da5b02d2 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala @@ -5,75 +5,11 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.client.storage.scan.{ScanVertexResult, ScanVertexResultIterator} import com.vesoft.nebula.connector.NebulaOptions import org.apache.spark.sql.types.StructType -import org.slf4j.{Logger, LoggerFactory} -import scala.collection.JavaConverters._ - -class NebulaVertexPartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) - extends NebulaPartitionReader(index, nebulaOptions, schema) { - - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) - - private var responseIterator: ScanVertexResultIterator = _ - - override def next(): Boolean = { - if (dataIterator == null && responseIterator == null && !scanPartIterator.hasNext) - return false - - var continue: Boolean = false - var break: Boolean = false - while ((dataIterator == null || !dataIterator.hasNext) && !break) { - resultValues.clear() - continue = false - if (responseIterator == null || !responseIterator.hasNext) { - if (scanPartIterator.hasNext) { - try { - if (nebulaOptions.noColumn) { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } else { - responseIterator = storageClient.scanVertex(nebulaOptions.spaceName, - scanPartIterator.next(), - nebulaOptions.label, - nebulaOptions.getReturnCols.asJava, - nebulaOptions.limit, - 0, - Long.MaxValue, - true, - true) - } - } catch { - case e: Exception => - LOG.error(s"Exception scanning vertex ${nebulaOptions.label}", e) - storageClient.close() - throw new Exception(e.getMessage, e) - } - // jump to the next loop - continue = true - } - // break while loop - break = !continue - } else { - val next: ScanVertexResult = responseIterator.next - if (!next.isEmpty) { - dataIterator = next.getVertexTableRows.iterator().asScala - } - } - } - - if (dataIterator == null) { - return false - } - dataIterator.hasNext - } +class NebulaVertexPartitionReader(split: Int, nebulaOptions: NebulaOptions, schema: StructType) + extends NebulaPartitionReader(split, nebulaOptions, schema) { + override def next(): Boolean = hasNextVertexRow } diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala index 7009838b..20769571 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala @@ -46,8 +46,6 @@ class SimpleScanBuilder(nebulaOptions: NebulaOptions, schema: StructType) override def pruneColumns(requiredColumns: StructType): Unit = { if (!nebulaOptions.pushDownFiltersEnabled || requiredColumns == schema) { new StructType() - } else { - requiredColumns } } } From 7c5df59126b388b047cbe1bb4a3fa2d1721f6e15 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 18 Nov 2022 14:52:37 +0800 Subject: [PATCH 03/13] fix incompatible for spark --- .github/workflows/snapshot.yml | 21 +++++++++- .../connector/NebulaSparkReaderExample.scala | 6 +-- .../connector/NebulaSparkWriterExample.scala | 15 ++++--- nebula-spark-common/pom.xml | 6 +-- .../connector/reader/NebulaReader.scala | 8 +--- .../com/vesoft/nebula/connector/package.scala | 4 ++ .../reader/NebulaPartitionReader.scala | 9 +++-- .../com/vesoft/nebula/connector/package.scala | 4 ++ .../reader/NebulaEdgePartitionReader.scala | 1 - .../connector/reader/NebulaIterator.scala | 11 +++++- nebula-spark-connector_3.0/pom.xml | 39 ------------------- .../connector/reader/NebulaPartition.scala | 10 ----- .../reader/NebulaPartitionReader.scala | 8 +++- .../connector/reader/SimpleScanBuilder.scala | 19 +++++---- pom.xml | 35 +++++++++++++++++ 15 files changed, 111 insertions(+), 85 deletions(-) delete mode 100644 nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala diff --git a/.github/workflows/snapshot.yml b/.github/workflows/snapshot.yml index 2fee6328..0159d49e 100644 --- a/.github/workflows/snapshot.yml +++ b/.github/workflows/snapshot.yml @@ -40,10 +40,29 @@ jobs: popd popd - - name: Deploy SNAPSHOT to Sonatype + - name: Deploy SNAPSHOT for spark2.4 to Sonatype uses: samuelmeuli/action-maven-publish@v1 with: gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} nexus_username: ${{ secrets.OSSRH_USERNAME }} nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 + + - name: Deploy SNAPSHOT for spark2.2 to Sonatype + uses: samuelmeuli/action-maven-publish@v1 + with: + gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} + gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} + nexus_username: ${{ secrets.OSSRH_USERNAME }} + nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 + + - name: Deploy SNAPSHOT for spark3.0 to Sonatype + uses: samuelmeuli/action-maven-publish@v1 + with: + gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} + gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} + nexus_username: ${{ secrets.OSSRH_USERNAME }} + nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala index e78f7e8d..3d5a07d2 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkReaderExample.scala @@ -62,12 +62,10 @@ object NebulaSparkReaderExample { } def readEdges(spark: SparkSession): Unit = { - LOG.info("start to read nebula edges") - val config = NebulaConnectionConfig .builder() - .withMetaAddress("127.0.0.1:9559") + .withMetaAddress("192.168.8.171:9559") .withTimeout(6000) .withConenctionRetry(2) .build() @@ -76,7 +74,7 @@ object NebulaSparkReaderExample { .withSpace("test") .withLabel("knows") .withNoColumn(false) - .withReturnCols(List("degree")) + .withReturnCols(List()) .withLimit(10) .withPartitionNum(10) .build() diff --git a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala index 279935f9..ccde7fdf 100644 --- a/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/examples/connector/NebulaSparkWriterExample.scala @@ -94,16 +94,22 @@ object NebulaSparkWriterExample { * if your withVidAsProp is true, then tag schema also should have property name: id */ def writeVertex(spark: SparkSession): Unit = { - LOG.info("start to write nebula vertices") - val df = spark.read.json("example/src/main/resources/vertex") + val df = spark.read.json("vertex") df.show() - val config = getNebulaConnectionConfig() + val config = + NebulaConnectionConfig + .builder() + .withMetaAddress("192.168.8.171:9559") + .withGraphAddress("192.168.8.171:9669") + .withConenctionRetry(2) + .build() val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig .builder() .withSpace("test") .withTag("person") .withVidField("id") + .withWriteMode(WriteMode.DELETE) .withVidAsProp(false) .withBatch(1000) .build() @@ -117,8 +123,7 @@ object NebulaSparkWriterExample { * if your withRankAsProperty is true, then edge schema also should have property name: degree */ def writeEdge(spark: SparkSession): Unit = { - LOG.info("start to write nebula edges") - val df = spark.read.json("example/src/main/resources/edge") + val df = spark.read.json("edge") df.show() df.persist(StorageLevel.MEMORY_AND_DISK_SER) diff --git a/nebula-spark-common/pom.xml b/nebula-spark-common/pom.xml index ca7d5281..d0a761fa 100644 --- a/nebula-spark-common/pom.xml +++ b/nebula-spark-common/pom.xml @@ -19,10 +19,6 @@ 3.2.3 4.13.1 1.13 - - 2.11 - 2.4.4 - 2.12.10 @@ -57,7 +53,7 @@ org.scalatest - scalatest-funsuite_2.11 + scalatest-funsuite_${scala.binary.version} ${scalatest.version} test diff --git a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala index d48e4034..054819a9 100644 --- a/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala +++ b/nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/reader/NebulaReader.scala @@ -52,7 +52,7 @@ trait NebulaReader { /** * init the reader: init metaProvider, storageClient */ - def init(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { + def init(index: Int, nebulaOptions: NebulaOptions, schema: StructType): Int = { this.schema = schema this.nebulaOptions = nebulaOptions @@ -104,11 +104,7 @@ trait NebulaReader { } // allocate scanPart to this partition val totalPart = metaProvider.getPartitionNumber(nebulaOptions.spaceName) - - // index starts with 1 - val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) - LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") - scanPartIterator = scanParts.iterator + totalPart } /** diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala index e9f122c1..0ee3b8bb 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -47,6 +47,7 @@ package object connector { val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) .option(NebulaOptions.LABEL, readConfig.getLabel) .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) @@ -84,6 +85,7 @@ package object connector { val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) .option(NebulaOptions.LABEL, readConfig.getLabel) .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) @@ -235,6 +237,7 @@ package object connector { .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) .option(NebulaOptions.LABEL, writeConfig.getTagName) .option(NebulaOptions.USER_NAME, writeConfig.getUser) @@ -278,6 +281,7 @@ package object connector { .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) .option(NebulaOptions.USER_NAME, writeConfig.getUser) .option(NebulaOptions.PASSWD, writeConfig.getPasswd) diff --git a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala index c491662d..06e06a59 100644 --- a/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala +++ b/nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -5,9 +5,8 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.connector.{NebulaOptions} +import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils} import org.apache.spark.sql.catalyst.InternalRow - import org.apache.spark.sql.sources.v2.reader.InputPartitionReader import org.apache.spark.sql.types.StructType import org.slf4j.{Logger, LoggerFactory} @@ -25,7 +24,11 @@ abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] w */ def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { this() - super.init(index, nebulaOptions, schema) + val totalPart = super.init(index, nebulaOptions, schema) + // index starts with 1 + val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) + LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") + scanPartIterator = scanParts.iterator } override def get(): InternalRow = super.getRow() diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala index 21c59f77..8c616727 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -110,6 +110,7 @@ package object connector { val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) .option(NebulaOptions.LABEL, readConfig.getLabel) .option(NebulaOptions.PARTITION_NUMBER, readConfig.getPartitionNum) @@ -147,6 +148,7 @@ package object connector { val dfReader = reader .format(classOf[NebulaDataSource].getName) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.READ.toString) .option(NebulaOptions.SPACE_NAME, readConfig.getSpace) .option(NebulaOptions.LABEL, readConfig.getLabel) .option(NebulaOptions.RETURN_COLS, readConfig.getReturnCols.mkString(",")) @@ -298,6 +300,7 @@ package object connector { .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.VERTEX.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) .option(NebulaOptions.LABEL, writeConfig.getTagName) .option(NebulaOptions.USER_NAME, writeConfig.getUser) @@ -340,6 +343,7 @@ package object connector { .format(classOf[NebulaDataSource].getName) .mode(SaveMode.Overwrite) .option(NebulaOptions.TYPE, DataTypeEnum.EDGE.toString) + .option(NebulaOptions.OPERATE_TYPE, OperaType.WRITE.toString) .option(NebulaOptions.SPACE_NAME, writeConfig.getSpace) .option(NebulaOptions.USER_NAME, writeConfig.getUser) .option(NebulaOptions.PASSWD, writeConfig.getPasswd) diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala index ef53ac54..7b458006 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala @@ -12,7 +12,6 @@ import org.slf4j.{Logger, LoggerFactory} class NebulaEdgePartitionReader(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) extends NebulaIterator(index, nebulaOptions, schema) { - private val LOG: Logger = LoggerFactory.getLogger(this.getClass) override def hasNext(): Boolean = hasNextEdgeRow } diff --git a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala index d03e5f50..3c182aa8 100644 --- a/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala +++ b/nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala @@ -6,20 +6,27 @@ package com.vesoft.nebula.connector.reader import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} - import org.apache.spark.Partition import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} /** * iterator for nebula vertex or edge data * convert each vertex data or edge data to Spark SQL's Row */ abstract class NebulaIterator extends Iterator[InternalRow] with NebulaReader { + private val LOG: Logger = LoggerFactory.getLogger(this.getClass) def this(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) { this() - super.init(index.index, nebulaOptions, schema) + val totalPart = super.init(index.index, nebulaOptions, schema) + // index starts with 0 + val nebulaPartition = index.asInstanceOf[NebulaPartition] + val scanParts = + nebulaPartition.getScanParts(totalPart, nebulaOptions.partitionNums.toInt) + LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") + scanPartIterator = scanParts.iterator } /** diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml index fae668eb..8888e1cc 100644 --- a/nebula-spark-connector_3.0/pom.xml +++ b/nebula-spark-connector_3.0/pom.xml @@ -26,45 +26,6 @@ com.vesoft nebula-spark-common ${project.version} - - - spark-core_2.11 - org.apache.spark - - - scalatest-funsuite_2.11 - org.scalatest - - - spark-graphx_2.11 - org.apache.spark - - - spark-sql_2.11 - org.apache.spark - - - - - org.apache.spark - spark-core_${scala.binary.version} - ${spark.version} - - - org.apache.spark - spark-sql_${scala.binary.version} - ${spark.version} - - - org.apache.spark - spark-graphx_${scala.binary.version} - ${spark.version} - - - - org.scalatest - scalatest-funsuite_2.12 - ${scalatest.version} diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala deleted file mode 100644 index d296376b..00000000 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala +++ /dev/null @@ -1,10 +0,0 @@ -/* Copyright (c) 2020 vesoft inc. All rights reserved. - * - * This source code is licensed under Apache 2.0 License. - */ - -package com.vesoft.nebula.connector.reader - -import org.apache.spark.sql.connector.read.InputPartition - -case class NebulaPartition(partition: Int) extends InputPartition diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala index 6b9878db..44fdf698 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala @@ -5,7 +5,7 @@ package com.vesoft.nebula.connector.reader -import com.vesoft.nebula.connector.{NebulaOptions} +import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType @@ -24,7 +24,11 @@ abstract class NebulaPartitionReader extends PartitionReader[InternalRow] with N */ def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { this() - super.init(index, nebulaOptions, schema) + val totalPart = super.init(index, nebulaOptions, schema) + // index starts with 1 + val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) + LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") + scanPartIterator = scanParts.iterator } override def get(): InternalRow = super.getRow() diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala index 20769571..f7515917 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala @@ -31,7 +31,7 @@ class SimpleScanBuilder(nebulaOptions: NebulaOptions, schema: StructType) private var filters: Array[Filter] = Array[Filter]() override def build(): Scan = { - new SimpleScan(nebulaOptions, 10, schema) + new SimpleScan(nebulaOptions, nebulaOptions.partitionNums.toInt, schema) } override def pushFilters(pushFilters: Array[Filter]): Array[Filter] = { @@ -56,12 +56,12 @@ class SimpleScan(nebulaOptions: NebulaOptions, nebulaTotalPart: Int, schema: Str override def toBatch: Batch = this override def planInputPartitions(): Array[InputPartition] = { - val partitionSize = nebulaTotalPart - val inputPartitions: java.util.List[InputPartition] = new util.ArrayList[InputPartition]() - for (i <- 1 to partitionSize) { - inputPartitions.add(NebulaPartition(i)) - } - inputPartitions.asScala.toArray + val partitionSize = nebulaTotalPart + val inputPartitions = for (i <- 1 to partitionSize) + yield { + NebulaPartition(i) + } + inputPartitions.map(_.asInstanceOf[InputPartition]).toArray } override def readSchema(): StructType = schema @@ -69,3 +69,8 @@ class SimpleScan(nebulaOptions: NebulaOptions, nebulaTotalPart: Int, schema: Str override def createReaderFactory(): PartitionReaderFactory = new NebulaPartitionReaderFactory(nebulaOptions, schema) } + +/** + * An identifier for a partition in an NebulaRDD. + */ +case class NebulaPartition(partition: Int) extends InputPartition diff --git a/pom.xml b/pom.xml index 5e3752dc..9abdf0fe 100644 --- a/pom.xml +++ b/pom.xml @@ -11,6 +11,9 @@ UTF-8 + 2.4.4 + 2.11 + 2.11.12 3.2.3 @@ -127,6 +130,38 @@ + + scala-2.11 + + 2.11.12 + 2.11 + + + + scala-2.12 + + 2.12.10 + 2.12 + + + + spark-2.2 + + 2.2.0 + + + + spark-2.4 + + 2.4.4 + + + + spark-3.0 + + 3.0.0 + + From f0059c8a07b0a16ded677f78e21e094f5b99b98c Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 18 Nov 2022 14:56:20 +0800 Subject: [PATCH 04/13] update maven command --- .github/workflows/pull_request.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index fda804df..a5c34dd5 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -45,6 +45,9 @@ jobs: - name: Build with Maven run: | - mvn -B package + mvn -B package -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 + mvn -B package -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 + mvn -B package -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 + - uses: codecov/codecov-action@v2 From 51b3a52f1f2d5ae09e9a03cc4a948e4b8dbbc0de Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 18 Nov 2022 16:36:35 +0800 Subject: [PATCH 05/13] sleep after create schema for test --- .../com/vesoft/nebula/connector/mock/NebulaGraphMock.scala | 2 ++ .../com/vesoft/nebula/connector/mock/NebulaGraphMock.scala | 2 ++ .../com/vesoft/nebula/connector/mock/NebulaGraphMock.scala | 2 ++ 3 files changed, 6 insertions(+) diff --git a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala index b8f0e72e..46c8f502 100644 --- a/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala +++ b/nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala @@ -169,6 +169,7 @@ class NebulaGraphMock { LOG.error("create string type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def mockIntIdGraphSchema(): Unit = { @@ -184,6 +185,7 @@ class NebulaGraphMock { LOG.error("create int type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def close(): Unit = { diff --git a/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala index b8f0e72e..46c8f502 100644 --- a/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala +++ b/nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala @@ -169,6 +169,7 @@ class NebulaGraphMock { LOG.error("create string type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def mockIntIdGraphSchema(): Unit = { @@ -184,6 +185,7 @@ class NebulaGraphMock { LOG.error("create int type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def close(): Unit = { diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala index b8f0e72e..46c8f502 100644 --- a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/mock/NebulaGraphMock.scala @@ -169,6 +169,7 @@ class NebulaGraphMock { LOG.error("create string type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def mockIntIdGraphSchema(): Unit = { @@ -184,6 +185,7 @@ class NebulaGraphMock { LOG.error("create int type space failed," + createResp.getErrorMessage) sys.exit(-1) } + Thread.sleep(10000) } def close(): Unit = { From 5688ed00e715e36ad642ace388cb48de1ba4434c Mon Sep 17 00:00:00 2001 From: Anqi Date: Mon, 21 Nov 2022 11:13:15 +0800 Subject: [PATCH 06/13] update scalatest plugin --- nebula-spark-connector/pom.xml | 2 +- nebula-spark-connector_2.2/pom.xml | 2 +- nebula-spark-connector_3.0/pom.xml | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/nebula-spark-connector/pom.xml b/nebula-spark-connector/pom.xml index 08a65f42..1174f703 100644 --- a/nebula-spark-connector/pom.xml +++ b/nebula-spark-connector/pom.xml @@ -47,7 +47,7 @@ org.scalatest scalatest-funsuite_2.11 - ${scalatest.version} + 3.2.3 test diff --git a/nebula-spark-connector_2.2/pom.xml b/nebula-spark-connector_2.2/pom.xml index 26e46cc3..b68fa0c4 100644 --- a/nebula-spark-connector_2.2/pom.xml +++ b/nebula-spark-connector_2.2/pom.xml @@ -46,7 +46,7 @@ org.scalatest scalatest-funsuite_2.11 - ${scalatest.version} + 3.2.3 test diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml index 8888e1cc..d298e569 100644 --- a/nebula-spark-connector_3.0/pom.xml +++ b/nebula-spark-connector_3.0/pom.xml @@ -27,6 +27,13 @@ nebula-spark-common ${project.version} + + + org.scalatest + scalatest-funsuite_2.12 + 3.2.3 + test + From a96dc109251d367aa4b38e06db51837a2f7f6353 Mon Sep 17 00:00:00 2001 From: Anqi Date: Mon, 21 Nov 2022 16:30:05 +0800 Subject: [PATCH 07/13] upadte default scala version --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 9abdf0fe..cbc42452 100644 --- a/pom.xml +++ b/pom.xml @@ -12,8 +12,8 @@ UTF-8 2.4.4 - 2.11 - 2.11.12 + 2.12 + 2.12.10 3.2.3 From 7b7965df03668c90568a20c81be1d2165027ca69 Mon Sep 17 00:00:00 2001 From: Anqi Date: Thu, 5 Jan 2023 14:10:10 +0800 Subject: [PATCH 08/13] update scala version --- nebula-spark-connector_3.0/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml index d298e569..a160810d 100644 --- a/nebula-spark-connector_3.0/pom.xml +++ b/nebula-spark-connector_3.0/pom.xml @@ -111,7 +111,7 @@ maven-scala-plugin 2.15.2 - 2.11.12 + 2.12.10 -target:jvm-1.8 From 2b2379c019a6e609cd5707649a91d4cf10ec6dd4 Mon Sep 17 00:00:00 2001 From: Anqi Date: Thu, 5 Jan 2023 16:03:54 +0800 Subject: [PATCH 09/13] fix options for writer --- .../scala/com/vesoft/nebula/connector/NebulaDataSource.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala index b7764e8d..a5c629bf 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala @@ -61,6 +61,9 @@ class NebulaDataSource extends TableProvider with DataSourceRegister { override def getTable(tableSchema: StructType, transforms: Array[Transform], map: util.Map[String, String]): Table = { + if (nebulaOptions == null) { + nebulaOptions = getNebulaOptions(new CaseInsensitiveStringMap(map)) + } new NebulaTable(tableSchema, nebulaOptions) } From 18ecfdc81a3299b3000d22a0d0be5ac73172126a Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 24 Feb 2023 11:27:02 +0800 Subject: [PATCH 10/13] update pom --- nebula-spark-connector_3.0/pom.xml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/nebula-spark-connector_3.0/pom.xml b/nebula-spark-connector_3.0/pom.xml index a160810d..989a87c8 100644 --- a/nebula-spark-connector_3.0/pom.xml +++ b/nebula-spark-connector_3.0/pom.xml @@ -16,12 +16,30 @@ 1.8 1.8 2.12 - 3.0.0 + 3.0.0 2.12.10 3.2.3 + + org.apache.spark + spark-core_2.12 + ${spark3.0.version} + provided + + + org.apache.spark + spark-sql_2.12 + ${spark3.0.version} + provided + + + org.apache.spark + spark-graphx_2.12 + ${spark3.0.version} + provided + com.vesoft nebula-spark-common From b6c89b02be0c7954123cc60dfea6682b9c230b15 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 24 Feb 2023 11:32:17 +0800 Subject: [PATCH 11/13] add spark version validate for spark3 --- .../com/vesoft/nebula/connector/package.scala | 3 +++ .../connector/SparkVersionValidateSuite.scala | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala diff --git a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala index b3cc5eed..79c4ba87 100644 --- a/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala +++ b/nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/package.scala @@ -6,6 +6,7 @@ package com.vesoft.nebula.connector import com.vesoft.nebula.connector.ssl.SSLSignType +import com.vesoft.nebula.connector.utils.SparkValidate import org.apache.spark.rdd.RDD import org.apache.spark.sql.{ DataFrame, @@ -30,6 +31,7 @@ package object connector { def nebula(connectionConfig: NebulaConnectionConfig, readConfig: ReadNebulaConfig): NebulaDataFrameReader = { + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") this.connectionConfig = connectionConfig this.readConfig = readConfig this @@ -179,6 +181,7 @@ package object connector { */ def nebula(connectionConfig: NebulaConnectionConfig, writeNebulaConfig: WriteNebulaConfig): NebulaDataFrameWriter = { + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") this.connectionConfig = connectionConfig this.writeNebulaConfig = writeNebulaConfig this diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala new file mode 100644 index 00000000..81a8a39c --- /dev/null +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala @@ -0,0 +1,21 @@ +/* Copyright (c) 2022 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +package com.vesoft.nebula.connector + +import com.vesoft.nebula.connector.utils.SparkValidate +import org.apache.spark.sql.SparkSession +import org.scalatest.funsuite.AnyFunSuite + +class SparkVersionValidateSuite extends AnyFunSuite { + test("spark version validate") { + try { + val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") + SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") + } catch { + case e: Exception => assert(false) + } + } +} From e74ec31f039e7a9e6e6d0bd9f02b28237970d905 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 24 Feb 2023 17:12:14 +0800 Subject: [PATCH 12/13] update action --- .github/workflows/pull_request.yml | 6 +++--- .github/workflows/release.yml | 21 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index a5c34dd5..daf6e83c 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -45,9 +45,9 @@ jobs: - name: Build with Maven run: | - mvn -B package -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 - mvn -B package -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 - mvn -B package -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 + mvn clean package -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 + mvn clean package -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 + mvn clean package -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 - uses: codecov/codecov-action@v2 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7132a0a7..f7563646 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,10 +38,29 @@ jobs: popd popd - - name: Deploy release to Maven + - name: Deploy release for spark2.4 to Maven uses: samuelmeuli/action-maven-publish@v1 with: gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} nexus_username: ${{ secrets.OSSRH_USERNAME }} nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 + + - name: Deploy release for spark2.2 to Maven + uses: samuelmeuli/action-maven-publish@v1 + with: + gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} + gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} + nexus_username: ${{ secrets.OSSRH_USERNAME }} + nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 + + - name: Deploy release for spark3.0 to Maven + uses: samuelmeuli/action-maven-publish@v1 + with: + gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} + gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} + nexus_username: ${{ secrets.OSSRH_USERNAME }} + nexus_password: ${{ secrets.OSSRH_TOKEN }} + maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 From 9628e330f40b0a1018d79550bf59e612dcd74c16 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 24 Feb 2023 17:12:28 +0800 Subject: [PATCH 13/13] update match statement --- .../vesoft/nebula/connector/writer/WriteInsertSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala index 3cd6c7d6..8856aa29 100644 --- a/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala +++ b/nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala @@ -27,7 +27,7 @@ class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { test("write vertex into test_write_string space with insert mode") { SparkMock.writeVertex() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses, 3000) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val createIndexResult: ResultSet = graphProvider.submit( @@ -50,7 +50,7 @@ class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { SparkMock.writeEdge() val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) - val graphProvider = new GraphProvider(addresses, 3000) + val graphProvider = new GraphProvider(addresses, 3000) graphProvider.switchSpace("root", "nebula", "test_write_string") val createIndexResult: ResultSet = graphProvider.submit( @@ -63,7 +63,7 @@ class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { graphProvider.submit("use test_write_string;") val resultSet: ResultSet = - graphProvider.submit("match (v:person_connector)-[e:friend_connector] -> () return e;") + graphProvider.submit("match (v:person_connector)-[e:friend_connector]-> () return e;") assert(resultSet.isSucceeded) assert(resultSet.getColumnNames.size() == 1) assert(resultSet.getRows.size() == 13)