In [1]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.4.1/spark-3.4.1-bin-hadoop3.tgz
!tar xf spark-3.4.1-bin-hadoop3.tgz
!pip install -q findspark

In [1]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.4.1-bin-hadoop3"

In [2]:
import findspark
findspark.init()

In [3]:
!python --version

Python 3.11.11



# MinHashLSH Algorithm

In [4]:
import numpy as np
import random
from pyspark import StorageLevel
from pyspark.sql import SparkSession, DataFrame, Row
from pyspark.ml.linalg import Vectors, VectorUDT
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import ArrayType, StringType, IntegerType, FloatType

import matplotlib.pyplot as plt
import time

In [5]:
spark = SparkSession.builder \
                    .appName("Progress 2") \
                    .getOrCreate()

documents = spark.read.text('WebOfScience-5736.txt', lineSep="\n")

In [6]:
class MinHashLSH:
    def __init__(self,
                 documents: DataFrame,
                 numHashTables: int = 100,
                 numBands: int = 50,
                 k: int = 4):
        self.documents = self.processDocuments(documents)

        self.vocab = None
        self.vocab_size = None
        self.bool_vectors = None
        self.buckets = None

        # config
        self.k = k
        self.numBands = numBands
        self.HASH_PRIME = 2038074743
        self.numHashTables = numHashTables

    @staticmethod
    def processDocuments(documents):
        documents_id = documents.rdd.zipWithIndex() \
                      .map(lambda row_index: (row_index[1],) + tuple(row_index[0])) \
                      .toDF(['id', 'value'])

        documents_id = documents_id.withColumn('value', F.lower(F.col('value'))) \
                                    .withColumn("value", F.regexp_replace(F.col("value"), '[^a-zA-Z0-9\s]', ''))


        return documents_id

    def build_vocab(self, shingle: DataFrame) -> DataFrame:
        # Remove duplicate shingle
        windowSpec = Window.orderBy('shingle')
        vocab = shingle.select('shingle').distinct()
        vocab = vocab.withColumn('shingle_id', F.row_number().over(windowSpec) - 1)
        self.vocab_size = vocab.count()

        return vocab

    def multiHotEncoding(self, shingle):
        vocab = self.vocab
        vocab_size = self.vocab_size

        sparse_vector = shingle.join(vocab, on='shingle', how='left') \
                                .groupby('id') \
                                .agg(F.collect_set('shingle_id').alias('shingle_ids'))

        # Define a UDF to convert a list to a one-hot encoded SparseVector
        @F.udf(returnType=VectorUDT())
        def one_hot_encode(ids):
            return Vectors.sparse(vocab_size, sorted(ids), [1]*len(ids))


        boolean_vectors = sparse_vector.withColumn("features", one_hot_encode(F.col('shingle_ids')))
        return boolean_vectors.select('id', 'features')


    def shingling(self, documents: DataFrame) -> DataFrame:
        k = self.k

        # Define a UDF to generate shingles
        shingle_udf = F.udf(lambda text: [text[i:i+k] for i in range(len(text) - k + 1)], ArrayType(StringType()))

        # Generate shingles and build vocabulary
        shingles = documents.withColumn('shingles', shingle_udf(F.col('value')))
        shingle = shingles.select('id', F.explode('shingles').alias('shingle'))
        self.vocab = F.broadcast(self.build_vocab(shingle))

        # Convert to boolean vectors
        boolean_vectors = self.multiHotEncoding(shingle)
        return boolean_vectors

    def minhashing(self, bool_vectors) -> DataFrame:
        random.seed(1205)
        prime = self.HASH_PRIME

        randCoefs = [
                      (1 + random.randint(0, self.HASH_PRIME - 1), random.randint(0, self.HASH_PRIME - 1))
                      for _ in range(self.numHashTables)
                    ]

        @F.udf(returnType=ArrayType(VectorUDT()))
        def minHash(features):
            indices = features.indices.tolist()
            min_hashes = []
            for a, b in randCoefs:
                hash_vals = [(a*(i + 1) + b) % prime for i in indices]
                min_hash = min(hash_vals)
                min_hashes.append(Vectors.dense([min_hash]))
            return min_hashes

        signature = bool_vectors.withColumn('hashes', minHash(F.col('features')))
        return signature

    def locality_sensity_hashing(self, signatures: list) -> DataFrame:
        HASH_PRIME = self.HASH_PRIME
        numBands = self.numBands

        assert numBands < self.numHashTables
        # Number of signature values in each band
        signa_length = self.numHashTables
        r = int(self.numHashTables / numBands)

        # Define a UDF to split the signatures into bands
        @F.udf(ArrayType(IntegerType()))
        def hash_bands(signature):
            hashed_bands = []
            for i in range(0, signa_length, r):
                band = signature[i:i + r]
                hashed_bands.append(hash(tuple(band)) % HASH_PRIME)
            return hashed_bands

        # Hash the bands into buckets
        band_buckets = signatures.withColumn('bucketID', hash_bands(F.col('hashes')))
        band_buckets = band_buckets.withColumnRenamed('hashes', 'signature')

        return band_buckets

    def run(self):
        self.bool_vectors = self.shingling(self.documents).persist(StorageLevel.MEMORY_AND_DISK)

        signature = self.minhashing(self.bool_vectors)

        self.buckets = self.locality_sensity_hashing(signature).persist(StorageLevel.MEMORY_AND_DISK)
        self.buckets.count()

    def approxNearestNeighbors(self, key, numLim):
        key = spark.createDataFrame([(key,)], ["value"])
        key = self.processDocuments(key)

        # Shingling step
        k = self.k
        shingle_udf = F.udf(lambda text: [text[i:i+k] for i in range(len(text) - k + 1)], ArrayType(StringType()))
        shingles = key.withColumn('shingles', shingle_udf(F.col('value')))
        shingle = shingles.select('id', F.explode('shingles').alias('shingle'))

        key_shingles = self.multiHotEncoding(shingle)

        # Min Hash step
        key_signature = self.minhashing(key_shingles)

        # LSH
        key_bucket = self.locality_sensity_hashing(key_signature)
        key_bucket = key_bucket.withColumnRenamed('bucketID', 'key_bucketID') \
                              .withColumnRenamed('id', 'key_id') \
                              .withColumnRenamed('features', 'key_features')


        # Find pairs
        candidatePairs = self.buckets.crossJoin(F.broadcast(key_bucket))
        candidatePairs = candidatePairs.filter(F.arrays_overlap(F.col("bucketID"), F.col("key_bucketID")))

        @F.udf(FloatType())
        def Jaccard_Similarity(candidate, key):
            # Convert SparseVectors to sets
            set1 = set(candidate.indices)
            set2 = set(key.indices)

            # Calculate intersection and union
            intersection  = len(set1.intersection(set2))
            union = len(set1.union(set2))

            # Calculate Jaccard similarity
            return intersection / union if union != 0 else 0.0

        nearestNeighbour = candidatePairs.withColumn('distCol', Jaccard_Similarity(F.col('features'), F.col('key_features')))

        # Return the top n nearest neighbors
        nearestNeighbour = nearestNeighbour.select('id', 'features', 'distCol') \
                                            .sort(F.col('distCol').desc()) \
                                            .limit(numLim)

        return nearestNeighbour

