# Preprocessing of the Elders data

### Requirements
We have to split the original csv-file into multiple tables, using the saveAsTable functionality, the datasets are automatically stored in the "spark-warehouse" folder at the base directory

#### Cleaning
* 

#### Tables
* Prescription dataset, cleaning most of the not-so-relevant variables from the base table

#### File structure
* Use "elders/" as the base directory for new tables
    * prescriptions (should they include possible drug descriptions?)
    * patients
    * prescribers
    * drugs

### Dataset splits


# Initialize the spark runtime

In [None]:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import spark.implicits._

val dsetname = "elders"
val basedir = "spark-warehouse/"+dsetname+"/"
sc.getConf.getAll

val execExtras = false

## Open the raw elders dataset
#### It needs preprocessing, all fields are Strings, and there's a lot of currently unused information that we discard

In [None]:
val rawElders = spark.read
    .option("header","true")
    .option("delimiter",";")
    .csv("datasets/hospitalization")
    .cache()
rawElders.printSchema

In [None]:
if (execExtras){
    println(rawElders.count)
    println(rawElders.where("ATCKode is not null").where("length(ATCKode) == 7").count)
    println()
    println(rawElders.where("Diff_UtleveringDato is not null").count)
    //println(rawHospitalization.select("ATCKode").where("length(ATCKode) == 7").distinct.count)
    //println(rawHospitalization.select("ATCKode").where("length(ATCKode) < 7").distinct.count)
    //rawHospitalization.select("ATCKode").where("length(ATCKode) < 7").distinct.show()
    //rawHospitalization.select("PasientDodsAr").groupBy("PasientDodsAr").count.show
    println()
    println(rawElders.where("PasientDodsAr is not null").count)
    println(rawElders.where("PasientDodsMnd is not null").count)
    println(rawElders.where("PasientDodsAr is not null").where("PasientDodsMnd is null").count)
    println()
    println(rawElders.where("Diff_UtleveringDato is not null").where("PasientDodsAr is null").count)
    println(rawElders.where("PasientDodsAr is not null").where("Diff_UtleveringDato is null").count)
    println(rawElders.where("Diff_UtleveringDato is not null").where("PasientDodsAr is not null").count)
}else{
    print("Set to no-print")
}

# Info about raw data
* ~62M rows (61.930.313) Before cleaning
* ~21M rows without (Diff_UtleveringDato) field


* 1211 different full (5-level) ATC codes
* 62 ATC codes using less than 5 levels


* If a prescription contains the deathYear(PasientDodsAr) It will always contain the deathMonth(PasientDodsMnd)
    * This means that we dont have to filter out this case


### Useful data
* ~60M (59.903.331) prescriptions with a full ATC code (with all 5 ATC levels)

### Useless data
* ~4600 rows containing only patientID (PasientLopeNr, all other fields are null)



In [None]:
rawElders.printSchema

### Find out the meaning of the following fields before they can be incorporated
* Hjemmel [Now included]
* HjemmelNr [Now included]
* Kategori [Now included]
* KategoriNr [Now included]
* OrdinasjonAntallPakninger [Now included]
* OrdinasjonAntallDDD [Now included]
#### For The following fields the meaning is known 
* ICD
* ICPC

In [None]:
val hjemmel = rawElders.select("Hjemmel").where("Hjemmel is not null")
println("Distinct Hjemmel")
println(hjemmel.distinct.count)

val hjemmelnr = rawElders.select($"HjemmelNr".cast("integer")).where("HjemmelNr is not null")
println("Distinct hjemmelnr")
println(hjemmelnr.distinct.count)

val hjemmel_hjemmelnr = rawElders.select($"Hjemmel", $"Hjemmelnr".cast("integer"))
    .groupBy("Hjemmel", "HjemmelNr")
    .count
    .sort(desc("count"))
hjemmel_hjemmelnr.show(false)

In [None]:
val kategori_kategorinr = rawElders.select($"Kategori",$"KategoriNr")
kategori_kategorinr.groupBy("Kategori","KategoriNr").count.sort(asc("KategoriNr"))
    .show

