# Assignment 1

## Imports

In [22]:
import os.path
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StringType, ArrayType
from itertools import combinations, chain
from typing import Iterable, Any, List

## Spark initialization

In [74]:
spark = SparkSession.builder \
    .appName('SandboxAssign1') \
    .config('spark.master', 'local[*]') \
    .getOrCreate()

## Prepare the data

In [67]:
df = spark.read \
    .option('header', True) \
    .csv('./data/conditions.csv.gz') \
    .drop('START', 'STOP', 'ENCOUNTER')

In [70]:
code_description_map = {r.CODE: r.DESCRIPTION
    for r in df \
    .select('CODE', 'DESCRIPTION') \
    .distinct() \
    .collect()
}

                                                                                

In [73]:
df.select('CODE', 'DESCRIPTION') \
.distinct() \
.collect()

                                                                                

[Row(CODE='431855005', DESCRIPTION='Chronic kidney disease stage 1 (disorder)'),
 Row(CODE='75498004', DESCRIPTION='Acute bacterial sinusitis (disorder)'),
 Row(CODE='93761005', DESCRIPTION='Primary malignant neoplasm of colon'),
 Row(CODE='254632001', DESCRIPTION='Small cell carcinoma of lung (disorder)'),
 Row(CODE='59621000', DESCRIPTION='Hypertension'),
 Row(CODE='86849004', DESCRIPTION='Suicidal deliberate poisoning'),
 Row(CODE='40055000', DESCRIPTION='Chronic sinusitis (disorder)'),
 Row(CODE='62106007', DESCRIPTION='Concussion with no loss of consciousness'),
 Row(CODE='69896004', DESCRIPTION='Rheumatoid arthritis'),
 Row(CODE='15777000', DESCRIPTION='Prediabetes'),
 Row(CODE='11218009', DESCRIPTION='Infection caused by Pseudomonas aeruginosa'),
 Row(CODE='197927001', DESCRIPTION='Recurrent urinary tract infection'),
 Row(CODE='448813005', DESCRIPTION='Sepsis caused by Pseudomonas (disorder)'),
 Row(CODE='403192003', DESCRIPTION='Third degree burn'),
 Row(CODE='88805009', DESCR

In [72]:
len(code_description_map)

159

In [5]:
df = df.drop('DESCRIPTION').distinct()

## A-priori algorithm

In [6]:
support_threshold = 1000

### First pass

In [7]:
if not os.path.exists('frequent_diseases_k1'):
    frequent_diseases_k1 = df \
        .groupBy('CODE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k1.write.mode('overwrite').parquet(path='frequent_diseases_k1', compression='gzip')

frequent_diseases_k1 = spark.read.parquet('frequent_diseases_k1')

                                                                                

In [8]:
frequent_diseases_k1_set = {r.CODE for r in frequent_diseases_k1.select('CODE').collect()}

                                                                                

In [9]:
len(frequent_diseases_k1_set)   # 131

131

### Second pass

In [10]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_pairs(elems: Iterable[Any]):
    return list(combinations(elems, 2))

In [11]:
if False:
    from pyspark.sql import Row

    df_test = spark.createDataFrame([
        Row(a=0, b=[1,2,3,8], c=list(combinations([1,2,3,8], 2))),
        Row(a=1, b=[4,5,6], c=list(combinations([4,5,6], 2))),
        Row(a=2, b=[1,5,7], c=list(combinations([1,5,7], 2))),
    ])

    df_test.withColumn('c', explode('c')).show()

In [12]:
if not os.path.exists('frequent_diseases_k2'):
    # ???
    frequent_diseases_k2 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('CODE_PAIRS', combine_pairs('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_PAIRS') \
        .withColumn('CODE_PAIR', explode('CODE_PAIRS')) \
        .drop('CODE_PAIRS') \
        .withColumn('CODE_PAIR', array_sort('CODE_PAIR')) \
        .groupBy('CODE_PAIR') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)
    
    frequent_diseases_k2.write.mode('overwrite').parquet(path='frequent_diseases_k2', compression='gzip')

frequent_diseases_k2 = spark.read.parquet('frequent_diseases_k2')

In [13]:
frequent_diseases_k2_set = {tuple(r.CODE_PAIR) for r in frequent_diseases_k2.select('CODE_PAIR').collect()}

                                                                                

In [14]:
len(frequent_diseases_k2_set)   # 2940

2940

### Third Pass

In [15]:
@udf(returnType=ArrayType(ArrayType(StringType(), False), False))
def combine_triples(elems: Iterable[Any]):

    triples = list(combinations(elems, 3))

    triples = [
        combination for combination in triples
        if ((combination[0], combination[1]) in frequent_diseases_k2_set
            and (combination[0], combination[2]) in frequent_diseases_k2_set
            and (combination[1], combination[2]) in frequent_diseases_k2_set)
    ]
        
    return triples

In [16]:
if not os.path.exists('frequent_diseases_k3'):
    frequent_diseases_k3 = df \
        .filter(col('CODE').isin(frequent_diseases_k1_set)) \
        .groupBy('PATIENT') \
        .agg(collect_list('CODE')) \
        .withColumn('collect_list(CODE)', array_sort('collect_list(CODE)')) \
        .withColumn('CODE_TRIPLES', combine_triples('collect_list(CODE)')) \
        .select('PATIENT', 'CODE_TRIPLES') \
        .withColumn('CODE_TRIPLE', explode('CODE_TRIPLES')) \
        .drop('CODE_TRIPLES') \
        .groupBy('CODE_TRIPLE') \
        .count() \
        .withColumnRenamed('count', 'COUNT') \
        .filter(col('COUNT') >= support_threshold)

    frequent_diseases_k3.write.mode('overwrite').parquet(path='frequent_diseases_k3', compression='gzip')

frequent_diseases_k3 = spark.read.parquet('frequent_diseases_k3')

In [17]:
frequent_diseases_k3_set = {tuple(r.CODE_TRIPLE) for r in frequent_diseases_k3.select('CODE_TRIPLE').collect()}

                                                                                

In [18]:
len(frequent_diseases_k3_set)   # 13395

13390

### Most frequent

In [39]:
with open('most_frequent_k2.txt', 'w') as f:
    print('pair\tcount', file=f)
    print(*(
            f'{r.CODE_PAIR}\t{r.COUNT}' for r in
            frequent_diseases_k2.sort('COUNT', ascending=False).take(10)
        ), sep='\n', file=f)

                                                                                

In [41]:
with open('most_frequent_k3.txt', 'w') as f:
    print('pair\tcount', file=f)
    print(*(
            f'{r.CODE_TRIPLE}\t{r.COUNT}' for r in
            frequent_diseases_k3.sort('COUNT', ascending=False).take(10)
        ), sep='\n', file=f)

                                                                                

### Rules

In [25]:
n_patients = df.select('PATIENT').distinct().count()

                                                                                

In [23]:
@udf(returnType=ArrayType(ArrayType(ArrayType(StringType(), False), False), False))
def inner_subsets(itemset: List[str]):
    itemset = set(itemset)
    combis = chain.from_iterable(list(map(set, combinations(itemset, k))) for k in range(1, len(itemset)))
    return list((sorted(combi), sorted(itemset - combi)) for combi in combis)

In [24]:
rules_k2 = frequent_diseases_k2 \
    .withColumn('CODE_PAIR', inner_subsets('CODE_PAIR')) \
    .withColumn('CODE_PAIR', explode('CODE_PAIR')) \
    .withColumnRenamed('COUNT', 'COUNT_PAIR') \
    .select(col('CODE_PAIR')[0].alias('RULE_1'), col('CODE_PAIR')[1].alias('RULE_2'), 'COUNT_PAIR') \
    .join(frequent_diseases_k1, frequent_diseases_k1['CODE'] == col('RULE_1')[0], 'inner') \
    .withColumnRenamed('COUNT', 'COUNT_1') \
    .drop('CODE') \
    .join(frequent_diseases_k1, frequent_diseases_k1['CODE'] == col('RULE_2')[0], 'inner') \
    .withColumnRenamed('COUNT', 'COUNT_2') \
    .drop('CODE')

In [35]:
rules_k2_metrics = rules_k2 \
    .withColumn('CONFIDENCE', col('COUNT_PAIR') / col('COUNT_1')) \
    .withColumn('INTEREST', col('CONFIDENCE') - col('COUNT_2') / n_patients) \
    .withColumn('LIFT', n_patients * col('CONFIDENCE') / col('COUNT_2')) \
    .withColumn('STANDARDISED_LIFT', 
                (col('LIFT') - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
                /
                ((n_patients / array_max(array(col('COUNT_1'), col('COUNT_2')))) - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
    ) \
    .filter(col('STANDARDISED_LIFT') >= 0.2) \
    .sort('STANDARDISED_LIFT', ascending=False)

In [29]:
rules_k3 = frequent_diseases_k3 \
    .withColumn('CODE_TRIPLE', inner_subsets('CODE_TRIPLE')) \
    .withColumn('CODE_TRIPLE', explode('CODE_TRIPLE')) \
    .withColumnRenamed('COUNT', 'COUNT_TRIPLE') \
    .select(col('CODE_TRIPLE')[0].alias('RULE_1'), col('CODE_TRIPLE')[1].alias('RULE_2'), 'COUNT_TRIPLE') \
    \
    .join(frequent_diseases_k1, array(frequent_diseases_k1['CODE']) == col('RULE_1'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_1') \
    .drop('CODE') \
    .join(frequent_diseases_k2, frequent_diseases_k2['CODE_PAIR'] == col('RULE_1'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_1_OTHER') \
    .drop('CODE_PAIR') \
    .withColumn('COUNT_1', coalesce('COUNT_1', 'COUNT_1_OTHER')) \
    .drop('COUNT_1_OTHER') \
    \
    .join(frequent_diseases_k1, array(frequent_diseases_k1['CODE']) == col('RULE_2'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_2') \
    .drop('CODE') \
    .join(frequent_diseases_k2, frequent_diseases_k2['CODE_PAIR'] == col('RULE_2'), 'left') \
    .withColumnRenamed('COUNT', 'COUNT_2_OTHER') \
    .drop('CODE_PAIR') \
    .withColumn('COUNT_2', coalesce('COUNT_2', 'COUNT_2_OTHER')) \
    .drop('COUNT_2_OTHER')

In [34]:
rules_k3_metrics = rules_k3 \
    .withColumn('CONFIDENCE', col('COUNT_TRIPLE') / col('COUNT_1')) \
    .withColumn('INTEREST', col('CONFIDENCE') - col('COUNT_2') / n_patients) \
    .withColumn('LIFT', n_patients * col('CONFIDENCE') / col('COUNT_2')) \
    .withColumn('STANDARDISED_LIFT', 
                (col('LIFT') - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
                /
                ((n_patients / array_max(array(col('COUNT_1'), col('COUNT_2')))) - array_max(array(
                    (col('COUNT_1') + col('COUNT_2')) / n_patients - 1,
                    lit(1 / n_patients)
                )) / (col('COUNT_1') * col('COUNT_2') / (n_patients ** 2)))
    ) \
    .filter(col('STANDARDISED_LIFT') >= 0.2) \
    .sort('STANDARDISED_LIFT', ascending=False)

### Printing

In [36]:
rules_k2_metrics.show(truncate=False)



+----------------+----------------+----------+-------+-------+-------------------+-------------------+------------------+------------------+
|RULE_1          |RULE_2          |COUNT_PAIR|COUNT_1|COUNT_2|CONFIDENCE         |INTEREST           |LIFT              |STANDARDISED_LIFT |
+----------------+----------------+----------+-------+-------+-------------------+-------------------+------------------+------------------+
|[44054006]      |[422034002]     |20456     |77306  |20456  |0.26461076759889274|0.24693938821884232|14.973973559620212|1.0000000000000002|
|[1551000119108] |[1501000119109] |3035      |11705  |3035   |0.25929090132422045|0.2566690477644603 |98.89602733874415 |1.0000000000000002|
|[72892002]      |[398254007]     |22959     |205390 |22959  |0.11178246263206583|0.09194880995380139|5.635999805248552 |1.0000000000000002|
|[44054006]      |[1551000119108] |11705     |77306  |11705  |0.1514112746746695 |0.14129964504798342|14.973973559620212|1.0000000000000002|
|[67811000119

                                                                                

In [64]:
@udf(returnType=StringType())
def format_rule(rule_1: List[str], rule_2: List[str], *values: List[Any]):
    return f'{{{", ".join(rule_1)}}} -> {{{", ".join(rule_2)}}}: {", ".join(map(str, values))}'

In [69]:
code_description_df.count()

                                                                                

160

In [65]:
# TODO: map code to description

rules_k2_metrics \
    .select(format_rule('RULE_1', 'RULE_2', 'STANDARDISED_LIFT', 'LIFT', 'CONFIDENCE', 'INTEREST')) \
    .show(truncate=False)



+------------------------------------------------------------------------------------------------------------------+
|format_rule(RULE_1, RULE_2, STANDARDISED_LIFT, LIFT, CONFIDENCE, INTEREST)                                        |
+------------------------------------------------------------------------------------------------------------------+
|{1551000119108} -> {1501000119109}: 1.0000000000000002, 98.89602733874415, 0.25929090132422045, 0.2566690477644603|
|{44054006} -> {422034002}: 1.0000000000000002, 14.973973559620212, 0.26461076759889274, 0.24693938821884232       |
|{44054006} -> {1551000119108}: 1.0000000000000002, 14.973973559620212, 0.1514112746746695, 0.14129964504798342    |
|{72892002} -> {398254007}: 1.0000000000000002, 5.635999805248552, 0.11178246263206583, 0.09194880995380139        |
|{443165006} -> {64859006}: 1.0, 20.9323158713224, 1.0, 0.9522269773613528                                         |
|{64859006} -> {443165006}: 1.0, 20.9323158713224, 0.31225475127

                                                                                

Example output

```
...
{Diabetes, Neoplasm} -> {Colon polyp}: 0.2000, ...
...
```