# Exercise 2

## Imports

In [52]:
# Shingle size
k = 9

# Number of bands
b = 13

# Number of rows per band
r = 11

# Min-hash: number of hash functions
num_functions = b*r

# Seed for the random number generator
seed = 123

# Similarity threshold
similarity_threshold = 0.85

In [53]:
random.seed(seed)

## Spark Initialization

In [54]:
spark = SparkSession.builder \
    .appName('LSH') \
    .config('spark.master', 'local[*]') \
    .getOrCreate()

## Prepare the Data

In [55]:
# TODO: Configure partitions for speedup?
df = spark.read \
    .option('header', True) \
    .json('./data/covid_news_small.json.bz2')

                                                                                

## Pipeline

### Generate shingles

In [56]:
# TODO: ignore punctuation? use different shingling strategy? (see 'Further fun' slide, which is the last, of 3b)
@F.udf(returnType=ArrayType(IntegerType(), False))
def generate_shingles(text: str):
    shingles = (text[idx:idx+k] for idx in range(len(text) - k + 1))
    # Get last 32 bits in order to have 4-byte integers (Python allows arbitrarily large integers)
    to_integer = lambda s: hash(s) & ((1 << 32) - 1)
    return sorted(set(to_integer(shingle_str) for shingle_str in shingles))

In [57]:
df_shingles = df \
    .drop('url') \
    .withColumn('shingles', generate_shingles('text')) \
    .drop('text')

### Min-hash

In [58]:
# Assumes the values to hash are 4-byte integers
def generate_universal_hash_family(K: int) -> List[Callable[[int], int]]:
    N = 1 << 32
    p = 2305843009213693951

    parameters = set()
    while (len(parameters) < K):
        parameters |= {(random.randint(1, N), random.randint(0, N)) for _ in range(K - len(parameters))}
    
    return [partial(lambda x, a, b, p, N: ((a * x + b) % p) % N, a=a, b=b, p=p, N=N) for a, b in parameters]

In [59]:
hash_family = generate_universal_hash_family(num_functions)

In [60]:
@F.udf(returnType=ArrayType(IntegerType(), False))
def calculate_min_hash(shingles: List[int]):
    return [min(h(shingle) for shingle in shingles) for h in hash_family]

In [61]:
df_minhash = df_shingles.withColumn('min_hash', calculate_min_hash('shingles')).drop('shingles')

### LSH

In [62]:
@F.udf(returnType=ArrayType(ArrayType(IntegerType(), False), False))
def generate_even_slices(minhashes: List[int]):
    return [minhashes[i:i+b] for i in range(0, num_functions, b)]

In [63]:
df_bands = df_minhash \
    .withColumn('min_hash_slices', generate_even_slices('min_hash')) \
    .select('tweet_id', *(F.hash(F.col('min_hash_slices')[band]).alias(f'band_{band}') for band in range(b))) 

In [64]:
df_bands_lst = [
    df_bands
        .select('tweet_id', f'band_{band}')
        .groupby(f'band_{band}')
        .agg(F.collect_list('tweet_id'))
        .withColumnRenamed('collect_list(tweet_id)', 'candidates')
        .withColumn('candidates', F.array_sort('candidates'))
    for band in range(b)
]

In [65]:
@F.udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [66]:
df_bands_lst = [
    d.select(F.explode(combine_pairs('candidates')).alias('candidate_pair'))
    for d in df_bands_lst
]

In [67]:
df_candidate_pairs = spark.createDataFrame([], schema=StructType([StructField(name='candidate_pair', dataType=ArrayType(StringType(), False), nullable=False)]))

for d in df_bands_lst:
    df_candidate_pairs = df_candidate_pairs.union(d)

df_candidate_pairs = df_candidate_pairs.distinct()

Remove false positives.

In [70]:
df_candidate_pairs = df_candidate_pairs \
    .join(df_minhash, df_minhash['tweet_id'] == F.col('candidate_pair')[0]) \
    .withColumnRenamed('min_hash', 'min_hash_0') \
    .drop('tweet_id') \
    .join(df_minhash, df_minhash['tweet_id'] == F.col('candidate_pair')[1]) \
    .withColumnRenamed('min_hash', 'min_hash_1') \
    .drop('tweet_id') \
    .withColumn('similarity', F.size(F.array_intersect('min_hash_0', 'min_hash_1')) / F.size(F.array_union('min_hash_0', 'min_hash_1'))) \
    .filter(F.col('similarity') >= similarity_threshold)

Save results.

In [71]:
fname_candidate_pairs = f'candidate_pairs_{r}_{b}'
if not os.path.exists(fname_candidate_pairs):
    df_candidate_pairs.write.mode('overwrite').parquet(path=fname_candidate_pairs, compression='gzip')

df_candidate_pairs = spark.read.parquet(fname_candidate_pairs)

[Stage 4:>    (0 + 2) / 2][Stage 5:>    (0 + 2) / 2][Stage 6:>    (0 + 2) / 2]2]

23/03/18 18:08:29 ERROR Executor: Exception in task 1.0 in stage 4.0 (TID 9)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in calculate_min_hash
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in <listcomp>
ValueError: min() arg is an empty sequence

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:552)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:86)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:68)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:505)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.has

[Stage 4:>    (0 + 1) / 2][Stage 5:>    (0 + 1) / 2][Stage 6:>    (0 + 2) / 2]

23/03/18 18:08:29 WARN TaskSetManager: Lost task 1.0 in stage 5.0 (TID 11) (localhost executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in calculate_min_hash
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in <listcomp>
ValueError: min() arg is an empty sequence

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:552)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:86)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$2.read(PythonUDFRunner.scala:68)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:505)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.colle

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in calculate_min_hash
  File "/tmp/ipykernel_8845/1550504526.py", line 3, in <listcomp>
ValueError: min() arg is an empty sequence


23/03/18 18:08:30 WARN TaskSetManager: Lost task 1.0 in stage 6.0 (TID 13) (localhost executor driver): TaskKilled (Stage cancelled)


[Stage 4:>    (0 + 1) / 2][Stage 5:>    (0 + 1) / 2][Stage 6:>    (0 + 1) / 2]

23/03/18 18:08:32 WARN PythonUDFRunner: Incomplete task 0.0 in stage 6 (TID 12) interrupted: Attempting to kill Python Worker
23/03/18 18:08:32 WARN PythonUDFRunner: Incomplete task 0.0 in stage 5 (TID 10) interrupted: Attempting to kill Python Worker
23/03/18 18:08:32 WARN PythonUDFRunner: Incomplete task 0.0 in stage 4 (TID 8) interrupted: Attempting to kill Python Worker
23/03/18 18:08:32 WARN TaskSetManager: Lost task 0.0 in stage 5.0 (TID 10) (localhost executor driver): TaskKilled (Stage cancelled)
23/03/18 18:08:32 WARN TaskSetManager: Lost task 0.0 in stage 6.0 (TID 12) (localhost executor driver): TaskKilled (Stage cancelled)
23/03/18 18:08:32 WARN TaskSetManager: Lost task 0.0 in stage 4.0 (TID 8) (localhost executor driver): TaskKilled (Stage cancelled)


Get similar articles.

In [None]:
def get_similar_articles(tweet_id: str) -> List[str]:
    rows = df_candidate_pairs \
        .filter(F.col('candidate_pair').contains(tweet_id)) \
        .select(F.array_remove('candidate_pair', tweet_id)) \
        .select(F.col('candidate_pair')[0].alias('similar_article')) \
        .collect()

    return [row.similar_article for row in rows]