Skip to content

Commit

Permalink
support custom partitioner for nebula when generate sst files (#49)
Browse files Browse the repository at this point in the history
* support custom partitioner for nebula when generate sst files

* support custom partitioner for nebula when generate sst files

* exclude jackson-core

* add test
  • Loading branch information
Nicole00 committed Jan 5, 2022
1 parent 70bcbf0 commit 78fb290
Show file tree
Hide file tree
Showing 24 changed files with 281 additions and 88 deletions.
Expand Up @@ -395,7 +395,8 @@ object Configs {
val localPath = getOptOrElse(tagConfig, "local.path")
val remotePath = getOptOrElse(tagConfig, "remote.path")

val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION)
val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION)
val repartitionWithNebula = getOrElse(tagConfig, "repartitionWithNebula", false)

LOG.info(s"name ${tagName} batch ${batch}")
val entry = TagConfigEntry(tagName,
Expand All @@ -407,7 +408,8 @@ object Configs {
policyOpt,
batch,
partition,
checkPointPath)
checkPointPath,
repartitionWithNebula)
LOG.info(s"Tag Config: ${entry}")
tags += entry
}
Expand Down Expand Up @@ -510,6 +512,8 @@ object Configs {
val localPath = getOptOrElse(edgeConfig, "path.local")
val remotePath = getOptOrElse(edgeConfig, "path.remote")

val repartitionWithNebula = getOrElse(edgeConfig, "repartitionWithNebula", false)

val entry = EdgeConfigEntry(
edgeName,
sourceConfig,
Expand All @@ -526,7 +530,8 @@ object Configs {
longitude,
batch,
partition,
checkPointPath
checkPointPath,
repartitionWithNebula
)
LOG.info(s"Edge Config: ${entry}")
edges += entry
Expand Down
Expand Up @@ -59,7 +59,8 @@ case class TagConfigEntry(override val name: String,
vertexPolicy: Option[KeyPolicy.Value],
override val batch: Int,
override val partition: Int,
override val checkPointPath: Option[String])
override val checkPointPath: Option[String],
repartitionWithNebula: Boolean = false)
extends SchemaConfigEntry {
require(name.trim.nonEmpty && vertexField.trim.nonEmpty && batch > 0)

Expand Down Expand Up @@ -108,7 +109,8 @@ case class EdgeConfigEntry(override val name: String,
longitude: Option[String],
override val batch: Int,
override val partition: Int,
override val checkPointPath: Option[String])
override val checkPointPath: Option[String],
repartitionWithNebula: Boolean = false)
extends SchemaConfigEntry {
require(
name.trim.nonEmpty && sourceField.trim.nonEmpty &&
Expand Down
Expand Up @@ -32,8 +32,12 @@ case class FileBaseSinkConfigEntry(override val category: SinkCategory.Value,
remotePath: String,
fsName: Option[String])
extends DataSinkConfigEntry {

override def toString: String = {
s"File sink: from ${localPath} to ${fsName.get}${remotePath}"
val fullRemotePath =
if (fsName.isDefined) s"${fsName.get}$remotePath"
else remotePath
s"File sink: from ${localPath} to $fullRemotePath"
}
}

Expand Down
Expand Up @@ -5,7 +5,8 @@

package com.vesoft.exchange.common.processor

import com.vesoft.exchange.common.utils.{HDFSUtils, NebulaUtils}
import com.vesoft.exchange.common.VidType
import com.vesoft.exchange.common.utils.{HDFSUtils, NebulaPartitioner, NebulaUtils}
import com.vesoft.exchange.common.utils.NebulaUtils.DEFAULT_EMPTY_VALUE
import com.vesoft.nebula.{
Coordinate,
Expand All @@ -21,7 +22,7 @@ import com.vesoft.nebula.{
Value
}
import org.apache.log4j.Logger
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row, SparkSession}

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
Expand Down Expand Up @@ -230,4 +231,19 @@ trait Processor extends Serializable {
else assert(assertion = false, context)
}

def customRepartition(spark: SparkSession,
data: Dataset[(Array[Byte], Array[Byte])],
partitionNum: Int): Dataset[(Array[Byte], Array[Byte])] = {
import spark.implicits._
data.rdd
.partitionBy(new NebulaPartitioner(partitionNum))
.map(kv => SSTData(kv._1, kv._2))
.toDF()
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))
}

}

