In [2]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

from pyspark.sql.functions import collect_list, count

from operator import add

In [3]:
# Spark Constants
APP_NAME = 'assignment1'
MASTER = 'local[*]'

# Column Constants
PATIENT_COLUMN = "PATIENT"
CODE_COLUMN = "CODE"

CONDITIONS_COLUMN = "CONDITIONS"

# Input Constants
INPUT_FILE = 'conditions-sample.csv'
SUPPORT_THRESHOLD = 1000

In [4]:
conf = SparkConf().setAppName(APP_NAME).setMaster(MASTER)
sc = SparkContext.getOrCreate(conf=conf)

spark = SparkSession.builder.appName(APP_NAME).master(MASTER).getOrCreate()

23/03/09 10:05:21 WARN Utils: Your hostname, pedro-duarte resolves to a loopback address: 127.0.1.1; using 192.168.39.114 instead (on interface wlp2s0)
23/03/09 10:05:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/09 10:05:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
ds = spark.read.csv(INPUT_FILE, header=True, inferSchema=True)
ds.schema

                                                                                

StructType([StructField('START', TimestampType(), True), StructField('STOP', TimestampType(), True), StructField('PATIENT', StringType(), True), StructField('ENCOUNTER', StringType(), True), StructField('CODE', LongType(), True), StructField('DESCRIPTION', StringType(), True)])

In [6]:
conditions = ds.select(CODE_COLUMN).distinct()
conditions.show(5)

[Stage 2:>                                                          (0 + 4) / 4]

+---------+
|     CODE|
+---------+
| 74400008|
|449868002|
| 38822007|
|254637007|
| 47693006|
+---------+
only showing top 5 rows



                                                                                

In [7]:
patient_conditions = ds.groupBy(PATIENT_COLUMN).agg(collect_list(CODE_COLUMN).alias(CONDITIONS_COLUMN))
patient_conditions.show(5)



+--------------------+--------------------+
|             PATIENT|          CONDITIONS|
+--------------------+--------------------+
|000157a3-5aca-4f2...|[367498001, 15777...|
|00016311-4402-47d...|[43878008, 728920...|
|00047353-ffa1-4c9...|[162864005, 19169...|
|000da4b5-ce45-466...|[162864005, 44481...|
|0015bf87-c0c6-401...|[15777000, 271737...|
+--------------------+--------------------+
only showing top 5 rows



                                                                                

In [8]:
elems_k1 = ds.groupBy(CODE_COLUMN).agg(count("*").alias("COUNT")).filter("count > 1000")
elems_k1.show(5)

[Stage 8:>                                                          (0 + 4) / 4]

+---------+-----+
|     CODE|COUNT|
+---------+-----+
| 74400008| 3234|
| 38822007| 1273|
| 75498004| 4099|
| 16114001| 2098|
|283371005| 2043|
+---------+-----+
only showing top 5 rows



                                                                                

In [9]:
rdd = patient_conditions.rdd \
  .flatMap(lambda v: [(c, 1) for c in v[CONDITIONS_COLUMN]]) \
  .reduceByKey(add)

rdd.collect()

                                                                                

[(15777000, 20702),
 (271737000, 20753),
 (55822004, 7901),
 (43878008, 9962),
 (198992004, 1395),
 (79586000, 1712),
 (446096008, 1863),
 (307731004, 1018),
 (40055000, 14933),
 (59621000, 17966),
 (263102004, 2062),
 (410429000, 2565),
 (92691004, 1648),
 (49436004, 2475),
 (53741008, 4028),
 (75498004, 4099),
 (74400008, 3234),
 (428251008, 3234),
 (26929004, 2057),
 (65966004, 2697),
 (62564004, 1042),
 (156073000, 2447),
 (239720000, 332),
 (408512008, 266),
 (444448004, 501),
 (370247008, 2065),
 (1551000119108, 678),
 (232353008, 1842),
 (241929008, 1088),
 (283385000, 2185),
 (262574004, 461),
 (713197008, 1480),
 (40275004, 367),
 (69896004, 201),
 (314994000, 303),
 (424132000, 772),
 (370143000, 363),
 (236077008, 268),
 (235919008, 145),
 (190905008, 20),
 (94260004, 124),
 (65710008, 10),
 (157141000119108, 14),
 (707577004, 10),
 (45816000, 32),
 (86849004, 7),
 (225444004, 9),
 (367498001, 919),
 (444814009, 72504),
 (162864005, 21513),
 (283371005, 2043),
 (195662009, 4

In [10]:
import itertools

rdd = patient_conditions.rdd \
  .flatMap(lambda v: [(c, 1) for c in itertools.combinations(v[CONDITIONS_COLUMN], 3)]) \
  .reduceByKey(add) \
  .filter(lambda x: x[1] > SUPPORT_THRESHOLD)

rdd.collect()

                                                                                

[((72892002, 10509002, 444814009), 3001),
 ((72892002, 72892002, 444814009), 16089),
 ((10509002, 72892002, 444814009), 2680),
 ((271737000, 444814009, 72892002), 2200),
 ((162864005, 19169002, 10509002), 1081),
 ((15777000, 271737000, 444814009), 15393),
 ((15777000, 271737000, 195662009), 8708),
 ((15777000, 444814009, 10509002), 4846),
 ((15777000, 195662009, 10509002), 2638),
 ((271737000, 444814009, 10509002), 4476),
 ((271737000, 195662009, 10509002), 2465),
 ((162864005, 19169002, 72892002), 1018),
 ((72892002, 72892002, 195662009), 9647),
 ((72892002, 19169002, 195662009), 1426),
 ((72892002, 19169002, 444814009), 2509),
 ((72892002, 398254007, 72892002), 2961),
 ((19169002, 72892002, 195662009), 2270),
 ((19169002, 72892002, 444814009), 4155),
 ((195662009, 72892002, 72892002), 6617),
 ((40055000, 59621000, 444814009), 1867),
 ((40055000, 162864005, 10509002), 1572),
 ((40055000, 444814009, 10509002), 3341),
 ((59621000, 162864005, 10509002), 2330),
 ((59621000, 444814009, 105

In [11]:
exclusions_k1 = patient_conditions.rdd \
  .flatMap(lambda v: [(c, 1) for c in itertools.combinations(v[CONDITIONS_COLUMN], 1)]) \
  .reduceByKey(add) \
  .filter(lambda x: x[1] < SUPPORT_THRESHOLD) \
  .toDF().select('_1').collect()

patient_conditions.rdd \
  .flatMap(lambda v: [(c, 1) for c in itertools.combinations(v[CONDITIONS_COLUMN], 2)]) \
  .reduceByKey(add) \
  .filter(lambda x: x[1] < SUPPORT_THRESHOLD) \
  .toDF().select('_1').show(5)

                                                                                

+--------------------+
|                  _1|
+--------------------+
|{10509002, 198992...|
|{162864005, 28337...|
|{283371005, 19566...|
|{72892002, 79586000}|
|{24079001, 195662...|
+--------------------+
only showing top 5 rows

