Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support custom partitioner for nebula when generate sst files #49

Merged
merged 4 commits into from Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -395,7 +395,8 @@ object Configs {
val localPath = getOptOrElse(tagConfig, "local.path")
val remotePath = getOptOrElse(tagConfig, "remote.path")

val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION)
val partition = getOrElse(tagConfig, "partition", DEFAULT_PARTITION)
val repartitionWithNebula = getOrElse(tagConfig, "repartitionWithNebula", false)

LOG.info(s"name ${tagName} batch ${batch}")
val entry = TagConfigEntry(tagName,
Expand All @@ -407,7 +408,8 @@ object Configs {
policyOpt,
batch,
partition,
checkPointPath)
checkPointPath,
repartitionWithNebula)
LOG.info(s"Tag Config: ${entry}")
tags += entry
}
Expand Down Expand Up @@ -510,6 +512,8 @@ object Configs {
val localPath = getOptOrElse(edgeConfig, "path.local")
val remotePath = getOptOrElse(edgeConfig, "path.remote")

val repartitionWithNebula = getOrElse(edgeConfig, "repartitionWithNebula", false)

val entry = EdgeConfigEntry(
edgeName,
sourceConfig,
Expand All @@ -526,7 +530,8 @@ object Configs {
longitude,
batch,
partition,
checkPointPath
checkPointPath,
repartitionWithNebula
)
LOG.info(s"Edge Config: ${entry}")
edges += entry
Expand Down
Expand Up @@ -59,7 +59,8 @@ case class TagConfigEntry(override val name: String,
vertexPolicy: Option[KeyPolicy.Value],
override val batch: Int,
override val partition: Int,
override val checkPointPath: Option[String])
override val checkPointPath: Option[String],
repartitionWithNebula: Boolean = false)
extends SchemaConfigEntry {
require(name.trim.nonEmpty && vertexField.trim.nonEmpty && batch > 0)

Expand Down Expand Up @@ -108,7 +109,8 @@ case class EdgeConfigEntry(override val name: String,
longitude: Option[String],
override val batch: Int,
override val partition: Int,
override val checkPointPath: Option[String])
override val checkPointPath: Option[String],
repartitionWithNebula: Boolean = false)
extends SchemaConfigEntry {
require(
name.trim.nonEmpty && sourceField.trim.nonEmpty &&
Expand Down
Expand Up @@ -32,8 +32,12 @@ case class FileBaseSinkConfigEntry(override val category: SinkCategory.Value,
remotePath: String,
fsName: Option[String])
extends DataSinkConfigEntry {

override def toString: String = {
s"File sink: from ${localPath} to ${fsName.get}${remotePath}"
val fullRemotePath =
if (fsName.isDefined) s"${fsName.get}$remotePath"
else remotePath
s"File sink: from ${localPath} to $fullRemotePath"
}
}

Expand Down
Expand Up @@ -5,7 +5,8 @@

package com.vesoft.exchange.common.processor

import com.vesoft.exchange.common.utils.{HDFSUtils, NebulaUtils}
import com.vesoft.exchange.common.VidType
import com.vesoft.exchange.common.utils.{HDFSUtils, NebulaPartitioner, NebulaUtils}
import com.vesoft.exchange.common.utils.NebulaUtils.DEFAULT_EMPTY_VALUE
import com.vesoft.nebula.{
Coordinate,
Expand All @@ -21,7 +22,7 @@ import com.vesoft.nebula.{
Value
}
import org.apache.log4j.Logger
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row, SparkSession}

import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
Expand Down Expand Up @@ -230,4 +231,19 @@ trait Processor extends Serializable {
else assert(assertion = false, context)
}

def customRepartition(spark: SparkSession,
data: Dataset[(Array[Byte], Array[Byte])],
partitionNum: Int): Dataset[(Array[Byte], Array[Byte])] = {
import spark.implicits._
data.rdd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't use repartition directly?

Copy link
Contributor Author

@Nicole00 Nicole00 Jan 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't use repartition directly?

Dataframe doesn't have customed repartition function, it's RDD's function.

.partitionBy(new NebulaPartitioner(partitionNum))
.map(kv => SSTData(kv._1, kv._2))
.toDF()
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))
}

}

case class SSTData(key: Array[Byte], value: Array[Byte])
@@ -0,0 +1,26 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.exchange.common.utils

