diff --git a/README-CN.md b/README-CN.md index b83a3ac..f8e3768 100644 --- a/README-CN.md +++ b/README-CN.md @@ -51,9 +51,6 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的 ``` ${SPARK_HOME}/bin/spark-submit --master --class com.vesoft.nebula.algorithm.Main nebula-algorithm-3.0-SNAPSHOT.jar -p application.conf ``` - * 使用限制 - - Nebula Algorithm 算法包未自动对字符串 id 进行编码,因此采用第一种方式执行图算法时,边的源点和目标点必须是整数(Nebula Space 的 vid_type 可以是 String 类型,但数据必须是整数)。 * 使用方法2:调用 nebula-algorithm 算法接口 在 `nebula-algorithm` 的 `lib` 库中提供了10+种常用图计算算法,可通过编程调用的形式调用算法。 @@ -75,7 +72,8 @@ nebula-algorithm 是一款基于 [GraphX](https://spark.apache.org/graphx/) 的 val prConfig = new PRConfig(5, 1.0) val prResult = PageRankAlgo.apply(spark, data, prConfig, false) ``` - * 如果你的节点 id 是 String 类型,可以参考 PageRank 的 [Example](https://github.com/vesoft-inc/nebula-algorithm/blob/master/example/src/main/scala/com/vesoft/nebula/algorithm/PageRankExample.scala) 。 + * 如果你的节点 id 是 String 类型,可以参考 [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala). + 该 Example 进行了 id 转换,将 String 类型 id 编码为 Long 类型的 id , 并在算法结果中将 Long 类型 id 解码为原始的 String 类型 id 。 其他算法的调用方法见[测试示例](https://github.com/vesoft-inc/nebula-algorithm/tree/master/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib) 。 diff --git a/README.md b/README.md index ba4370e..a9f8795 100644 --- a/README.md +++ b/README.md @@ -60,10 +60,6 @@ You could submit the entire spark application or invoke algorithms in `lib` libr ${SPARK_HOME}/bin/spark-submit --master --class com.vesoft.nebula.algorithm.Main nebula-algorithm-3.0—SNAPSHOT.jar -p application.conf ``` - * Limitation - - Due to Nebula Algorithm jar does not encode string id, thus during the algorithm execution, the source and target of edges must be in Type Int (The `vid_type` in Nebula Space could be String, while data must be in Type Int). - * Option2: Call nebula-algorithm interface Now there are 10+ algorithms provided in `lib` from `nebula-algorithm`, which could be invoked in a programming fashion as below: @@ -87,7 +83,7 @@ You could submit the entire spark application or invoke algorithms in `lib` libr val prResult = PageRankAlgo.apply(spark, data, prConfig, false) ``` - If your vertex ids are Strings, see [Pagerank Example](https://github.com/vesoft-inc/nebula-algorithm/blob/master/example/src/main/scala/com/vesoft/nebula/algorithm/PageRankExample.scala) for how to encoding and decoding them. + If your vertex ids are Strings, please set the algo config with encodeId = true. see [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala) For examples of other algorithms, see [examples](https://github.com/vesoft-inc/nebula-algorithm/tree/master/example/src/main/scala/com/vesoft/nebula/algorithm) > Note: The first column of DataFrame in the application represents the source vertices, the second represents the target vertices and the third represents edges' weight. diff --git a/example/pom.xml b/example/pom.xml index beda679..f3956c6 100644 --- a/example/pom.xml +++ b/example/pom.xml @@ -182,7 +182,7 @@ com.vesoft nebula-algorithm - 3.0.0 + 3.0-SNAPSHOT diff --git a/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala b/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala index 6badb57..19d799d 100644 --- a/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala +++ b/example/src/main/scala/com/vesoft/nebula/algorithm/DegreeStaticExample.scala @@ -6,7 +6,8 @@ package com.vesoft.nebula.algorithm import com.facebook.thrift.protocol.TCompactProtocol -import com.vesoft.nebula.algorithm.lib.{DegreeStaticAlgo} +import com.vesoft.nebula.algorithm.config.DegreeStaticConfig +import com.vesoft.nebula.algorithm.lib.DegreeStaticAlgo import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} @@ -22,15 +23,22 @@ object DegreeStaticExample { .config(sparkConf) .getOrCreate() - // val csvDF = ReadData.readCsvData(spark) // val nebulaDF = ReadData.readNebulaData(spark) val journalDF = ReadData.readLiveJournalData(spark) - degree(spark, journalDF) + + val csvDF = ReadData.readStringCsvData(spark) + degreeForStringId(spark, csvDF) } def degree(spark: SparkSession, df: DataFrame): Unit = { val degree = DegreeStaticAlgo.apply(spark, df) degree.show() } + + def degreeForStringId(spark: SparkSession, df: DataFrame): Unit = { + val degreeConfig = new DegreeStaticConfig(true) + val degree = DegreeStaticAlgo.apply(spark, df, degreeConfig) + degree.show() + } } diff --git a/nebula-algorithm/src/main/resources/application.conf b/nebula-algorithm/src/main/resources/application.conf index f10ad45..40ebe75 100644 --- a/nebula-algorithm/src/main/resources/application.conf +++ b/nebula-algorithm/src/main/resources/application.conf @@ -113,7 +113,9 @@ } # Vertex degree statistics parameter - degreestatic: {} + degreestatic: { + encodeId:false + } # KCore parameter kcore:{ @@ -123,7 +125,9 @@ } # Trianglecount parameter - trianglecount:{} + trianglecount:{ + encodeId:false + } # graphTriangleCount parameter graphtrianglecount:{} @@ -189,6 +193,7 @@ # JaccardAlgo parameter jaccard:{ tol: 1.0 + encodeId:false } } } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala index 9844e07..9811185 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala @@ -5,6 +5,7 @@ package com.vesoft.nebula.algorithm.config +import com.vesoft.nebula.algorithm.config.JaccardConfig.encodeId import org.apache.spark.graphx.VertexId case class PRConfig(maxIter: Int, resetProb: Double, encodeId: Boolean = false) @@ -110,24 +111,30 @@ object LouvainConfig { /** * degree static */ -case class DegreeStaticConfig(degree: Boolean, - inDegree: Boolean, - outDegree: Boolean, - encodeId: Boolean = false) +case class DegreeStaticConfig(encodeId: Boolean = false) object DegreeStaticConfig { - var degree: Boolean = false - var inDegree: Boolean = false - var outDegree: Boolean = false - var encodeId: Boolean = false + var encodeId: Boolean = false def getDegreeStaticConfig(configs: Configs): DegreeStaticConfig = { val degreeConfig = configs.algorithmConfig.map - degree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.degree", false) - inDegree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.indegree", false) - outDegree = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.outdegree", false) encodeId = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.encodeId", false) - DegreeStaticConfig(degree, inDegree, outDegree, encodeId) + DegreeStaticConfig(encodeId) + } +} + +/** + * graph triangle count + */ +case class TriangleConfig(encodeId: Boolean = false) + +object TriangleConfig { + var encodeId: Boolean = false + def getTriangleConfig(configs: Configs): TriangleConfig = { + val triangleConfig = configs.algorithmConfig.map + encodeId = + ConfigUtil.getOrElseBoolean(triangleConfig, "algorithm.trianglecount.encodeId", false) + TriangleConfig(encodeId) } } @@ -321,14 +328,16 @@ object Node2vecConfig { /** * Jaccard */ -case class JaccardConfig(tol: Double) +case class JaccardConfig(tol: Double, encodeId: Boolean = false) object JaccardConfig { - var tol: Double = _ + var tol: Double = _ + var encodeId: Boolean = false def getJaccardConfig(configs: Configs): JaccardConfig = { val jaccardConfig = configs.algorithmConfig.map tol = jaccardConfig("algorithm.jaccard.tol").toDouble - JaccardConfig(tol) + encodeId = ConfigUtil.getOrElseBoolean(jaccardConfig, "algorithm.jaccard.encodeId", false) + JaccardConfig(tol, encodeId) } } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala index ce639c5..8765f93 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala @@ -50,9 +50,9 @@ object BfsAlgo { .orderBy(col(AlgoConstants.BFS_RESULT_COL)) if (bfsConfig.encodeId) { - DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf) + DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1) } else { - algoResult + algoResult.coalesce(1) } } diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgo.scala index 3a63c71..7721ecc 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgo.scala @@ -5,8 +5,8 @@ package com.vesoft.nebula.algorithm.lib -import com.vesoft.nebula.algorithm.config.AlgoConstants -import com.vesoft.nebula.algorithm.utils.NebulaUtil +import com.vesoft.nebula.algorithm.config.{AlgoConstants, DegreeStaticConfig} +import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil} import org.apache.log4j.Logger import org.apache.spark.graphx.{Graph, VertexRDD} import org.apache.spark.rdd.RDD @@ -22,9 +22,18 @@ object DegreeStaticAlgo { /** * run the pagerank algorithm for nebula graph */ - def apply(spark: SparkSession, dataset: Dataset[Row]): DataFrame = { + def apply(spark: SparkSession, + dataset: Dataset[Row], + degreeConfig: DegreeStaticConfig = new DegreeStaticConfig): DataFrame = { + var encodeIdDf: DataFrame = null - val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false) + val graph: Graph[None.type, Double] = if (degreeConfig.encodeId) { + val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false) + encodeIdDf = encodeId + NebulaUtil.loadInitGraph(data, false) + } else { + NebulaUtil.loadInitGraph(dataset, false) + } val degreeResultRDD = execute(graph) @@ -38,7 +47,11 @@ object DegreeStaticAlgo { val algoResult = spark.sqlContext .createDataFrame(degreeResultRDD, schema) - algoResult + if (degreeConfig.encodeId) { + DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf) + } else { + algoResult + } } def execute(graph: Graph[None.type, Double]): RDD[Row] = { diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgo.scala index 01c1351..8339d70 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgo.scala @@ -6,7 +6,9 @@ package com.vesoft.nebula.algorithm.lib import com.vesoft.nebula.algorithm.config.JaccardConfig +import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil} import org.apache.log4j.Logger +import org.apache.spark.graphx.Graph import org.apache.spark.ml.feature.{ CountVectorizer, CountVectorizerModel, @@ -29,7 +31,16 @@ object JaccardAlgo { */ def apply(spark: SparkSession, dataset: Dataset[Row], jaccardConfig: JaccardConfig): DataFrame = { - val jaccardResult: RDD[Row] = execute(spark, dataset, jaccardConfig.tol) + var encodeIdDf: DataFrame = null + var data: DataFrame = dataset + + if (jaccardConfig.encodeId) { + val (encodeData, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false) + encodeIdDf = encodeId + data = encodeData + } + + val jaccardResult: RDD[Row] = execute(spark, data, jaccardConfig.tol) val schema = StructType( List( @@ -38,7 +49,13 @@ object JaccardAlgo { StructField("similarity", DoubleType, nullable = true) )) val algoResult = spark.sqlContext.createDataFrame(jaccardResult, schema) - algoResult + + if (jaccardConfig.encodeId) { + DecodeUtil.convertIds2String(algoResult, encodeIdDf, "srcId", "dstId") + } else { + algoResult + } + } def execute(spark: SparkSession, dataset: Dataset[Row], tol: Double): RDD[Row] = { diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/TriangleCountAlgo.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/TriangleCountAlgo.scala index 47b9b8a..377c5a7 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/TriangleCountAlgo.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/TriangleCountAlgo.scala @@ -5,8 +5,8 @@ package com.vesoft.nebula.algorithm.lib -import com.vesoft.nebula.algorithm.config.AlgoConstants -import com.vesoft.nebula.algorithm.utils.NebulaUtil +import com.vesoft.nebula.algorithm.config.{AlgoConstants, TriangleConfig} +import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil} import org.apache.log4j.Logger import org.apache.spark.graphx.{Graph, VertexRDD} import org.apache.spark.graphx.lib.TriangleCount @@ -24,9 +24,19 @@ object TriangleCountAlgo { * * compute each vertex's triangle count */ - def apply(spark: SparkSession, dataset: Dataset[Row]): DataFrame = { + def apply(spark: SparkSession, + dataset: Dataset[Row], + triangleConfig: TriangleConfig = new TriangleConfig): DataFrame = { - val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false) + var encodeIdDf: DataFrame = null + + val graph: Graph[None.type, Double] = if (triangleConfig.encodeId) { + val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false) + encodeIdDf = encodeId + NebulaUtil.loadInitGraph(data, false) + } else { + NebulaUtil.loadInitGraph(dataset, false) + } val triangleResultRDD = execute(graph) @@ -38,7 +48,11 @@ object TriangleCountAlgo { val algoResult = spark.sqlContext .createDataFrame(triangleResultRDD, schema) - algoResult + if (triangleConfig.encodeId) { + DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf) + } else { + algoResult + } } def execute(graph: Graph[None.type, Double]): RDD[Row] = { diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/DecodeUtil.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/DecodeUtil.scala index 9343a24..f540240 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/DecodeUtil.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/utils/DecodeUtil.scala @@ -73,4 +73,21 @@ object DecodeUtil { .drop(algoProp) .withColumnRenamed(ORIGIN_ID_COL, algoProp) } + + def convertIds2String(dataframe: DataFrame, + encodeId: DataFrame, + srcCol: String, + dstCol: String): DataFrame = { + encodeId + .join(dataframe) + .where(col(ENCODE_ID_COL) === col(srcCol)) + .drop(ENCODE_ID_COL) + .drop(srcCol) + .withColumnRenamed(ORIGIN_ID_COL, srcCol) + .join(encodeId) + .where(col(dstCol) === col(ENCODE_ID_COL)) + .drop(ENCODE_ID_COL) + .drop(dstCol) + .withColumnRenamed(ORIGIN_ID_COL, dstCol) + } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgoSuite.scala index bcf5b8f..29cf6f3 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DegreeStaticAlgoSuite.scala @@ -5,13 +5,15 @@ package com.vesoft.nebula.algorithm.lib +import com.vesoft.nebula.algorithm.config.DegreeStaticConfig import org.apache.spark.sql.SparkSession import org.junit.Test class DegreeStaticAlgoSuite { @Test def degreeStaticAlgoSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val result = DegreeStaticAlgo.apply(spark, data) assert(result.count() == 4) @@ -20,5 +22,9 @@ class DegreeStaticAlgoSuite { assert(row.get(2).toString.toInt == 4) assert(row.get(3).toString.toInt == 4) }) + + val config = DegreeStaticConfig(true) + val encodeResult = DegreeStaticAlgo.apply(spark, data, config) + assert(result.count() == 4) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/HanpSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/HanpSuite.scala index ed4bbfc..9a4acf1 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/HanpSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/HanpSuite.scala @@ -14,10 +14,14 @@ import org.junit.Test class HanpSuite { @Test def hanpSuite() = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val hanpConfig = new HanpConfig(0.1, 10, 1.0) val result = HanpAlgo.apply(spark, data, hanpConfig, false) assert(result.count() == 4) + + val encodeHanpConfig = new HanpConfig(0.1, 10, 1.0, true) + assert(HanpAlgo.apply(spark, data, encodeHanpConfig, false).count() == 4) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgoSuite.scala index e4a7a52..2e2ab35 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/JaccardAlgoSuite.scala @@ -12,11 +12,16 @@ import org.junit.Test class JaccardAlgoSuite { @Test def kcoreSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val jaccardConfig = new JaccardConfig(0.01) val jaccardResult = JaccardAlgo.apply(spark, data, jaccardConfig) jaccardResult.show() assert(jaccardResult.count() == 6) + + val encodeJaccardConfig = new JaccardConfig(0.01, true) + val encodeJaccardResult = JaccardAlgo.apply(spark, data, encodeJaccardConfig) + assert(encodeJaccardResult.count() == 6) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/KCoreAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/KCoreAlgoSuite.scala index feabe70..1a4370f 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/KCoreAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/KCoreAlgoSuite.scala @@ -12,10 +12,14 @@ import org.junit.Test class KCoreAlgoSuite { @Test def kcoreSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val kcoreConfig = new KCoreConfig(10, 3) val kcoreResult = KCoreAlgo.apply(spark, data, kcoreConfig) assert(kcoreResult.count() == 4) + + val encodeKcoreConfig = new KCoreConfig(10, 3, true) + assert(KCoreAlgo.apply(spark, data, encodeKcoreConfig).count() == 4) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/LabelPropagationAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/LabelPropagationAlgoSuite.scala index 8b7d585..7d27cf4 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/LabelPropagationAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/LabelPropagationAlgoSuite.scala @@ -12,7 +12,8 @@ import org.junit.Test class LabelPropagationAlgoSuite { @Test def lpaAlgoSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val lpaConfig = new LPAConfig(5) val result = LabelPropagationAlgo.apply(spark, data, lpaConfig, false) @@ -20,5 +21,8 @@ class LabelPropagationAlgoSuite { result.foreach(row => { assert(row.get(1).toString.toInt == 1) }) + + val encodeLpaConfig = new LPAConfig(5) + assert(LabelPropagationAlgo.apply(spark, data, encodeLpaConfig, false).count() == 4) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/PageRankAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/PageRankAlgoSuite.scala index 54ec0dc..6808fdc 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/PageRankAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/PageRankAlgoSuite.scala @@ -12,10 +12,14 @@ import org.junit.Test class PageRankAlgoSuite { @Test def pageRankSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val prConfig = new PRConfig(5, 1.0) val prResult = PageRankAlgo.apply(spark, data, prConfig, false) assert(prResult.count() == 4) + + val encodePrConfig = new PRConfig(5, 1.0, true) + assert(PageRankAlgo.apply(spark, data, encodePrConfig, false).count() == 5) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/SCCAlgoSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/SCCAlgoSuite.scala index a55fb5d..fcdfe7b 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/SCCAlgoSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/SCCAlgoSuite.scala @@ -12,10 +12,14 @@ import org.junit.Test class SCCAlgoSuite { @Test def sccAlgoSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val sccConfig = new CcConfig(5) val sccResult = StronglyConnectedComponentsAlgo.apply(spark, data, sccConfig, true) assert(sccResult.count() == 4) + + val encodeSccConfig = new CcConfig(5, true) + StronglyConnectedComponentsAlgo.apply(spark, data, encodeSccConfig, true) } } diff --git a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/TrangleCountSuite.scala b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/TrangleCountSuite.scala index b93f8d9..9068c30 100644 --- a/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/TrangleCountSuite.scala +++ b/nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/TrangleCountSuite.scala @@ -5,13 +5,15 @@ package com.vesoft.nebula.algorithm.lib +import com.vesoft.nebula.algorithm.config.TriangleConfig import org.apache.spark.sql.SparkSession import org.junit.Test class TrangleCountSuite { @Test def trangleCountSuite(): Unit = { - val spark = SparkSession.builder().master("local").getOrCreate() + val spark = + SparkSession.builder().master("local").config("spark.sql.shuffle.partitions", 5).getOrCreate() val data = spark.read.option("header", true).csv("src/test/resources/edge.csv") val trangleCountResult = TriangleCountAlgo.apply(spark, data) assert(trangleCountResult.count() == 4) @@ -19,5 +21,8 @@ class TrangleCountSuite { trangleCountResult.foreach(row => { assert(row.get(1) == 3) }) + + val triangleConfig = TriangleConfig(true) + assert(TriangleCountAlgo.apply(spark, data, triangleConfig).count() == 4) } }