In [None]:
val ordantpak_ordantddd = rawElders
    .select($"OrdinasjonAntallPakninger", $"OrdinasjonAntallDDD").distinct

//ordantpak_ordantddd.show
rawElders
    .select($"OrdinasjonAntallPakninger"
            , $"OrdinasjonAntallDDD"
            , $"ATCKodeDDDEnhet"
            , $"ATCKodeDDDVerdi"
            , $"ATCKode"
           )
    .groupBy($"OrdinasjonAntallPakninger"
            , $"OrdinasjonAntallDDD"
            , $"ATCKodeDDDEnhet"
            , $"ATCKodeDDDVerdi"
            , $"ATCKode")
    .count
    .sort(asc("ATCKode"))
//.sort(desc("count"))
    .show


    rawElders.select($"ATCKodeDDDEnhet").where("ATCKodeDDDEnhet is not null").distinct.collect.map(print)


## Sanitize the data, removing some rows that make no sense
Rows that lack the following information are removed
* Prescriptions with no ATC code
* Prescriptions of patients with no ID
* Prescriptions of patients with no birthyear
* Prescriptions where the ATC code is shorter than 5 levels

### Cast the respective fields into types that make more sense
* The fields should correlate with the field types and names of the prescription data

#### The ICD and ICPC numbers are currently not kept, these will require some preprocessing to get in order

##### Fields that are currently not kept after sanitization
* |-- PasientUtenID: string (nullable = true)  :  Patients without ID are removed from dataset
* |-- OrdinasjonAntallPakninger: string (nullable = true)
* |-- OrdinasjonAntallDDD: string (nullable = true)
* |-- Kategori: string (nullable = true)
* |-- KategoriNr: string (nullable = true)
* |-- Hjemmel: string (nullable = true)
* |-- HjemmelNr: string (nullable = true)
* |-- RefusjonKodeICDNr: string (nullable = true)
* |-- RefusjonKodeICPCNr: string (nullable = true)
* |-- VareNr: string (nullable = true)


In [None]:
val sanitized = rawElders
    .where("PasientFodtAr is not null")
    .where("ATCKode is not null")
    .where("LENGTH(ATCKode) == 7")
    .where("PasientUtenID is not null")
    .select(
        $"PasientLopeNr".as("id")
        ,$"PasientFodtAr".cast("integer").as("birthyear")
        ,$"PasientKjonn".cast("integer").as("gender")
        ,$"PasientBostedFylkeNr".cast("integer").as("fylke_id")
        ,$"PasientBostedFylkeNavn".cast("integer").as("fylke_name")
        ,$"PasientDodsAr".cast("integer").as("year_of_death")
        ,$"PasientDodsMnd".cast("integer").as("month_of_death")
        ,$"ATCKode".as("drugcode")
        ,$"ATCKodeDDDVerdi".as("DDD_value")
        ,$"ATCKodeDDDEnhet".as("DDD_unit")
        ,$"VareNavn"
        ,unix_timestamp($"UtleveringsAar", "yyyy").cast(TimestampType).as("prescription_year")
        ,unix_timestamp($"UtleveringsDato", "yyyy.MM.dd").cast(TimestampType).as("timestamp")
        ,$"Diff_UtleveringDato".cast("integer")
        ,$"ForskriverLopeNr".as("prescriber_id")
        ,$"ForskriverFodtAr".cast("integer").as("prescriber_birthyear")
        ,$"ForskriverKjonn".cast("integer").as("prescriber_gender")
        ,$"ForskriverUtenID".cast("integer").as("prescriber_no_id")
        ,$"Hjemmel"
        ,$"Hjemmelnr".cast("integer")
        ,$"Kategori"
        ,$"KategoriNr".cast("integer")
        ,$"OrdinasjonAntallPakninger".cast("float")
        ,$"OrdinasjonAntallDDD".cast("float")
    )
    .repartition($"id")
    .sortWithinPartitions($"id",$"timestamp",$"Diff_UtleveringDato")
