In [0]:
import org.apache.spark.sql.SparkSession

val spark= SparkSession
    .builder()
    .appName("graphX")
    .getOrCreate()

In [1]:
import com.databricks.spark.xml._

val rawXML = spark.read.option("rowTag", "MedlineCitation").xml("hdfs://namenode:8020/medline/*.xml")

In [2]:
rawXML.printSchema()

In [3]:
import spark.implicits._
val meshHeadlingList = rawXML.select("MeshHeadingList.MeshHeading")

In [4]:
meshHeadlingList.printSchema()

In [5]:
val MeshHeadlingElems = meshHeadlingList.withColumn("data", explode($"MeshHeading")).select("data")

In [6]:
MeshHeadlingElems.printSchema()

In [7]:
val descriptorName = MeshHeadlingElems.select(MeshHeadlingElems.col("data.DescriptorName"))
descriptorName.printSchema()

In [8]:
val parsedDF = descriptorName.select(descriptorName.col("DescriptorName._MajorTopicYN"),
                                    descriptorName.col("DescriptorName._VALUE").as("topic"))

In [9]:
z.show(parsedDF)

In [10]:
val majorTopic = parsedDF.filter(col("_MajorTopicYN") === "Y")

In [11]:
z.show(majorTopic)

In [12]:
val topicDist = majorTopic.groupBy("topic").count()

In [13]:
z.show(topicDist.orderBy(desc("count")))

In [14]:
val topics = majorTopic.select("topic").rdd.map(el => el.getString(0).split(",").toList)

In [15]:
val onlyTopics =  topics.flatMap(mesh => mesh).toDF("topic")

In [16]:
val topicPairs = topics.flatMap(t => {t.sorted.combinations(2)}).toDF("pairs")
topicPairs.createOrReplaceTempView("topic_pairs")
val cooccurs = spark.sql("""
    SELECT pairs, COUNT(*) cnt
    FROM topic_pairs
    GROUP BY pairs""")


In [17]:
cooccurs.createOrReplaceTempView("cooccurs")
spark.sql("""
    SELECT pairs, cnt
    FROM cooccurs
    ORDER BY cnt DESC
    LIMIT 10""").collect().foreach(println)

In [18]:
import java.nio.charset.StandardCharsets
import java.security.MessageDigest

def hashID(str: String): Long = {
    val bytes = MessageDigest.getInstance("MD5").digest(str.getBytes(StandardCharsets.UTF_8))
    (bytes(0) & 0xFFL) |
    ((bytes(1) & 0xFFL) << 8)  |
    ((bytes(2) & 0xFFL) << 16) |
    ((bytes(3) & 0xFFL) << 24) | 
    ((bytes(4) & 0xFFL) << 32) |
    ((bytes(5) & 0xFFL) << 40) |
    ((bytes(6) & 0xFFL) << 48) |
    ((bytes(7) & 0xFFL) << 56)
}

In [19]:
import org.apache.spark.sql.Row

val vertices = onlyTopics.map{ case Row(topic: String) => (hashID(topic), topic) }.toDF("hash", "topic")
vertices.show(false)

In [20]:
import org.apache.spark.graphx._

val edges = cooccurs.map{ case Row(pairs: Seq[_], cnt: Long) =>
    val ids = pairs.map(_.toString).map(hashID).sorted
    Edge(ids(0), ids(1), cnt)
}

In [21]:
val vertexRDD = vertices.rdd.map{
    case Row(hash: Long, topic: String) => (hash, topic)
}
val topicGraph = Graph(vertexRDD, edges.rdd)
topicGraph.cache()

In [22]:
val connectedComponentGraph = topicGraph.connectedComponents()

In [23]:
val componentDF = connectedComponentGraph.vertices.toDF("vid", "cid")
componentDF.show(false)

In [24]:
componentDF.groupBy("vid").count().orderBy("count").show()

In [25]:
val componentCounts = componentDF.groupBy("cid").count()
componentCounts.count()

In [26]:
z.show(componentCounts.orderBy(desc("count")))

In [27]:
val testGraphConnect = componentDF.filter(col("cid") === "-9080598076986622448")
testGraphConnect.show

In [28]:
val joinExp = vertices.col("hash") === testGraphConnect.col("vid")
val joinWithVertexName = testGraphConnect.join(vertices, joinExp).distinct()

joinWithVertexName.show

In [29]:
z.show( topicDist.filter($"topic".contains("Cell")) )

In [30]:
val degrees: VertexRDD[Int] = topicGraph.degrees.cache()

In [31]:
val testVertexDegree = degrees.toDF("vertexID", "degree")

testVertexDegree.where(col("vertexID") === "-2024180124655511532").show
println(testVertexDegree.count)

In [32]:
testVertexDegree.describe("degree").show

In [33]:
val topicGraphVerteciesCount = topicGraph.vertices.count()

In [34]:
val singleTopics = topics.filter(x => x.size == 1)
singleTopics.count()

In [35]:
val singleTopicsDistinct = singleTopics.flatMap(topic => topic).distinct().toDS()
singleTopicsDistinct.count()

In [36]:
val singleTopicInPairs = topicPairs.flatMap(_.getAs[Seq[String]](0))
singleTopicsDistinct.except(singleTopicInPairs).count()