case class SSTData(key: Array[Byte], value: Array[Byte])
@@ -0,0 +1,26 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.exchange.common.utils

import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.Partitioner

class NebulaPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

override def numPartitions: Int = partitions

override def getPartition(key: Any): Int = {
var part = ByteBuffer
.wrap(key.asInstanceOf[Array[Byte]], 0, 4)
.order(ByteOrder.nativeOrder)
.getInt >> 8
if (part <= 0) {
part = part + partitions
}
part - 1
}
}
Expand Up @@ -19,9 +19,8 @@ import org.slf4j.LoggerFactory
/**
* NebulaSSTWriter
*/
class NebulaSSTWriter extends Writer {
var isOpen = false
var path: String = _
class NebulaSSTWriter(path: String) extends Writer {
var isOpen = false

private val LOG = LoggerFactory.getLogger(getClass)

Expand All @@ -40,11 +39,6 @@ class NebulaSSTWriter extends Writer {
val env = new EnvOptions()
var writer: SstFileWriter = _

def withPath(path: String): NebulaSSTWriter = {
this.path = path
this
}

override def prepare(): Unit = {
writer = new SstFileWriter(env, options)
writer.open(path)
Expand All @@ -64,6 +58,11 @@ class NebulaSSTWriter extends Writer {
env.close()
}

}

class GenerateSstFile extends Serializable {
private val LOG = LoggerFactory.getLogger(getClass)

def writeSstFiles(iterator: Iterator[Row],
fileBaseConfig: FileBaseSinkConfigEntry,
partitionNum: Int,
Expand Down Expand Up @@ -97,7 +96,8 @@ class NebulaSSTWriter extends Writer {
}
currentPart = part
val tmp = s"$localPath/$currentPart-$taskID.sst"
withPath(tmp).prepare()
writer = new NebulaSSTWriter(tmp)
writer.prepare()
}
writer.write(key, value)
}
Expand Down
Expand Up @@ -5,7 +5,10 @@

package scala.com.vesoft.nebula.exchange.processor

import java.util

import com.vesoft.exchange.common.processor.Processor
import com.vesoft.exchange.common.utils.NebulaPartitioner
import com.vesoft.nebula.{
Coordinate,
Date,
Expand All @@ -19,6 +22,8 @@ import com.vesoft.nebula.{
Time,
Value
}
import org.apache.spark.TaskContext
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{
BooleanType,
Expand Down Expand Up @@ -217,4 +222,27 @@ class ProcessorSuite extends Processor {
printChoice(true, "nothing")
assertThrows[AssertionError](printChoice(false, "assert failed"))
}

@Test
def customRepartitionSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val key1 = "01d80100546f6d000000000000000000000000000000000002000000"
val key2 = "017b000030313233343536373839e4070000"
val value = "abc"
val data: Dataset[(Array[Byte], Array[Byte])] = spark.sparkContext
.parallelize(List(key1.getBytes(), key2.getBytes()))
.map(line => (line, value.getBytes()))
.toDF("key", "value")
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

val result = customRepartition(spark, data, 100)
val partitioner = new NebulaPartitioner(100)
result.map { row =>
assert(partitioner.getPartition(row._1) == TaskContext.getPartitionId())
""
}
}
}
@@ -0,0 +1,50 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.exchange.common.writer

import com.vesoft.exchange.common.config.{FileBaseSinkConfigEntry, SinkCategory}
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
import org.junit.Test

class FileBaseWriterSuite {

@Test
def writeSstFilesSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
// generate byte[] key using encoder's getVertexKey, space:"test", tag: "person"
val key1 = "01a40200310000000000000000000000000000000000000002000000" // id: "1"
val key2 = "01170000320000000000000000000000000000000000000002000000" // id: "2"
val key3 = "01fe0000330000000000000000000000000000000000000002000000" // id: "3"
val key4 = "01a90300340000000000000000000000000000000000000002000000" // id: "4"
val key5 = "01220200350000000000000000000000000000000000000002000000" // id: "5"
val value = "abc"
// construct test dataset
val data: Dataset[(Array[Byte], Array[Byte])] = spark.sparkContext
.parallelize(
List(key1.getBytes(), key2.getBytes(), key3.getBytes(), key4.getBytes(), key5.getBytes()))
.map(line => (line, value.getBytes()))
.toDF("key", "value")
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

val generateSstFile = new GenerateSstFile

val fileBaseConfig =
FileBaseSinkConfigEntry(SinkCategory.SST, "/tmp", "/tmp/remote", None)
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.test}")

data
.toDF("key", "value")
.sortWithinPartitions("key")
.foreachPartition { iterator: Iterator[Row] =>
generateSstFile.writeSstFiles(iterator, fileBaseConfig, 10, null, batchFailure)
}
assert(batchFailure.value == 0)
}

}
Expand Up @@ -140,13 +140,15 @@ object Exchange {
spark.sparkContext.longAccumulator(s"batchFailure.${tagConfig.name}")

val processor = new VerticesProcessor(
spark,
repartition(data.get, tagConfig.partition, tagConfig.dataSourceConfigEntry.category),
tagConfig,
fieldKeys,
nebulaKeys,
configs,
batchSuccess,
batchFailure)
batchFailure
)
processor.process()
val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f")
LOG.info(s"import for tag ${tagConfig.name} cost time: ${costTime} s")
Expand Down Expand Up @@ -185,6 +187,7 @@ object Exchange {
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}")

val processor = new EdgeProcessor(
spark,
repartition(data.get, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category),
edgeConfig,
fieldKeys,
Expand Down
Expand Up @@ -19,21 +19,22 @@ import com.vesoft.exchange.common.config.{
import com.vesoft.exchange.common.processor.Processor
import com.vesoft.exchange.common.utils.NebulaUtils
import com.vesoft.exchange.common.utils.NebulaUtils.DEFAULT_EMPTY_VALUE
import com.vesoft.exchange.common.writer.{NebulaGraphClientWriter, NebulaSSTWriter}
import com.vesoft.exchange.common.writer.{GenerateSstFile, NebulaGraphClientWriter, NebulaSSTWriter}
import com.vesoft.exchange.common.VidType
import com.vesoft.nebula.encoder.NebulaCodecImpl
import com.vesoft.nebula.meta.EdgeItem
import org.apache.commons.codec.digest.MurmurHash2
import org.apache.log4j.Logger
import org.apache.spark.TaskContext
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.{DataFrame, Encoders, Row}
import org.apache.spark.sql.{DataFrame, Encoders, Row, SparkSession}
import org.apache.spark.util.LongAccumulator

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class EdgeProcessor(data: DataFrame,
class EdgeProcessor(spark: SparkSession,
data: DataFrame,
edgeConfig: EdgeConfigEntry,
fieldKeys: List[String],
nebulaKeys: List[String],
Expand Down Expand Up @@ -113,7 +114,7 @@ class EdgeProcessor(data: DataFrame,
} else {
data.dropDuplicates(edgeConfig.sourceField, edgeConfig.targetField)
}
distintData
var sstKeyValueData = distintData
.mapPartitions { iter =>
iter.map { row =>
encodeEdge(row, partitionNum, vidType, spaceVidLen, edgeItem, fieldTypeMap)
Expand All @@ -122,15 +123,22 @@ class EdgeProcessor(data: DataFrame,
.flatMap(line => {
List((line._1, line._3), (line._2, line._3))
})(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

// repartition dataframe according to nebula part, to make sure sst files for one part has no overlap
if (edgeConfig.repartitionWithNebula) {
sstKeyValueData = customRepartition(spark, sstKeyValueData, partitionNum)
}

sstKeyValueData
.toDF("key", "value")
.sortWithinPartitions("key")
.foreachPartition { iterator: Iterator[Row] =>
val sstFileWriter = new NebulaSSTWriter
sstFileWriter.writeSstFiles(iterator,
fileBaseConfig,
partitionNum,
namenode,
batchFailure)
val generateSstFile = new GenerateSstFile
generateSstFile.writeSstFiles(iterator,
fileBaseConfig,
partitionNum,
namenode,
batchFailure)
}
} else {
val streamFlag = data.isStreaming
Expand Down

0 comments on commit 78fb290

Please sign in to comment.