import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.Partitioner

class NebulaPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

override def numPartitions: Int = partitions

override def getPartition(key: Any): Int = {
var part = ByteBuffer
.wrap(key.asInstanceOf[Array[Byte]], 0, 4)
.order(ByteOrder.nativeOrder)
.getInt >> 8
if (part <= 0) {
part = part + partitions
}
part - 1
}
}
Expand Up @@ -19,9 +19,8 @@ import org.slf4j.LoggerFactory
/**
* NebulaSSTWriter
*/
class NebulaSSTWriter extends Writer {
var isOpen = false
var path: String = _
class NebulaSSTWriter(path: String) extends Writer {
var isOpen = false

private val LOG = LoggerFactory.getLogger(getClass)

Expand All @@ -40,11 +39,6 @@ class NebulaSSTWriter extends Writer {
val env = new EnvOptions()
var writer: SstFileWriter = _

def withPath(path: String): NebulaSSTWriter = {
this.path = path
this
}

override def prepare(): Unit = {
writer = new SstFileWriter(env, options)
writer.open(path)
Expand All @@ -64,6 +58,11 @@ class NebulaSSTWriter extends Writer {
env.close()
}

}

class GenerateSstFile extends Serializable {
private val LOG = LoggerFactory.getLogger(getClass)

def writeSstFiles(iterator: Iterator[Row],
fileBaseConfig: FileBaseSinkConfigEntry,
partitionNum: Int,
Expand Down Expand Up @@ -97,7 +96,8 @@ class NebulaSSTWriter extends Writer {
}
currentPart = part
val tmp = s"$localPath/$currentPart-$taskID.sst"
withPath(tmp).prepare()
writer = new NebulaSSTWriter(tmp)
writer.prepare()
}
writer.write(key, value)
}
Expand Down
Expand Up @@ -5,7 +5,10 @@

package scala.com.vesoft.nebula.exchange.processor

import java.util

import com.vesoft.exchange.common.processor.Processor
import com.vesoft.exchange.common.utils.NebulaPartitioner
import com.vesoft.nebula.{
Coordinate,
Date,
Expand All @@ -19,6 +22,8 @@ import com.vesoft.nebula.{
Time,
Value
}
import org.apache.spark.TaskContext
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{
BooleanType,
Expand Down Expand Up @@ -217,4 +222,27 @@ class ProcessorSuite extends Processor {
printChoice(true, "nothing")
assertThrows[AssertionError](printChoice(false, "assert failed"))
}

@Test
def customRepartitionSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
val key1 = "01d80100546f6d000000000000000000000000000000000002000000"
val key2 = "017b000030313233343536373839e4070000"
val value = "abc"
val data: Dataset[(Array[Byte], Array[Byte])] = spark.sparkContext
.parallelize(List(key1.getBytes(), key2.getBytes()))
.map(line => (line, value.getBytes()))
.toDF("key", "value")
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

val result = customRepartition(spark, data, 100)
val partitioner = new NebulaPartitioner(100)
result.map { row =>
assert(partitioner.getPartition(row._1) == TaskContext.getPartitionId())
""
}
}
}
@@ -0,0 +1,50 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

package com.vesoft.exchange.common.writer

import com.vesoft.exchange.common.config.{FileBaseSinkConfigEntry, SinkCategory}
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
import org.junit.Test

class FileBaseWriterSuite {

@Test
def writeSstFilesSuite(): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
import spark.implicits._
// generate byte[] key using encoder's getVertexKey, space:"test", tag: "person"
val key1 = "01a40200310000000000000000000000000000000000000002000000" // id: "1"
val key2 = "01170000320000000000000000000000000000000000000002000000" // id: "2"
val key3 = "01fe0000330000000000000000000000000000000000000002000000" // id: "3"
val key4 = "01a90300340000000000000000000000000000000000000002000000" // id: "4"
val key5 = "01220200350000000000000000000000000000000000000002000000" // id: "5"
val value = "abc"
// construct test dataset
val data: Dataset[(Array[Byte], Array[Byte])] = spark.sparkContext
.parallelize(
List(key1.getBytes(), key2.getBytes(), key3.getBytes(), key4.getBytes(), key5.getBytes()))
.map(line => (line, value.getBytes()))
.toDF("key", "value")
.map { row =>
(row.getAs[Array[Byte]](0), row.getAs[Array[Byte]](1))
}(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

val generateSstFile = new GenerateSstFile

val fileBaseConfig =
FileBaseSinkConfigEntry(SinkCategory.SST, "/tmp", "/tmp/remote", None)
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.test}")

data
.toDF("key", "value")
.sortWithinPartitions("key")
.foreachPartition { iterator: Iterator[Row] =>
generateSstFile.writeSstFiles(iterator, fileBaseConfig, 10, null, batchFailure)
}
assert(batchFailure.value == 0)
}

}
Expand Up @@ -140,13 +140,15 @@ object Exchange {
spark.sparkContext.longAccumulator(s"batchFailure.${tagConfig.name}")

val processor = new VerticesProcessor(
spark,
repartition(data.get, tagConfig.partition, tagConfig.dataSourceConfigEntry.category),
tagConfig,
fieldKeys,
nebulaKeys,
configs,
batchSuccess,
batchFailure)
batchFailure
)
processor.process()
val costTime = ((System.currentTimeMillis() - startTime) / 1000.0).formatted("%.2f")
LOG.info(s"import for tag ${tagConfig.name} cost time: ${costTime} s")
Expand Down Expand Up @@ -185,6 +187,7 @@ object Exchange {
val batchFailure = spark.sparkContext.longAccumulator(s"batchFailure.${edgeConfig.name}")

val processor = new EdgeProcessor(
spark,
repartition(data.get, edgeConfig.partition, edgeConfig.dataSourceConfigEntry.category),
edgeConfig,
fieldKeys,
Expand Down
Expand Up @@ -19,21 +19,22 @@ import com.vesoft.exchange.common.config.{
import com.vesoft.exchange.common.processor.Processor
import com.vesoft.exchange.common.utils.NebulaUtils
import com.vesoft.exchange.common.utils.NebulaUtils.DEFAULT_EMPTY_VALUE
import com.vesoft.exchange.common.writer.{NebulaGraphClientWriter, NebulaSSTWriter}
import com.vesoft.exchange.common.writer.{GenerateSstFile, NebulaGraphClientWriter, NebulaSSTWriter}
import com.vesoft.exchange.common.VidType
import com.vesoft.nebula.encoder.NebulaCodecImpl
import com.vesoft.nebula.meta.EdgeItem
import org.apache.commons.codec.digest.MurmurHash2
import org.apache.log4j.Logger
import org.apache.spark.TaskContext
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.{DataFrame, Encoders, Row}
import org.apache.spark.sql.{DataFrame, Encoders, Row, SparkSession}
import org.apache.spark.util.LongAccumulator

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

class EdgeProcessor(data: DataFrame,
class EdgeProcessor(spark: SparkSession,
data: DataFrame,
edgeConfig: EdgeConfigEntry,
fieldKeys: List[String],
nebulaKeys: List[String],
Expand Down Expand Up @@ -113,7 +114,7 @@ class EdgeProcessor(data: DataFrame,
} else {
data.dropDuplicates(edgeConfig.sourceField, edgeConfig.targetField)
}
distintData
var sstKeyValueData = distintData
.mapPartitions { iter =>
iter.map { row =>
encodeEdge(row, partitionNum, vidType, spaceVidLen, edgeItem, fieldTypeMap)
Expand All @@ -122,15 +123,22 @@ class EdgeProcessor(data: DataFrame,
.flatMap(line => {
List((line._1, line._3), (line._2, line._3))
})(Encoders.tuple(Encoders.BINARY, Encoders.BINARY))

// repartition dataframe according to nebula part, to make sure sst files for one part has no overlap
if (edgeConfig.repartitionWithNebula) {
sstKeyValueData = customRepartition(spark, sstKeyValueData, partitionNum)
}

sstKeyValueData
.toDF("key", "value")
.sortWithinPartitions("key")
.foreachPartition { iterator: Iterator[Row] =>
val sstFileWriter = new NebulaSSTWriter
sstFileWriter.writeSstFiles(iterator,
fileBaseConfig,
partitionNum,
namenode,
batchFailure)
val generateSstFile = new GenerateSstFile
generateSstFile.writeSstFiles(iterator,
fileBaseConfig,
partitionNum,
namenode,
batchFailure)
}
} else {
val streamFlag = data.isStreaming
Expand Down