In [7]:
mh = MinHashLSH(documents)

start_time = time.time()
signature = mh.run()
end_time = time.time()

running_time = end_time - start_time
running_time

405.8739447593689

In [8]:
document1 = """
Australian pathosystems, providing the first comprehensive compilation of information
for this continent, covering the phytoplasmas, host plants, vectors and diseases.
Of the 33 16Sr groups reported internationally
"""

In [9]:
result_1 = mh.approxNearestNeighbors(document1, numLim = 3)
result_1.show()

+----+--------------------+-----------+
|  id|            features|    distCol|
+----+--------------------+-----------+
|   0|(103134,[706,1228...| 0.17419963|
|5473|(103134,[2604,260...|0.093147755|
|3527|(103134,[2725,273...| 0.08979592|
+----+--------------------+-----------+



In [10]:
document2 = """
Phytoplasmas are insect-vectored bacteria that cause disease in a wide range of plant species.
The increasing availability of molecular DNA analyses, expertise and additional methods in recent years
has led to a proliferation of discoveries of phytoplasma-plant host associations and in
the numbers of taxonomic groupings for phytoplasmas. The widespread use of common names based on
the diseases with which they are associated, as well as separate phenetic and taxonomic systems
for classifying phytoplasmas based on variation at the 16S rRNA-encoding gene, complicates interpretation of the literature.
We explore this issue and related trends through a focus on Australian pathosystems, providing the first comprehensive compilation of information
for this continent, covering the phytoplasmas, host plants, vectors and diseases.
Of the 33 16Sr groups reported internationally, only groups I, II, III, X, XI and XII have been recorded in Australia
and this highlights the need for ongoing biosecurity measures to prevent the introduction of additional pathogen groups.
Many of the phytoplasmas reported in Australia have not been sufficiently well studied
to assign them to 16Sr groups so it is likely that unrecognized groups and sub-groups are present.
Wide host plant ranges are apparent among well studied phytoplasmas,
with multiple crop and non-crop species infected by some. Disease management is further complicated by the fact
that putative vectors have been identified for few phytoplasmas, especially in Australia.
Despite rapid progress in recent years using molecular approaches, phytoplasmas remain
the least well studied group of plant pathogens, making them a "crouching tiger" disease threat.
"""

In [11]:
result_2 = mh.approxNearestNeighbors(document2, numLim = 10)
result_2.show()

+----+--------------------+----------+
|  id|            features|   distCol|
+----+--------------------+----------+
|   0|(103134,[706,1228...| 0.9830508|
|2532|(103134,[729,2604...|0.20167653|
|1494|(103134,[20,605,9...|0.19393939|
|5081|(103134,[2603,261...|0.19218911|
|3778|(103134,[256,2459...|0.18880779|
|2234|(103134,[729,2607...|0.18124643|
| 994|(103134,[2604,261...|0.18037602|
|3193|(103134,[2604,260...| 0.1797235|
|3553|(103134,[2603,260...|0.17739318|
|5043|(103134,[882,1714...|0.17715618|
+----+--------------------+----------+