In [37]:
topicGraphVerteciesCount == (testVertexDegree.select("degree").count() + singleTopicsDistinct.except(singleTopicInPairs).count())

In [38]:
val namesAndDegrees = degrees.innerJoin(topicGraph.vertices) {
    (topicId, degree, name) => (name, degree.toInt)
}.values.toDF("topic", "degree")

In [39]:
z.show( namesAndDegrees.orderBy(desc("degree")) )

In [40]:
val T = majorTopic.count()
sc.broadcast(T)

In [41]:
val topicDistRdd = topicDist.map{
    case Row(topic: String, count: Long) => (hashID(topic), count)
}.rdd

In [42]:
val topicDistGraph = Graph(topicDistRdd, topicGraph.edges)

In [43]:
def chiSq(YY: Long, YB: Long, YA: Long, T: Long): Double = {
    val NB = T - YB
    val NA = T - YA
    val YN = YA - YY
    val NY = YB - YY
    val NN = T - NY - YN - YY
    val inner = math.abs(YY * NN - YN * NY) - T / 2.0
    T * math.pow(inner, 2) / (YA * NA * YB * NB)
}

In [44]:
// https://spark.apache.org/docs/latest/api/scala/org/apache/spark/graphx/EdgeTriplet.html
val topicDistGraphTriplet = topicDistGraph.triplets.map(triplet => 
    (triplet.srcAttr, triplet.srcId, triplet.attr, triplet.dstId, triplet.dstAttr))
    .toDF("srcId", "srcAttr",  "attr", "dstId", "dstAttr")
z.show(topicDistGraphTriplet)

In [45]:
val chiSquaredGraph = topicDistGraph.mapTriplets(triplet => {
    chiSq(triplet.attr, triplet.srcAttr, triplet.dstAttr, T)
})
chiSquaredGraph.edges.map(x => x.attr).stats()

In [46]:
val interesting = chiSquaredGraph.subgraph(
    triplet => triplet.attr > 19.5)
interesting.edges.count

In [47]:
val interestingComponentGraph = interesting.connectedComponents()
val icDF = interestingComponentGraph.vertices.toDF("vid", "cid")
val icCountDF = icDF.groupBy("cid").count()
icCountDF.count()

In [48]:
icCountDF.orderBy(desc("count")).show

In [49]:
val interestingDegrees = interesting.degrees.cache()
interestingDegrees.map(_._2).stats()

In [50]:
interestingDegrees.innerJoin(topicGraph.vertices) {
    (topicId, degree, name) => (name, degree)
}.values.toDF("topic", "degree").orderBy(desc("degree")).show

In [51]:
val triCountGraph = interesting.triangleCount()
triCountGraph.vertices.map(x => x._2).stats()

In [52]:
val maxTrisGraph = interestingDegrees.mapValues(d => d * (d-1) / 2.0)

In [53]:
val clusterCoef = triCountGraph.vertices.innerJoin(maxTrisGraph) { 
    (vertexId, triCount, maxTris) => {if (maxTris == 0) 0 else triCount / maxTris}
}

In [54]:
clusterCoef.map(_._2).sum() / interesting.vertices.count()

In [55]:
def mergeMaps(m1: Map[VertexId, Int], m2: Map[VertexId, Int]) : Map[VertexId, Int] = {
    def minThatExists(k: VertexId): Int = {
        math.min(m1.getOrElse(k, Int.MaxValue), m2.getOrElse(k, Int.MaxValue))
    }
    
    (m1.keySet ++ m2.keySet).map(k => (k, minThatExists(k))).toMap
}

In [56]:
def update(id: VertexId, state: Map[VertexId, Int], msg: Map[VertexId, Int]) = {
    mergeMaps(state, msg)
}

In [57]:
def checkIncrement(a: Map[VertexId, Int], b: Map[VertexId, Int], bid: VertexId) = {
    val aplus = a.map { case (v, d) => v -> (d+1) }
    if(b != mergeMaps(aplus, b)) {
        Iterator((bid, aplus))
    } else {
        Iterator.empty
    }
}

In [58]:
def iterate(e: EdgeTriplet[Map[VertexId, Int], _]) = {
    checkIncrement(e.srcAttr, e.dstAttr, e.dstId) ++
    checkIncrement(e.dstAttr, e.srcAttr, e.srcId)
}

In [59]:
val fraction = 0.02
val replacement = false
val sample = interesting.vertices.map(v => v._1).sample(replacement, fraction, 1729L)
val ids = sample.collect().toSet

In [60]:
val mapGraph = interesting.mapVertices((id, _) => {
    if (ids.contains(id)) {
        Map(id -> 0)
    } else {
        Map[VertexId, Int]()
    }
})

In [61]:
val start = Map[VertexId, Int]()
val res = mapGraph.pregel(start)(update, iterate, mergeMaps)

In [62]:
val paths = res.vertices.flatMap{ case(id, m) =>
    m.map { case(k, v) =>
        if (id < k) {
            (id, k, v)
        } else {
            (id, k, v)
        }
    }
}.distinct()
paths.cache()

In [63]:
val pathDF = paths.toDF("SrcVertexId", "DstVertexId", "PathLen")

In [64]:
z.show(pathDF)

In [65]:
pathDF.filter(col("PathLen") > 0).describe("PathLen").show()

In [66]:
z.show(pathDF.groupBy("PathLen").count())