# Assignment 1

## Imports

In [1]:
import os.path
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StringType, ArrayType
from itertools import combinations
from typing import Iterable, Any

## Spark initialization

In [2]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[*]') \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/02 18:56:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Prepare the data

In [3]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER')

In [4]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [5]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [6]:
support_threshold = 1000

### First pass

In [7]:
if not os.path.exists('frequent_diseases_k1'):
    frequent_diseases_k1 = df \
        .groupBy('CODE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k1.write.mode('overwrite').parquet(path='frequent_diseases_k1', compression='gzip')

else:
    frequent_diseases_k1 = spark.read.parquet('frequent_diseases_k1')

In [8]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

In [None]:
len(frequent_diseases_k1_set)   # 131

### Second pass

In [9]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [11]:
df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(collect_list('CODE')) \
    .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
    .first()

                                                                                

Row(PATIENT='0000584e-c2d3-436b-8bd2-16294a3889b3', collect_list(CODE)=['195662009', '128613002', '703151001', '70704007', '65363002', '192127007', '232353008', '444814009'], CODE_PAIRS=[['195662009', '128613002'], ['195662009', '703151001'], ['195662009', '70704007'], ['195662009', '65363002'], ['195662009', '192127007'], ['195662009', '232353008'], ['195662009', '444814009'], ['128613002', '703151001'], ['128613002', '70704007'], ['128613002', '65363002'], ['128613002', '192127007'], ['128613002', '232353008'], ['128613002', '444814009'], ['703151001', '70704007'], ['703151001', '65363002'], ['703151001', '192127007'], ['703151001', '232353008'], ['703151001', '444814009'], ['70704007', '65363002'], ['70704007', '192127007'], ['70704007', '232353008'], ['70704007', '444814009'], ['65363002', '192127007'], ['65363002', '232353008'], ['65363002', '444814009'], ['192127007', '232353008'], ['192127007', '444814009'], ['232353008', '444814009']])

In [12]:
df_example = df \
    .filter(col('CODE').isin(frequent_diseases_k1_set)) \
    .groupBy('PATIENT') \
    .agg(collect_list('CODE')) \
    .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
    .first()

                                                                                

In [14]:
len(df_example['collect_list(CODE)'])

8

In [13]:
len(df_example.CODE_PAIRS)

28

In [35]:
if False:
    from pyspark.sql import Row

    df_test = spark.createDataFrame([
        Row(a=0, b=[1,2,3,8], c=list(combinations([1,2,3,8], 2))),
        Row(a=1, b=[4,5,6], c=list(combinations([4,5,6], 2))),
        Row(a=2, b=[1,5,7], c=list(combinations([1,5,7], 2))),
    ])

    df_test.withColumn('c', explode('c')).show()

In [17]:
if not os.path.exists('frequent_diseases_k2'):
    # ???
    frequent_diseases_k2 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_PAIRS') \
        .withColumn('CODE_PAIR', explode('CODE_PAIRS')) \
        .drop('CODE_PAIRS') \
        .withColumn('CODE_PAIR', array_sort('CODE_PAIR')) \
        .groupBy('CODE_PAIR') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k2.write.mode('overwrite').parquet(path='frequent_diseases_k2', compression='gzip')

else:
    frequent_diseases_k2 = spark.read.parquet('frequent_diseases_k2')

                                                                                

In [20]:
frequent_diseases_k2_set = {tuple(r.CODE_PAIR) for r in frequent_diseases_k2.select('CODE_PAIR').collect()}

                                                                                

In [21]:
len(frequent_diseases_k2_set)   # 2940

2940

### Third Pass

In [42]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_triples(elems: Iterable[Any]):

    triples = list(combinations(elems, 3))

    triples = [
        combination for combination in triples
        if ((combination[0], combination[1]) in frequent_diseases_k2_set
            and (combination[0], combination[2]) in frequent_diseases_k2_set
            and (combination[1], combination[2]) in frequent_diseases_k2_set)
    ]
        
    return triples

In [43]:
if not os.path.exists('frequent_diseases_k3'):
    frequent_diseases_k3 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('collect_list(CODE)', array_sort('collect_list(CODE)')) \
        .withColumn('CODE_TRIPLES', combine_triples('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_TRIPLES') \
        .withColumn('CODE_TRIPLE', explode('CODE_TRIPLES')) \
        .drop('CODE_TRIPLES') \
        .groupBy('CODE_TRIPLE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)

    frequent_diseases_k3.write.mode('overwrite').parquet(path='frequent_diseases_k3', compression='gzip')

else:
    frequent_diseases_k3 = spark.read.parquet('frequent_diseases_k3')

In [44]:
frequent_diseases_k3_set = {tuple(r.CODE_TRIPLE) for r in frequent_diseases_k3.select('CODE_TRIPLE').collect()}

                                                                                

In [45]:
len(frequent_diseases_k3_set) # 13395

13395

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```