//before
rawElders.printSchema
//after
sanitized.printSchema

### The script below verifies that prescription dates are represented in two different ways
___ 
Look at the following table

`
+-------------+--------------+-------------------+-------------------+
|year_of_death|month_of_death|          timestamp|Diff_UtleveringDato|
+-------------+--------------+-------------------+-------------------+
|         2015|            12|2013-01-23 00:00:00|               null|
|         2016|            11|2014-10-20 00:00:00|               null|
|         2015|            11|2014-10-06 00:00:00|               null|
|         2015|             5|2013-01-02 00:00:00|               null|
|         2016|             7|               null|                396|
+-------------+--------------+-------------------+-------------------+
`

* Observe that there is an XOR relation between timestamp and Diff_UtleveringsDato

This may be verified by setting execExtras to true and running the next cell

In [None]:
// Show fields using timestamp vs Diff_UtleveringDato
if(execExtras){
    val timestamp_cmp_df = sanitized.select(
            $"year_of_death"
            ,$"month_of_death"
            ,$"timestamp"
            ,$"Diff_UtleveringDato"
        ).where("year_of_death is not null")
    timestamp_cmp_df.show(5)
    val timestamp_not_null = timestamp_cmp_df
        .where("timestamp is not null").count
    val timestamp_null =timestamp_cmp_df
        .where("timestamp is null").count    
}else{
    println("not executed")
}

## Parse the death timestamp

In [None]:
%run "src/scala/udf_elders.scala"

In [None]:


val elders_processed = parse_death_timestamp(sanitized)
//sanitized.printSchema
//elders_processed.printSchema
elders_processed.select($"id",$"death_timestamp")
    .where("death_timestamp is not null").show(5)

elders_processed.where("death_timestamp is not null").count
elders_processed.where("death_timestamp is null").count

# Transform the dataset into multiple DataFrames
## Raw Dataframe, Also parse the death timestamp

## Prescription Table

In [None]:

val prescriptions = elders_processed.select(
    $"id"
    ,$"birthyear"
    ,$"gender"
    ,$"drugcode"
    ,$"DDD_value"
    ,$"DDD_unit"
    ,$"prescription_year"
    ,$"timestamp"
    ,$"Diff_UtleveringDato"
    ,$"prescriber_id"
    ,$"OrdinasjonAntallPakninger"
    ,$"OrdinasjonAntallDDD"
)

if (execExtras){
prescriptions.show(5)    
}

## Patients Table
#### Will need to figure out how to properly concatenate dates (year+month of death) in this format


In [None]:
val patients = elders_processed.select(
    $"id"
    ,$"birthyear"
    ,$"gender"
    ,$"death_timestamp"
).distinct

patients.createOrReplaceTempView("hosp_patients")


In [None]:
if (execExtras) {
    spark.sql("select * from hosp_patients")
        .where("year_of_death is null")
        .createOrReplaceTempView("hosp_patients_live")
    spark.sql("select * from hosp_patients")
        .where("year_of_death is not null")
        .createOrReplaceTempView("hosp_patients_dead")
    val live_patients_count = spark.sql("select * from hosp_patients_live").count
    val dead_patients_count = spark.sql("select * from hosp_patients_dead").count
    val total_patients_count = patients.count
    patients.count - live_patients_count - dead_patients_count  
}

## Drug table

In [None]:
val drugs = elders_processed.where(length($"drugcode") === 7).select(
        $"drugcode"
        ,$"VareNavn"
).groupBy($"drugcode",$"VareNavn").count.sort("drugcode")
if (execExtras){
    drugs.show(5,false)    
}

# Save all tables

In [None]:
//save elders_processed table
//This contains all variables
elders_processed.write
    .mode(SaveMode.Overwrite)
    .parquet(basedir+"all")
println("ok")

In [None]:
//save patients table
//this contains a subset of variables
patients.write
    .mode(SaveMode.Overwrite)
    .parquet(basedir+"patients")
