In [24]:
"""
The traditional LSH appraoch to a hypothetical many-to-many document similarity task. 
The objective is to bucket similar documents together. The implementation is done through 
the `LSH` class which leverages on `dask.bag` functionality and methods to 
parallelize the banding technique. Specifically, the map (hash function) and reduce 
(bucketing) tasks.

Note: importing the model automatically initializes a dask client.

BY: Mike Dorosan, 2022
"""

import dask.bag as db

import numpy as np
import matplotlib.pyplot as plt


class LSH():
    """The LSH class for a many-to-many document similarity task.

    Attributes
    ----------
    signature : 2-D np.array
        document minhash signatures with dimension n (samples) by m (signature size)
    bands : int
        number of bands
    r : int
        number of rows per band derived from bands
    hash_functions : list, default=None
        a list of hash functions with size equivalent to the number of 
        bands. If None, the native python hash function is applied.
    band_dict : dict
        dictionary with band labels as keys and 
        (set/doc index, signature band) tuples as values  
    band_buckets : dict
        a dictionary with hash bucket as keys and a list of similar 
        document indices as values
    Methods
    -------

    """

    def __init__(self, signature):
        """Initialize class

        Parameters
        ----------
        signature : 2-D np.array, or dask.bag 
            document minhash signatures with dimension n (samples) by m (signature size)
            dask.bag of tuples (set/doc index, )
        """
        self.signature = signature
        self.bands = None  # number of bands
        self.r = None  # rows per band, band size
        self.hash_functions = None
        self.band_dict = {}
        self.band_buckets = {}

    def make_bands(self, bands):
        """Takes in the desired number of `bands` as a parameter and returns 
        a dictionary with band labels as keys and `dask.bag` of (set/document 
        index, signature band) tuples

        Parameters
        ----------
        bands : int 
            desired number of bands

        Returns
        -------
        band_dict : dict 
            dictionary with band labels as keys and 
            (set/doc index, signature band) tuples as values    
        """

        self.bands = bands
        
        if type(self.signature) == db.core.Bag:
            print("Input signature a dask bag.")
            signature_size = len(self.signature.take(1)[0][1]) # get size of signature
            assert signature_size % self.bands == 0, "Number of bands not a factor of signature size."
            self.r = int(signature_size / self.bands)
            
            for band_label, y in enumerate(range(0, signature_size, self.r)):
                band_bag = self.signature.map(lambda x: (x[0], np.array(x[1][i:i+self.r])))
                self.band_dict[band_label] = band_bag
            
        elif type(self.signature) == np.ndarray:
            # check if number of bands divide columns equally
            print("Input signature a numpy array.")
            signature_size = self.signature.shape[1]
            assert signature_size % self.bands == 0, "Number of bands not a factor of signature size."

            self.r = int(signature_size / self.bands)

            for band_label, i in enumerate(range(0, signature_size, self.r)):
                band_bag = db.from_sequence(
                    zip(range(signature_size),
                        self.signature[:, i:i+self.r]), npartitions=1)
                self.band_dict[band_label] = band_bag
                
        else:
            raise "Input signature not a dask.bag.core.Bag or a numpy.ndarry"


        return self.band_dict

    def get_buckets(self, hash_functions=None, compute=False):
        """This method implementes the map-reduce step of the traditional 
        banding technique. Specifically, signature slices of each band are 
        hashed using `hash_functions` (map). The document indices are then 
        grouped according to their hash values.

        Parameters
        ----------
        hash_functions : list, default=None
            a list of hash functions with size equivalent to the number of 
            bands. If None, the native python hash function is applied.

        Returns
        -------
        band_buckets - dict 
            a dictionary with hash bucket as keys and a list of similar 
            document indices as values    
        """

        self.hash_functions = hash_functions
        if not hash_functions:
            # use python's built-in hasher
            self.hash_functions = [hash]

        for index, (key, value) in enumerate(self.band_dict.items()):
            # add checks here for hash_functions type
            if len(self.hash_functions) > 1:
                idx = index
            else:
                idx = 0
            self.band_buckets[key] = (
                value.map(
                    lambda x: (
                        x[0],
                        self.hash_functions[idx](x[1].tobytes())
                    )
                )
                .groupby(lambda x: x[1])  # groupby hash value
                # get only document index
                .map(lambda x: (x[0], list(list(zip(*x[1]))[0])))
            )
            if compute:
                self.band_buckets[key] = self.band_buckets[key].compute()

        return self.band_buckets

    def _prob_of_s(self, s):
        """Return the probability of similarity s given b and r"""
        return 1 - (1 - s**self.r)**self.bands

    def _get_approx_thresh(self):
        """Return approximate similarity threshold for chosen b and r"""
        thresh = (1/self.bands) ** (1/self.r)

        return thresh

    def plot_thresh(self, display_thresh=True, ax=None, **kwargs):
        """Plots the threshold plot according to number of bands.

        Parameters
        ----------
        display_thresh : bool, default=True 
            whether to display emphasis on the similarity threshold or not.

        ax : matplotlib.pyplot Axis, default=None 
            Axis for plotting. If None, use internally generated Axis object.

        **kwargs : keyword arguments for the matplotlib.pyplot.plot() function.

        Returns
        -------
        ax : matplotlib.pyplot Axis object
        """
        s_list = np.linspace(0, 1, num=50)
        p_list = np.array([self._prob_of_s(s) for s in s_list])

        if ax is None:
            fig, ax = plt.subplots(figsize=(10, 5))

        ax.plot(s_list, p_list, **kwargs)

        if display_thresh:
            thresh = self._get_approx_thresh()
            ax.axvline(thresh, color='black', linestyle='--',
                       label=f'Similarity Threshold: {thresh:.2f}')

        ax.set_title('Probability of becoming a candidate given a similarity\nThe S-curve',
                     fontsize=15)
        ax.set_ylabel('Probability', fontsize=13)
        ax.set_xlabel('Jaccard Similarity of Documents', fontsize=13)
        ax.legend()
        self.ax_ = ax
        return self.ax_


