# Assignment 1

## Imports

In [1]:
import os.path
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StringType, ArrayType
from itertools import combinations
from typing import Iterable, Any

## Spark initialization

In [2]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[*]') \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/02 20:10:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/03/02 20:10:55 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## Prepare the data

In [3]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER')

In [4]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [5]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [6]:
support_threshold = 1000

### First pass

In [7]:
if not os.path.exists('frequent_diseases_k1'):
    frequent_diseases_k1 = df \
        .groupBy('CODE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k1.write.mode('overwrite').parquet(path='frequent_diseases_k1', compression='gzip')

frequent_diseases_k1 = spark.read.parquet('frequent_diseases_k1')

In [8]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

In [9]:
len(frequent_diseases_k1_set)   # 131

131

### Second pass

In [10]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [12]:
if False:
    from pyspark.sql import Row

    df_test = spark.createDataFrame([
        Row(a=0, b=[1,2,3,8], c=list(combinations([1,2,3,8], 2))),
        Row(a=1, b=[4,5,6], c=list(combinations([4,5,6], 2))),
        Row(a=2, b=[1,5,7], c=list(combinations([1,5,7], 2))),
    ])

    df_test.withColumn('c', explode('c')).show()

In [13]:
if not os.path.exists('frequent_diseases_k2'):
    # ???
    frequent_diseases_k2 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_PAIRS') \
        .withColumn('CODE_PAIR', explode('CODE_PAIRS')) \
        .drop('CODE_PAIRS') \
        .withColumn('CODE_PAIR', array_sort('CODE_PAIR')) \
        .groupBy('CODE_PAIR') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k2.write.mode('overwrite').parquet(path='frequent_diseases_k2', compression='gzip')

frequent_diseases_k2 = spark.read.parquet('frequent_diseases_k2')

In [14]:
frequent_diseases_k2_set = {tuple(r.CODE_PAIR) for r in frequent_diseases_k2.select('CODE_PAIR').collect()}

In [15]:
len(frequent_diseases_k2_set)   # 2940

2940

### Third Pass

In [16]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_triples(elems: Iterable[Any]):

    triples = list(combinations(elems, 3))

    triples = [
        combination for combination in triples
        if ((combination[0], combination[1]) in frequent_diseases_k2_set
            and (combination[0], combination[2]) in frequent_diseases_k2_set
            and (combination[1], combination[2]) in frequent_diseases_k2_set)
    ]
        
    return triples

In [17]:
if not os.path.exists('frequent_diseases_k3'):
    frequent_diseases_k3 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('collect_list(CODE)', array_sort('collect_list(CODE)')) \
        .withColumn('CODE_TRIPLES', combine_triples('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_TRIPLES') \
        .withColumn('CODE_TRIPLE', explode('CODE_TRIPLES')) \
        .drop('CODE_TRIPLES') \
        .groupBy('CODE_TRIPLE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)

    frequent_diseases_k3.write.mode('overwrite').parquet(path='frequent_diseases_k3', compression='gzip')

frequent_diseases_k3 = spark.read.parquet('frequent_diseases_k3')

In [18]:
frequent_diseases_k3_set = {tuple(r.CODE_TRIPLE) for r in frequent_diseases_k3.select('CODE_TRIPLE').collect()}

In [19]:
len(frequent_diseases_k3_set)   # 13395

13395

### Most frequent

In [20]:
frequent_diseases_k2.sort('COUNT', ascending=False).take(10)

[Row(CODE_PAIR=['195662009', '444814009'], COUNT=343651),
 Row(CODE_PAIR=['10509002', '444814009'], COUNT=302516),
 Row(CODE_PAIR=['15777000', '271737000'], COUNT=289176),
 Row(CODE_PAIR=['162864005', '444814009'], COUNT=243812),
 Row(CODE_PAIR=['271737000', '444814009'], COUNT=236847),
 Row(CODE_PAIR=['15777000', '444814009'], COUNT=236320),
 Row(CODE_PAIR=['10509002', '195662009'], COUNT=211065),
 Row(CODE_PAIR=['444814009', '59621000'], COUNT=203450),
 Row(CODE_PAIR=['162864005', '195662009'], COUNT=167438),
 Row(CODE_PAIR=['40055000', '444814009'], COUNT=165530)]

In [22]:
frequent_diseases_k3.sort('COUNT', ascending=False).take(10)

[Row(CODE_TRIPLE=['15777000', '271737000', '444814009'], COUNT=192819),
 Row(CODE_TRIPLE=['10509002', '195662009', '444814009'], COUNT=139174),
 Row(CODE_TRIPLE=['15777000', '195662009', '271737000'], COUNT=132583),
 Row(CODE_TRIPLE=['10509002', '15777000', '271737000'], COUNT=115510),
 Row(CODE_TRIPLE=['162864005', '195662009', '444814009'], COUNT=111860),
 Row(CODE_TRIPLE=['195662009', '271737000', '444814009'], COUNT=108560),
 Row(CODE_TRIPLE=['15777000', '195662009', '444814009'], COUNT=108083),
 Row(CODE_TRIPLE=['15777000', '271737000', '59621000'], COUNT=99818),
 Row(CODE_TRIPLE=['10509002', '162864005', '444814009'], COUNT=97384),
 Row(CODE_TRIPLE=['10509002', '271737000', '444814009'], COUNT=94793)]

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```