# Assignment 1

## Imports

In [66]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, ArrayType
from itertools import combinations
from typing import Iterable, Any

## Spark initialization

In [2]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[4]') \
    .getOrCreate()

23/02/23 10:31:21 WARN Utils: Your hostname, martinho-SATELLITE-L50-B resolves to a loopback address: 127.0.1.1; using 192.168.36.118 instead (on interface wlx200db038271f)
23/02/23 10:31:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/02/23 10:31:22 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


## Prepare the data

In [52]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER') \
    .withColumn('CODE', col('CODE').cast(IntegerType()))

In [53]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [54]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [24]:
support_threshold = 1000

### First pass

In [33]:
frequent_diseases_k1 = df.select('CODE', lit(1).alias('COUNT')) \
    .groupBy('CODE') \
    .sum('COUNT') \
    .withColumnRenamed('sum(count)', 'COUNT') \
    .filter(col('COUNT') >= support_threshold)

In [35]:
frequent_diseases_k1.write.option('header', True).csv(path='frequent_diseases_k1', compression='gzip')

                                                                                

In [40]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

                                                                                

### Second pass

In [67]:
@udf(returnType=ArrayType(ArrayType(IntegerType()), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [65]:
df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(combine_pairs('CODE')) \
    .select(explode('CODE').alias('CODE_PAIR')) \
    .show()

[Stage 91:>                                                         (0 + 1) / 1]

+--------------------+--------------------+
|             PATIENT|   collect_set(CODE)|
+--------------------+--------------------+
|0004846e-17b1-41e...|[39848009, 703151...|
|001ef919-eb0d-416...|[43878008, 105090...|
|001fe56a-7fbb-414...|[428251008, 40319...|
|00227504-0f42-433...|[446096008, 15607...|
|0023510e-a1dc-4c8...|[65966004, 443165...|
|002b39a7-06a9-47f...|[271737000, 10509...|
|0031d0c5-36b4-44f...|[301011002, 16286...|
|0037f64d-5bcc-4fe...|[36971009, 400550...|
|003c9bab-d697-4d0...| [44465007, 1734006]|
|003f1af5-2084-492...|[43878008, 653630...|
|00494652-a2f2-4f5...|[19169002, 684960...|
|004c94f3-afa9-49d...|[230265002, 36749...|
|005558bb-6418-416...|[19169002, 105090...|
|00565bf0-a9ea-4f9...|[36971009, 400550...|
|005d1271-4646-449...|[428251008, 44054...|
|005e8575-2646-42e...|[79586000, 191690...|
|0061ad68-a3fa-46b...|[7200002, 9269100...|
|007faced-6e1d-4ac...|[55680006, 271737...|
|008f6f56-f26e-473...|[19169002, 105090...|
|008fafd3-7642-40c...|[43878008,

                                                                                

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```