From 099c8c9ff3cdf786a082f8ea2a4def21a8a229d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=90=D1=80=D1=82=D1=91=D0=BC=20=D0=9A=D0=BE=D1=80=D1=81?= =?UTF-8?q?=D0=B0=D0=BA=D0=BE=D0=B2?= Date: Sat, 18 Nov 2023 08:27:22 +0300 Subject: [PATCH] Refactor getOrElseBoolean --- .../nebula/algorithm/config/AlgoConfig.scala | 45 ++++++++----------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala index cba8814..0fe4fa4 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala @@ -26,7 +26,7 @@ object PRConfig { if (prConfig.contains("algorithm.pagerank.resetProb")) prConfig("algorithm.pagerank.resetProb").toDouble else 0.15 - encodeId = ConfigUtil.getOrElseBoolean(prConfig, "algorithm.pagerank.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(prConfig, "algorithm.pagerank.encodeId") PRConfig(maxIter, resetProb, encodeId) } } @@ -44,7 +44,7 @@ object LPAConfig { val lpaConfig = configs.algorithmConfig.map maxIter = lpaConfig("algorithm.labelpropagation.maxIter").toInt - encodeId = ConfigUtil.getOrElseBoolean(lpaConfig, "algorithm.labelpropagation.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(lpaConfig, "algorithm.labelpropagation.encodeId") LPAConfig(maxIter, encodeId) } } @@ -60,7 +60,7 @@ object CcConfig { def getCcConfig(configs: Configs): CcConfig = { val ccConfig = configs.algorithmConfig.map - encodeId = ConfigUtil.getOrElseBoolean(ccConfig, "algorithm.connectedcomponent.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(ccConfig, "algorithm.connectedcomponent.encodeId") maxIter = ccConfig("algorithm.connectedcomponent.maxIter").toInt CcConfig(maxIter, encodeId) @@ -80,7 +80,7 @@ object ShortestPathConfig { val spConfig = configs.algorithmConfig.map landmarks = spConfig("algorithm.shortestpaths.landmarks").split(",").toSeq.map(_.toLong) - encodeId = ConfigUtil.getOrElseBoolean(spConfig, "algorithm.shortestpaths.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(spConfig, "algorithm.shortestpaths.encodeId") ShortestPathConfig(landmarks, encodeId) } } @@ -102,7 +102,7 @@ object LouvainConfig { maxIter = louvainConfig("algorithm.louvain.maxIter").toInt internalIter = louvainConfig("algorithm.louvain.internalIter").toInt tol = louvainConfig("algorithm.louvain.tol").toDouble - encodeId = ConfigUtil.getOrElseBoolean(louvainConfig, "algorithm.louvain.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(louvainConfig, "algorithm.louvain.encodeId") LouvainConfig(maxIter, internalIter, tol, encodeId) } @@ -118,7 +118,7 @@ object DegreeStaticConfig { def getDegreeStaticConfig(configs: Configs): DegreeStaticConfig = { val degreeConfig = configs.algorithmConfig.map - encodeId = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(degreeConfig, "algorithm.degreestatic.encodeId") DegreeStaticConfig(encodeId) } } @@ -132,8 +132,7 @@ object TriangleConfig { var encodeId: Boolean = false def getTriangleConfig(configs: Configs): TriangleConfig = { val triangleConfig = configs.algorithmConfig.map - encodeId = - ConfigUtil.getOrElseBoolean(triangleConfig, "algorithm.trianglecount.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(triangleConfig, "algorithm.trianglecount.encodeId") TriangleConfig(encodeId) } } @@ -152,7 +151,7 @@ object KCoreConfig { val kCoreConfig = configs.algorithmConfig.map maxIter = kCoreConfig("algorithm.kcore.maxIter").toInt degree = kCoreConfig("algorithm.kcore.degree").toInt - encodeId = ConfigUtil.getOrElseBoolean(kCoreConfig, "algorithm.kcore.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(kCoreConfig, "algorithm.kcore.encodeId") KCoreConfig(maxIter, degree, false) } } @@ -169,8 +168,7 @@ object BetweennessConfig { def getBetweennessConfig(configs: Configs): BetweennessConfig = { val betweennessConfig = configs.algorithmConfig.map maxIter = betweennessConfig("algorithm.betweenness.maxIter").toInt - encodeId = - ConfigUtil.getOrElseBoolean(betweennessConfig, "algorithm.betweenness.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(betweennessConfig, "algorithm.betweenness.encodeId") BetweennessConfig(maxIter, encodeId) } } @@ -190,9 +188,8 @@ object CoefficientConfig { algoType = coefficientConfig("algorithm.clusteringcoefficient.type") assert(algoType.equalsIgnoreCase("local") || algoType.equalsIgnoreCase("global"), "ClusteringCoefficient only support local or global type.") - encodeId = ConfigUtil.getOrElseBoolean(coefficientConfig, - "algorithm.clusteringcoefficient.encodeId", - false) + encodeId = + ConfigUtil.getOrElseBoolean(coefficientConfig, "algorithm.clusteringcoefficient.encodeId") CoefficientConfig(algoType, encodeId) } } @@ -210,7 +207,7 @@ object BfsConfig { val bfsConfig = configs.algorithmConfig.map maxIter = bfsConfig("algorithm.bfs.maxIter").toInt root = bfsConfig("algorithm.bfs.root").toString - encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId") BfsConfig(maxIter, root, encodeId) } } @@ -228,7 +225,7 @@ object DfsConfig { val dfsConfig = configs.algorithmConfig.map maxIter = dfsConfig("algorithm.dfs.maxIter").toInt root = dfsConfig("algorithm.dfs.root").toString - encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId") DfsConfig(maxIter, root, encodeId) } } @@ -251,7 +248,7 @@ object HanpConfig { hopAttenuation = hanpConfig("algorithm.hanp.hopAttenuation").toDouble maxIter = hanpConfig("algorithm.hanp.maxIter").toInt preference = hanpConfig("algorithm.hanp.preference").toDouble - encodeId = ConfigUtil.getOrElseBoolean(hanpConfig, "algorithm.hanp.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(hanpConfig, "algorithm.hanp.encodeId") HanpConfig(hopAttenuation, maxIter, preference, encodeId) } } @@ -306,7 +303,7 @@ object Node2vecConfig { degree = node2vecConfig("algorithm.node2vec.degree").toInt embSeparate = node2vecConfig("algorithm.node2vec.embSeparate") modelPath = node2vecConfig("algorithm.node2vec.modelPath") - encodeId = ConfigUtil.getOrElseBoolean(node2vecConfig, "algorithm.node2vec.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(node2vecConfig, "algorithm.node2vec.encodeId") Node2vecConfig(maxIter, lr, dataNumPartition, @@ -336,7 +333,7 @@ object JaccardConfig { def getJaccardConfig(configs: Configs): JaccardConfig = { val jaccardConfig = configs.algorithmConfig.map tol = jaccardConfig("algorithm.jaccard.tol").toDouble - encodeId = ConfigUtil.getOrElseBoolean(jaccardConfig, "algorithm.jaccard.encodeId", false) + encodeId = ConfigUtil.getOrElseBoolean(jaccardConfig, "algorithm.jaccard.encodeId") JaccardConfig(tol, encodeId) } } @@ -351,12 +348,6 @@ object AlgoConfig { } object ConfigUtil { - def getOrElseBoolean(config: Map[String, String], key: String, defaultValue: Boolean): Boolean = { - if (config.contains(key)) { - config(key).toBoolean - } else { - defaultValue - } - } - + def getOrElseBoolean(config: Map[String, String], key: String): Boolean = + config.get(key).exists(_.toBoolean) }