In [1]:
# !pip install --upgrade pip --quiet
# !pip install pyspark --quiet
# !pip install -U -q PyDrive --quiet
# !pip install numpy pandas --quiet

In [2]:
# import os

# os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64/"

In [3]:
import platform

MASTER, FILE_PATH, SPARK_CONFIG = (
    ("spark://master:7077", "hdfs://master:9000/user/tudny/", {})
    if platform.node().startswith("driver")
    else (
        "local[*]",
        "",
        {"spark.driver.memory": "16g", "spark.executor.memory": "16g"},
    )
)
print(MASTER, FILE_PATH, SPARK_CONFIG)

spark://master:7077 hdfs://master:9000/user/tudny/ {}


In [4]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    ArrayType,
    IntegerType,
    StructType,
    StructField,
)

spark = (
    SparkSession.builder.master(MASTER)
    .appName("PDD PZ1")
    .config("spark.ui.port", "4050")
    .config(map=SPARK_CONFIG)
    .getOrCreate()
)

sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/14 21:15:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
from datetime import datetime

start_time = datetime.now()

In [6]:
fasta = spark.read.json(f"{FILE_PATH}data/fasta/*.json")



In [7]:
SHINGLES = 5
BANDS = 20
ROWS = 5

In [8]:
import numpy as np

NO_HASH_FUNCTIONS = BANDS * ROWS
BIG_ENOUGH_PRIME = 1000000033
BASE = 26
np.random.seed(42)


BASE_VECTOR = sc.broadcast(np.array([BASE**i for i in range(SHINGLES - 1, -1, -1)]))


def make_hash_functions(no_hash_functions: int, max_value: int):
    return list(
        zip(
            np.random.randint(1, max_value, (no_hash_functions,)),
            np.random.randint(0, max_value, (no_hash_functions,)),
        )
    )


hash_functions_params = sc.broadcast(
    make_hash_functions(NO_HASH_FUNCTIONS, BIG_ENOUGH_PRIME)
)

In [9]:
def _lsh(value: str) -> list[(int, int)]:
    _value = (np.frombuffer(value.encode(), dtype=np.uint8) - ord("A")).astype(np.int64)
    _almost_shingles_vector_not_even_close = np.array(
        [_value[i : i + len(_value) - SHINGLES + 1] for i in range(SHINGLES)]
    )
    indexes = BASE_VECTOR.value @ _almost_shingles_vector_not_even_close
    min_hashing = np.array(
        [
            np.min((indexes * a + b) % BIG_ENOUGH_PRIME)
            for a, b in hash_functions_params.value
        ]
    )
    return [
        (i, hash(tuple(min_hashing[b : b + ROWS])))
        for i, b in enumerate(range(0, 100, ROWS))
    ]

In [10]:
lsh = F.udf(
    _lsh,
    ArrayType(
        StructType(
            [StructField("band", IntegerType()), StructField("hash", IntegerType())]
        )
    ),
)
name_lsh_pairs = (
    fasta.withColumn("lsh_ed", lsh(F.col("value")))
    .withColumn("lsh", F.explode(F.col("lsh_ed")))
    .select("name", "lsh")
)

In [11]:
all_candidates = (
    name_lsh_pairs.withColumnRenamed("name", "name_1")
    .join(name_lsh_pairs.withColumnRenamed("name", "name_2"), on="lsh")
    .filter(F.col("name_1") < F.col("name_2"))
    .select("name_1", "name_2")
    .distinct()
)

In [12]:
def n_choose_2(_n: int) -> int:
    return _n * (_n - 1) // 2


group_definitions = (
    spark.read.json(f"{FILE_PATH}data/group_definition.json")
    .rdd.flatMap(lambda x: x.asDict().items())
    .toDF(["group", "names"])
)

group_mapping = group_definitions.withColumn("name", F.explode(F.col("names"))).drop(
    "names"
)

group_counts = group_definitions.withColumn("count", F.size(F.col("names"))).withColumn(
    "n_choose_2", F.floor(F.col("count") * (F.col("count") - 1) / 2)
)

all_pairs_count = n_choose_2(group_counts.agg(F.sum("count")).collect()[0][0])
all_matching_pairs_count = group_counts.agg(F.sum("n_choose_2")).collect()[0][0]
all_non_matching_pairs_count = all_pairs_count - all_matching_pairs_count

print(f"Total pairs: {all_pairs_count}")
print(f"Matching pairs: {all_matching_pairs_count}")
print(f"Non-matching pairs: {all_non_matching_pairs_count}")

group_mapping_1 = group_mapping.withColumnRenamed("name", "name_1").withColumnRenamed(
    "group", "group_1"
)
group_mapping_2 = group_mapping.withColumnRenamed("name", "name_2").withColumnRenamed(
    "group", "group_2"
)



Total pairs: 315293716
Matching pairs: 31616541
Non-matching pairs: 283677175


                                                                                

In [13]:
all_candidates_mapped_to_clusters = all_candidates.join(
    group_mapping_1, on="name_1"
).join(group_mapping_2, on="name_2")

In [14]:
count_calculated, matching_count_calculated = all_candidates_mapped_to_clusters.select(
    F.count("*"), F.count(F.when(F.col("group_1") == F.col("group_2"), 1))
).collect()[0]
non_matching_count_calculated = count_calculated - matching_count_calculated



In [15]:
print("=" * 100)
print(f"[Calc] Total count: {count_calculated}")
print(f"[Calc] Matching count: {matching_count_calculated}")
print(f"[Calc] Non-matching count: {non_matching_count_calculated}")
print("=" * 100)
print(f"[Real] Total from reference data: {all_pairs_count}")
print(f"[Real] Total from reference data: {all_matching_pairs_count}")
print(f"[Real] Total from reference data: {all_non_matching_pairs_count}")
print("=" * 100)
print(f"True positive rate: {matching_count_calculated / all_matching_pairs_count}")
print(f"False positive rate: {non_matching_count_calculated / all_non_matching_pairs_count}")
print("=" * 100)

[Calc] Total count: 9316924
[Calc] Matching count: 9025012
[Calc] Non-matching count: 291912
[Real] Total from reference data: 315293716
[Real] Total from reference data: 31616541
[Real] Total from reference data: 283677175
True positive rate: 0.28545222578270024
False positive rate: 0.0010290288600060967


In [16]:
end_time = datetime.now()
time_diff = end_time - start_time

print(
    f"Execution took {time_diff.seconds // 60} minutes and {time_diff.seconds % 60} seconds"
)

Execution took 17 minutes and 48 seconds


On my personal PC this takes ~1.5 minutes to execute.