# Assignment 1

## Imports

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, ArrayType
from itertools import combinations
from typing import Iterable, Any

## Spark initialization

In [2]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[4]') \
    .getOrCreate()

23/02/24 18:47:45 WARN Utils: Your hostname, martinho-SATELLITE-L50-B resolves to a loopback address: 127.0.1.1; using 192.168.1.66 instead (on interface enp8s0)
23/02/24 18:47:45 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/02/24 18:47:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


## Prepare the data

In [3]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER') \
    .withColumn('CODE', col('CODE').cast(IntegerType()))

                                                                                

In [4]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [5]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [6]:
support_threshold = 1000

### First pass

In [7]:
frequent_diseases_k1 = df.select('CODE', lit(1).alias('COUNT')) \
    .groupBy('CODE') \
    .sum('COUNT') \
    .withColumnRenamed('sum(count)', 'COUNT') \
    .filter(col('COUNT') >= support_threshold)

In [8]:
frequent_diseases_k1.write.mode('overwrite').option('header', True).csv(path='frequent_diseases_k1', compression='gzip')

                                                                                

In [9]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

                                                                                

### Second pass

In [10]:
@udf(returnType=ArrayType(ArrayType(IntegerType()), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [20]:
final_pass_k1 = df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(collect_list('CODE')) \
    .withColumn('CODE', combine_pairs('collect_list(CODE)')) \
    .select(explode('CODE').alias('CODE_PAIR'), lit(1).alias('COUNT')) \
    .groupBy('CODE_PAIR') \
    .sum('COUNT') \
    .withColumnRenamed('sum(COUNT)', 'COUNT')



+--------------------+-----+
|           CODE_PAIR|COUNT|
+--------------------+-----+
|[127013003, 23069...| 2609|
|[230690007, 12690...| 2870|
|[15777000, 44465007]|15974|
|[19169002, 79586000]|12985|
|[44054006, 196416...| 2128|
|[444814009, 39921...| 8152|
|[443165006, 26929...| 1887|
|[239872002, 42825...|  773|
|[65966004, 15777000]| 7194|
|[88805009, 16114001]| 1222|
|[162864005, 87433...| 4285|
|[44465007, 44054006]| 3462|
|[162864005, 25463...| 3197|
|[403190006, 12690...|  353|
|[254837009, 58150...|  466|
|[230690007, 82423...| 1666|
|[307731004, 70704...|  356|
|[62106007, 275272...|  207|
|[16114001, 65966004]|  844|
|[84757009, 19169002]| 1467|
+--------------------+-----+
only showing top 20 rows



                                                                                

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```