# Assignment 1

## Imports

In [12]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, ArrayType
from itertools import combinations
from typing import Iterable, Any

## Spark initialization

In [13]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[4]') \
    .getOrCreate()

## Prepare the data

In [14]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER') \
    .withColumn('CODE', col('CODE').cast(IntegerType()))

In [15]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [16]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [17]:
support_threshold = 1000

### First pass

In [38]:
frequent_diseases_k1 = df.select('CODE', lit(1).alias('COUNT')) \
    .groupBy('CODE') \
    .sum('COUNT') \
    .withColumnRenamed('sum(count)', 'COUNT') \
    .filter(col('COUNT') >= support_threshold)

IndentationError: unexpected indent (3005710820.py, line 6)

In [19]:
frequent_diseases_k1.write.mode('overwrite').option('header', True).csv(path='frequent_diseases_k1', compression='gzip')

[Stage 16:>                                                         (0 + 4) / 5]

23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:09 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

In [20]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

[Stage 22:>                                                         (0 + 4) / 5]

23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 19:32:41 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

### Second pass

In [21]:
@udf(returnType=ArrayType(ArrayType(IntegerType()), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [28]:
frequent_diseases_k2 = df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(collect_list('CODE')) \
    .withColumn('CODE', combine_pairs('collect_list(CODE)')) \
    .select(explode('CODE').alias('CODE_PAIR'), lit(1).alias('COUNT')) \
    .groupBy('CODE_PAIR') \
    .sum('COUNT') \
    .withColumnRenamed('sum(COUNT)', 'COUNT') \
    .filter(col('COUNT') >= support_threshold) 

In [30]:
frequent_diseases_k2_set = {tuple(r.CODE_PAIR) for r in frequent_diseases_k2.select('CODE_PAIR').collect()}

                                                                                

In [None]:
frequent_diseases_k2.write.mode('overwrite').option('header', True).csv(path='frequent_diseases_k2', compression='gzip')

In [32]:
len(frequent_diseases_k2_set)

3558

In [31]:
len(frequent_diseases_k1_set)

125

### Third Pass

In [44]:
candidate_triples = frequent_diseases_k1.select('CODE').crossJoin(frequent_diseases_k2.select('CODE_PAIR'))

In [48]:
candidate_triples.show()

[Stage 113:>                (0 + 1) / 1][Stage 115:>                (0 + 3) / 5]

23/02/24 20:45:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:45:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:45:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:45:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 115:>                                                        (0 + 4) / 5]

23/02/24 20:45:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 115:>                (0 + 4) / 5][Stage 117:>                (0 + 0) / 5]

23/02/24 20:45:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:45:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.




+---------+--------------------+
|     CODE|           CODE_PAIR|
+---------+--------------------+
| 44054006|[10509002, 307731...|
| 87433001|[10509002, 307731...|
|230690007|[10509002, 307731...|
|398254007|[10509002, 307731...|
| 79586000|[10509002, 307731...|
|403192003|[10509002, 307731...|
| 92691004|[10509002, 307731...|
| 10509002|[10509002, 307731...|
|422034002|[10509002, 307731...|
| 75498004|[10509002, 307731...|
|363406005|[10509002, 307731...|
|367498001|[10509002, 307731...|
| 55680006|[10509002, 307731...|
| 82423001|[10509002, 307731...|
| 39848009|[10509002, 307731...|
| 62564004|[10509002, 307731...|
|431856006|[10509002, 307731...|
| 47505003|[10509002, 307731...|
|254837009|[10509002, 307731...|
|196416002|[10509002, 307731...|
+---------+--------------------+
only showing top 20 rows



                                                                                

In [49]:
candidate_triples.select(concat(array('CODE'), 'CODE_PAIR').alias('CODE_TRIPLE')).show()



23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 133:>                (0 + 4) / 5][Stage 136:>                (0 + 0) / 5]

23/02/24 20:50:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
23/02/24 20:50:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.




+--------------------+
|         CODE_TRIPLE|
+--------------------+
|[44054006, 105090...|
|[87433001, 105090...|
|[230690007, 10509...|
|[398254007, 10509...|
|[79586000, 105090...|
|[403192003, 10509...|
|[92691004, 105090...|
|[10509002, 105090...|
|[422034002, 10509...|
|[75498004, 105090...|
|[363406005, 10509...|
|[367498001, 10509...|
|[55680006, 105090...|
|[82423001, 105090...|
|[39848009, 105090...|
|[62564004, 105090...|
|[431856006, 10509...|
|[47505003, 105090...|
|[254837009, 10509...|
|[196416002, 10509...|
+--------------------+
only showing top 20 rows



                                                                                

In [51]:
frequent_diseases_k3 = df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(collect_list('CODE')) \
    .select('collect_list(CODE)') \
    .show()
    



+--------------------+
|  collect_list(CODE)|
+--------------------+
|[70704007, 192127...|
|[15777000, 558220...|
|[10509002, 157770...|
|[72892002, 105090...|
|[15777000, 271737...|
|[198992004, 43878...|
|[72892002, 162864...|
|[195662009, 40055...|
|[162864005, 30773...|
|[59621000, 444814...|
|[88805009, 254837...|
|[70704007, 65363002]|
|[55822004, 444814...|
|[53741008, 653630...|
|[195662009, 15777...|
|[15777000, 271737...|
|[444814009, 10509...|
|[36971009, 128613...|
|[703151001, 39848...|
|[87433001, 429007...|
+--------------------+
only showing top 20 rows



                                                                                

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```