In [2]:
from sklearn.datasets import fetch_20newsgroups

# Load the news group dataset without the headers, footers and quotes
newsgroup = fetch_20newsgroups(
    subset='train', remove=('headers', 'footers', 'quotes'))
newsgroup_data = newsgroup['data']

In [3]:
from dask.distributed import Client
import dask.bag as db

client = Client()

Perhaps you already have a cluster running?
Hosting the HTTP server on port 46863 instead


In [4]:
newsgroup_bag = db.from_sequence(
    zip(range(len(newsgroup_data)), newsgroup_data))

In [5]:
import re


def clean_text(text):
    """Clean text by removing non-alphanumeric characters and replacing
    all blank space characters into a single space
    """
    return (re.sub(r'\s', r' ', re.sub(r'[^\w\s]', r'', text)).lower())

In [6]:
newsgroup_bag_cleaned = newsgroup_bag.map(lambda x: (x[0], clean_text(x[1])))

In [7]:
import sys
sys.path.append('/home/mdorosan/2022/alis')

from alis.feature_extraction import MinhashLSH

In [8]:
minhasher = MinhashLSH(shingle_size=3, num_shingle_bucket=12, num_hash=10,
                       hash_size=2**12)

In [10]:
newsgroup_signatures = minhasher.transform(newsgroup_bag_cleaned)

In [11]:
newsgroup_signatures.take(10)

distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 195, in msgpack._cmsgpack.unpackb
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 475, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 391, in deserialize
    deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 407, in deserialize
    return loads(header, frames)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 86, in pickle_loads
  

distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 195, in msgpack._cmsgpack.unpackb
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 475, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 391, in deserialize
    deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 407, in deserialize
    return loads(header, frames)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 86, in pickle_loads
  

KilledWorker: ("('lambda-take-a196030acdc87e23370f2c05c41978b7', 0)", <WorkerState 'tcp://10.233.69.154:39385', name: 5, memory: 0, processing: 1>)

In [14]:
type(newsgroup_signatures) == db.core.Bag

True

In [15]:
type(np.array([1, 2]))

numpy.ndarray

In [25]:
lsh = LSH(newsgroup_signatures)
lsh.make_bands(bands=2)
print("Rows per band: ", lsh.r)
print("Number of bands: ", lsh.bands)
buckets = lsh.get_buckets()
print("Group of buckets: ", len(buckets.keys()))

display(buckets)

Input signature a dask bag.


distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 195, in msgpack._cmsgpack.unpackb
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 475, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 391, in deserialize
    deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 407, in deserialize
    return loads(header, frames)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 86, in pickle_loads
  

distributed.protocol.core - CRITICAL - Failed to deserialize
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 111, in loads
    return msgpack.loads(
  File "msgpack/_unpacker.pyx", line 195, in msgpack._cmsgpack.unpackb
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/core.py", line 103, in _decode_default
    return merge_and_deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 475, in merge_and_deserialize
    return deserialize(header, merged_frames, deserializers=deserializers)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 391, in deserialize
    deserialize(
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 407, in deserialize
    return loads(header, frames)
  File "/opt/conda/lib/python3.9/site-packages/distributed/protocol/serialize.py", line 86, in pickle_loads
  

KilledWorker: ("('lambda-take-d59ef40784b0e980e319c23fcf41151d', 0)", <WorkerState 'tcp://10.233.69.154:46767', name: 1, memory: 0, processing: 1>)