println("ok")

In [None]:
//save prescription table
prescriptions.write
    .mode(SaveMode.Overwrite)
    .parquet(basedir+"prescriptions")
println("ok")

In [None]:
//save drugs table
drugs.write
    .mode(SaveMode.Overwrite)
    .parquet(basedir+"drugs")
println("ok")

___



### Which genders do the PasientKjonn(gender) value correlate with?

PasientKjonn values: Male=>1, Female=>2, as verified below


In [None]:
//what genders do the PasientKjonn value correlate with?
val atc_with_gender = sanitizedHospitalization.select(
    $"drugcode"
    ,$"gender"
)
//sex hormones
atc_with_gender.where($"drugcode".startsWith("G03")).groupBy("gender").count.show
//and androgens, specifically
atc_with_gender.where($"drugcode".startsWith("G03B")).groupBy("gender").count.show


## Product variants of the same drug(ATC code)
* not counting dosage or product number, but rather only the product name

In [None]:
val drug_variants_with_count = sanitizedHospitalization.select(
    $"drugcode"
    ,$"VareNavn"
).distinct.groupBy("drugcode").count.sort("drugcode").distinct

drug_variants_with_count.sort(desc("count")).createOrReplaceTempView("n_drug_variants")

In [None]:
val count_distinct_atc_with_product_name = rawHospitalization.select($"ATCKode",$"VareNavn").distinct.count
val count_distinct_atc_codes = rawHospitalization.select($"ATCKode").distinct.count
println("mean number of distinct products per drug ATC code:")
println(count_distinct_atc_with_product_name.toFloat/count_distinct_atc_codes.toFloat)


### Let's plot some distributions with python

In [None]:
%%python
spark

In [None]:
%%python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Image
import tempfile

In [None]:
%%python
#magic function for plotting, since the spylon-kernel needs tempfiles to display images
def plotfig_magic():
    fo = tempfile.NamedTemporaryFile(suffix=".png", mode="w")
    fo.close()
    plt.savefig(fo.name)
    return Image(filename=fo.name)

In [None]:
%%python
spark.sql("show tables").show()
count_drug_variants_df = spark.sql("select * from n_drug_variants").toPandas()
print(count_drug_variants_df[:20])

In [None]:
spark.sql("show tables").show
sanitizedHospitalization.select(
    $"drugcode"
    ,$"VareNavn"
).where($"drugcode"==="N05AH03").distinct.show(68,false)

## Variants of the same drug (same ATC code)

In [None]:
%%python
plt.clf()
count_drug_variants_df.plot(
    x="drugcode",
    y="count",
    label="Number of product names per ATC code"
)
plt.title("drugs, ordered by highers to lowest number of products per ATC code")
retval=plotfig_magic()

## Birth year distribution of live and dead patients

In [None]:
%%python

#spark.sql("show tables").show()
totalDF = spark.sql("select * from hosp_patients").groupBy("birthyear").count().sort("birthyear").toPandas()
liveDF = spark.sql("select * from hosp_patients_live").groupBy("birthyear").count().sort("birthyear").toPandas()
deadDF = spark.sql("select * from hosp_patients_dead").groupBy("birthyear").count().sort("birthyear").toPandas()
print("ok")

In [None]:
%%python

plt.clf()
ax = liveDF[:50].plot(x="birthyear", y="count", label="Live patients")
deadDF[:46].plot(x="birthyear", y="count", label="Dead patients" , ax=ax)
plt.title("Number of live and dead hospitalized patients")
retval=plotfig_magic()

### Zoom in on 1900-1930

In [None]:
%%python

plt.clf()
ax = liveDF[:24].plot(x="birthyear", y="count", label="Live patients")
deadDF[:21].plot(x="birthyear", y="count", label="Dead patients" , ax=ax)
plt.title("Number of live and dead hospitalized patients")
retval=plotfig_magic()

## Death rate by age

In [None]:
%%python

