# Assignment 1

## Imports

In [None]:
import os.path
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StringType, ArrayType
from itertools import combinations, chain
from typing import Iterable, Any

## Spark initialization

In [None]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[*]') \
    .getOrCreate()

## Prepare the data

In [None]:
df = spark.read \
    .option('header', True) \
    .csv('conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER')

In [None]:
code_description_df = df \
    .select('CODE', 'DESCRIPTION') \
    .distinct()

In [None]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [None]:
support_threshold = 1000

### First pass

In [None]:
if not os.path.exists('frequent_diseases_k1'):
    frequent_diseases_k1 = df \
        .groupBy('CODE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k1.write.mode('overwrite').parquet(path='frequent_diseases_k1', compression='gzip')

frequent_diseases_k1 = spark.read.parquet('frequent_diseases_k1')

In [None]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

In [None]:
len(frequent_diseases_k1_set)   # 131

### Second pass

In [None]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [None]:
if False:
    from pyspark.sql import Row

    df_test = spark.createDataFrame([
        Row(a=0, b=[1,2,3,8], c=list(combinations([1,2,3,8], 2))),
        Row(a=1, b=[4,5,6], c=list(combinations([4,5,6], 2))),
        Row(a=2, b=[1,5,7], c=list(combinations([1,5,7], 2))),
    ])

    df_test.withColumn('c', explode('c')).show()

In [None]:
if not os.path.exists('frequent_diseases_k2'):
    # ???
    frequent_diseases_k2 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_PAIRS') \
        .withColumn('CODE_PAIR', explode('CODE_PAIRS')) \
        .drop('CODE_PAIRS') \
        .withColumn('CODE_PAIR', array_sort('CODE_PAIR')) \
        .groupBy('CODE_PAIR') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k2.write.mode('overwrite').parquet(path='frequent_diseases_k2', compression='gzip')

frequent_diseases_k2 = spark.read.parquet('frequent_diseases_k2')

In [None]:
frequent_diseases_k2_set = {tuple(r.CODE_PAIR) for r in frequent_diseases_k2.select('CODE_PAIR').collect()}

In [None]:
len(frequent_diseases_k2_set)   # 2940

### Third Pass

In [None]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_triples(elems: Iterable[Any]):

    triples = list(combinations(elems, 3))

    triples = [
        combination for combination in triples
        if ((combination[0], combination[1]) in frequent_diseases_k2_set
            and (combination[0], combination[2]) in frequent_diseases_k2_set
            and (combination[1], combination[2]) in frequent_diseases_k2_set)
    ]
        
    return triples

In [None]:
if not os.path.exists('frequent_diseases_k3'):
    frequent_diseases_k3 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('collect_list(CODE)', array_sort('collect_list(CODE)')) \
        .withColumn('CODE_TRIPLES', combine_triples('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_TRIPLES') \
        .withColumn('CODE_TRIPLE', explode('CODE_TRIPLES')) \
        .drop('CODE_TRIPLES') \
        .groupBy('CODE_TRIPLE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)

    frequent_diseases_k3.write.mode('overwrite').parquet(path='frequent_diseases_k3', compression='gzip')

frequent_diseases_k3 = spark.read.parquet('frequent_diseases_k3')

In [None]:
frequent_diseases_k3_set = {tuple(r.CODE_TRIPLE) for r in frequent_diseases_k3.select('CODE_TRIPLE').collect()}

In [None]:
len(frequent_diseases_k3_set)   # 13395

### Most frequent

In [None]:
frequent_diseases_k2.sort('COUNT', ascending=False).take(10)

In [None]:
frequent_diseases_k3.sort('COUNT', ascending=False).take(10)

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```

### Rules

In [118]:
@udf(returnType=ArrayType(ArrayType(ArrayType(StringType(), False), False), False))
def inner_subsets(itemset: List[str]):
    itemset = set(itemset)
    combis = chain.from_iterable(list(map(set, combinations(itemset, k))) for k in range(1, len(itemset)))
    return list((sorted(combi), sorted(itemset - combi)) for combi in combis)

In [129]:
rules_k2 = frequent_diseases_k2 \
    .withColumn('CODE_PAIR', inner_subsets('CODE_PAIR')) \
    .withColumn('CODE_PAIR', explode('CODE_PAIR')) \
    .withColumnRenamed('COUNT', 'COUNT_PAIR') \
    .select(col('CODE_PAIR')[0].alias('RULE_1'), col('CODE_PAIR')[1].alias('RULE_2'), 'COUNT_PAIR') \
    .join(frequent_diseases_k1, frequent_diseases_k1['CODE'] == col('RULE_1')[0], 'inner') \
    .withColumnRenamed('COUNT', 'COUNT_1') \
    .drop('CODE') \
    .join(frequent_diseases_k1, frequent_diseases_k1['CODE'] == col('RULE_2')[0], 'inner') \
    .withColumnRenamed('COUNT', 'COUNT_2') \
    .drop('CODE')

In [None]:
n_patients = df.select('PATIENT').distinct().count()

In [None]:
rules_k2 \
    .withColumn('CONFIDENCE', col('COUNT_PAIR') / col('COUNT_1')) \
    .withColumn('INTEREST', col('CONFIDENCE') - col('COUNT_2') / n_patients) \
    .withColumn('LIFT', n_patients * col('CONFIDENCE') / col('COUNT_2')) \
    .withColumn('STANDARDISED_LIFT', 
                (col('LIFT') - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
                /
                ((n_patients / array_max(array(col('COUNT_1'), col('COUNT_2')))) - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
    ) \
    .filter(col('STANDARDISED_LIFT') >= 0.2) \
    .show(truncate=False)


In [151]:
rules_k3 = frequent_diseases_k3 \
    .withColumn('CODE_TRIPLE', inner_subsets('CODE_TRIPLE')) \
    .withColumn('CODE_TRIPLE', explode('CODE_TRIPLE')) \
    .withColumnRenamed('COUNT', 'COUNT_TRIPLE') \
    .select(col('CODE_TRIPLE')[0].alias('RULE_1'), col('CODE_TRIPLE')[1].alias('RULE_2'), 'COUNT_TRIPLE') \
    \
    .join(frequent_diseases_k1, array(frequent_diseases_k1['CODE']) == col('RULE_1'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_1') \
    .drop('CODE') \
    .join(frequent_diseases_k2, frequent_diseases_k2['CODE_PAIR'] == col('RULE_1'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_1_OTHER') \
    .drop('CODE_PAIR') \
    .withColumn('COUNT_1', coalesce('COUNT_1', 'COUNT_1_OTHER')) \
    .drop('COUNT_1_OTHER') \
    \
    .join(frequent_diseases_k1, array(frequent_diseases_k1['CODE']) == col('RULE_2'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_2') \
    .drop('CODE') \
    .join(frequent_diseases_k2, frequent_diseases_k2['CODE_PAIR'] == col('RULE_2'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_2_OTHER') \
    .drop('CODE_PAIR') \
    .withColumn('COUNT_2', coalesce('COUNT_2', 'COUNT_2_OTHER')) \
    .drop('COUNT_2_OTHER') \
    \
    .show(truncate=False)
    # .withColumnRenamed('COUNT', 'COUNT_1') \
    # .drop('CODE') \
    # .join(frequent_diseases_k1, frequent_diseases_k1['CODE'] == col('RULE_2')[0], 'inner') \
    # .withColumnRenamed('COUNT', 'COUNT_2') \
    # .drop('CODE')

+----------------------------+----------------------------+------------+-------+-------+
|RULE_1                      |RULE_2                      |COUNT_TRIPLE|COUNT_1|COUNT_2|
+----------------------------+----------------------------+------------+-------+-------+
|[368581000119106]           |[302870006, 422034002]      |10003       |26779  |20413  |
|[422034002]                 |[302870006, 368581000119106]|10003       |20456  |26722  |
|[302870006]                 |[368581000119106, 422034002]|10003       |75992  |10004  |
|[368581000119106, 422034002]|[302870006]                 |10003       |10004  |75992  |
|[302870006, 368581000119106]|[422034002]                 |10003       |26722  |20456  |
|[302870006, 422034002]      |[368581000119106]           |10003       |20413  |26779  |
|[444814009]                 |[59621000, 80394007]        |13500       |751940 |20376  |
|[59621000]                  |[444814009, 80394007]       |13500       |305134 |23558  |
|[80394007]          

In [150]:
frequent_diseases_k2.filter((frequent_diseases_k2.CODE_PAIR[0] == '15777000') & (frequent_diseases_k2.CODE_PAIR[1] == '271737000')).show()

+--------------------+------+
|           CODE_PAIR| COUNT|
+--------------------+------+
|[15777000, 271737...|289176|
+--------------------+------+