deadDFcopy = deadDF.copy()
deadDFcopy.columns=["birthyear", "count_deceased"]
deadDFcopy.set_index("birthyear")

liveDFcopy = liveDF.copy()
liveDFcopy.set_index("birthyear")

totalDFcopy=totalDF.copy().rename(columns={"count":"total"})

distribution_totalDF = liveDFcopy.merge(deadDFcopy, how="outer", on="birthyear")\
    .merge(totalDFcopy, how="outer",on="birthyear")\
    .fillna(0.0)\
    .sort_values("birthyear")

#distribution_totalDF["total"] = distribution_totalDF["count"] +  distribution_totalDF["count_deceased"]
print(distribution_totalDF[40:50])
print(distribution_totalDF["total"].sum())

In [None]:
%%python
plt.clf()
ratesDF = distribution_totalDF.copy()
ratesDF["count"] = (distribution_totalDF["count_deceased"]/(distribution_totalDF["total"]))
#print(ratesDF)
#print(rates)
ratesDF.plot(x="birthyear", y="count")
plt.title("")
retval=plotfig_magic()

In [None]:
rawHospitalization.printSchema

In [None]:
rawHospitalization.select("Hjemmel").sample(0.05).where("hjemmel is not null").distinct.show(3,false)

In [None]:
rawHospitalization.select($"Diff_UtleveringDato".cast("integer"))
    .where("Diff_UtleveringDato is not null")
    .groupBy("Diff_UtleveringDato")
    .count
    .sort(desc("Diff_UtleveringDato"))
    .createOrReplaceTempView("drug_handout_diff")

In [None]:
%%python
drug_handoutDF = spark.sql("select * from drug_handout_diff").toPandas()

### Diff_UtleveringDato
* Day 0 is the first hospitalization for a given patient in 2013
* The following graph is aggregated across the whole population
* 

In [None]:
%%python
plt.clf()


#print(drug_handoutDF)
ax = drug_handoutDF.plot(x="Diff_UtleveringDato", y="count")
plt.title("TBD")

retval=plotfig_magic()

In [None]:
%%python
#print(drug_handoutDF)
ax = drug_handoutDF[600:900].plot(x="Diff_UtleveringDato", y="count")
plt.title("TBD")

retval=plotfig_magic()

In [None]:
sanitizedHospitalization.where("timestamp is null").select($"Diff_UtleveringDato").show(20)

In [None]:
val top_drugs_by_count = sanitizedHospitalization
    .where("timestamp is null").select("drugcode")
    .groupBy("drugcode").count.sort(desc("count"))

top_drugs_by_count.show()
top_drugs_by_count.take(5)

In [None]:
rawHospitalization.where($"ATCKode"==="B01AC06")
    .select($"Diff_UtleveringDato".cast("integer"))    
    .where("Diff_UtleveringDato is not null")
    .groupBy("Diff_UtleveringDato")
    .count
    .sort(desc("Diff_UtleveringDato"))
    .createOrReplaceTempView("drug_handout_diff_1")

In [None]:
%%python

drug_handoutDF1 =spark.sql("select * from drug_handout_diff_1").toPandas()

#print(drug_handoutDF)
ax = drug_handoutDF1.plot(x="Diff_UtleveringDato", y="count")
plt.title("TBD")

retval=plotfig_magic()

## Distribution of prescriptions by birthyear

In [None]:
rawHospitalization.where($"ATCKode"==="C07AB02")
    .select($"Diff_UtleveringDato".cast("integer"))    
    .where("Diff_UtleveringDato is not null")
    .groupBy("Diff_UtleveringDato")
    .count
    .sort(desc("Diff_UtleveringDato"))
    .createOrReplaceTempView("drug_handout_diff_2")



In [None]:
%%python

drug_handoutDF2 =spark.sql("select * from drug_handout_diff_2").toPandas()

#print(drug_handoutDF)
ax = drug_handoutDF2.plot(x="Diff_UtleveringDato", y="count")
plt.title("TBD")

retval=plotfig_magic()