## Load datasets

In [1]:
import numpy as np
import pandas as pd

oo2012 = pd.read_csv('Data/R301_OO/OO2012.csv', dtype=str)
oo2013 = pd.read_csv('Data/R301_OO/OO2013.csv', dtype=str)

oo = pd.concat([oo2012, oo2013])

In [2]:
cd = pd.read_csv('Data/CD.csv', dtype=str)
cd = cd[cd['APPL_DATE'] > '20120000']

In [3]:
print(oo.columns)
print(len(oo))

Index(['FEE_YM', 'APPL_TYPE', 'HOSP_ID', 'APPL_DATE', 'CASE_TYPE', 'SEQ_NO',
       'ORDER_TYPE', 'DRUG_NO', 'DRUG_USE', 'DRUG_FRE', 'UNIT_PRICE',
       'TOTAL_QTY', 'TOTAL_AMT', 'SEQ'],
      dtype='object')
7120120


In [4]:
print(cd.columns)
print(len(cd))

Index(['FEE_YM', 'APPL_TYPE', 'HOSP_ID', 'APPL_DATE', 'CASE_TYPE', 'SEQ_NO',
       'CURE_ITEM_NO1', 'CURE_ITEM_NO2', 'CURE_ITEM_NO3', 'CURE_ITEM_NO4',
       'FUNC_TYPE', 'FUNC_DATE', 'ID_BIRTHDAY', 'ID', 'ACODE_ICD9_1',
       'ACODE_ICD9_2', 'ACODE_ICD9_3', 'ICD_OP_CODE', 'DRUG_DAY', 'MED_TYPE',
       'ID_SEX'],
      dtype='object')
1305343


## Data Processing

In [5]:
pk = ['FEE_YM', 'APPL_TYPE', 'HOSP_ID', 'APPL_DATE', 'SEQ_NO', 'CASE_TYPE']
variables = ['ID_SEX', 'ID_BIRTHDAY', 'ACODE_ICD9_1', 'DRUG_NO', 'APPL_DATE']

In [6]:
# cd: disease
# oo: medicine
# Using left join to join 2 tables
oo_cd = pd.merge(oo, cd, on=pk, how='left')
print('Length of joined table:', len(oo_cd))

Length of joined table: 7120120


In [7]:
oo_cd = oo_cd.dropna(subset=['ACODE_ICD9_1'])
print('Table Length after droping NaN value:', len(oo_cd))

Table Length after droping NaN value: 7113238


In [8]:
# Only use attribute ['ID_SEX', 'ID_BIRTHDAY', 'ACODE_ICD9_1', 'DRUG_NO', 'APPL_DATE']
oo_cd = oo_cd[variables]

# Alternating attribute name
oo_cd.columns = ['G', 'A', 'D', 'M', 'AD']
oo_cd.head()

Unnamed: 0,G,A,D,M,AD
0,F,1976-04-01,786.5,00203B,20130105
1,F,1976-04-01,786.5,A001085277,20130105
2,F,1976-04-01,786.5,A040130100,20130105
3,F,1976-04-01,786.5,AC03355212,20130105
4,F,1976-04-01,786.5,AC14025100,20130105


In [9]:
from tqdm import tqdm

def get_age(birthday, appl_date):
    # Calculate the age from birthday to apply date
    # [:4] -> return the 4 first digits as the year
    return int(appl_date[:4]) - int(birthday[:4])

ages = oo_cd['A'].values
appl_dates = oo_cd['AD'].values
ages_range = []

for i in tqdm(range(len(ages))):
    if type(ages[i]) == str:
        age = get_age(ages[i], appl_dates[i])
        age_text = '>65'
        if age < 18:
            age_text = '<18'
        elif age <= 30:
            age_text = '18<=30'
        elif age <= 50:
            age_text = '31<=50'
        elif age <= 65:
            age_text = '51<=65'
        else:
            age_text = '>65'

        ages_range.append(age)

oo_cd['A'] = ages_range

100%|██████████| 7113238/7113238 [00:16<00:00, 419360.13it/s]


In [10]:
# remove AD columns
oo_cd = oo_cd.drop(columns='AD')
oo_cd.head()

Unnamed: 0,G,A,D,M
0,F,37,786.5,00203B
1,F,37,786.5,A001085277
2,F,37,786.5,A040130100
3,F,37,786.5,AC03355212
4,F,37,786.5,AC14025100


In [11]:
# Read ICD data then create a mapping: disease code -> disease name (english, chinese)
icd9 = pd.read_csv('Data/ICD9.csv', dtype=str)
icd9_mapping = {}
for row in icd9.values:
    icd9_mapping[row[2]] = [row[0], row[1]]

icd_org = pd.read_csv('Data/ICD_original.csv', dtype=str)

for row in icd_org.values:
    icd9_mapping[row[0]] = [row[1], row[2]]

In [12]:
# Read DRUG data then create a mapping: medicine code -> medicine name
drug = pd.read_csv('Data/DRUG.csv', dtype=str)
drug_mapping = {}
for row in drug.values:
    drug_mapping[row[1]] = row[3]

## Model

In [13]:
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.models import BayesianModel
from pgmpy.inference import VariableElimination

model = BayesianModel([('A', 'D'),
                       ('G', 'D')])

print('Learning CPD using Maximum likelihood estimators')
model.fit(oo_cd, estimator=MaximumLikelihoodEstimator)

print('Inferencing with Bayesian Network')
infer = VariableElimination(model)

Learning CPD using Maximum likelihood estimators
Inferencing with Bayesian Network


## Phase #1: Disease Inference

In [92]:
from prettytable import PrettyTable

def disease_inference(evidence, K):
    ans = infer.query(['D'], evidence=evidence, show_progress=False)
    d = []
    
    for k, v in zip(ans.state_names['D'], ans.values):
        d.append([v, k])
        if str(k) not in icd9_mapping:
            icd9_mapping[str(k)] = ['', '']

    diseases = np.array(sorted(d, reverse=True))[:K]

    cond = ''
    for k, v in evidence.items():
        cond += f"{k}='{v}'"

    table = PrettyTable(['No.', 'Code', 'Disease English Name', 'Prob (%)'])
    table.align["Code"] = "l"
    table.align["Disease English Name"] = "l"
    
    print(f"P(D|{cond}):")
    for idx, (p, disease) in enumerate(diseases):
        table.add_row([idx+1, disease, icd9_mapping[str(disease)][0], f"{float(p)*100:.2f}"])
    print(table)
    return diseases


#### P(D|A=Age)

In [93]:
K = 10
evidence = {'A': 10}
diseases = disease_inference(evidence, K)


P(D|A='10'):
+-----+-------+---------------------------------------------------+----------+
| No. | Code  | Disease English Name                              | Prob (%) |
+-----+-------+---------------------------------------------------+----------+
|  1  | 465.9 | Acute upper respiretory infections of unspecified |  12.92   |
|  2  | 466.0 | Acute bronchitis                                  |   5.97   |
|  3  | 461.9 | Acute sinusitis, unspecified                      |   5.44   |
|  4  | 460.  |                                                   |   5.08   |
|  5  | 463.  |                                                   |   4.17   |
|  6  | 367.1 | Myopia                                            |   3.53   |
|  7  | 462.  |                                                   |   2.93   |
|  8  | 786.2 | Cough                                             |   2.83   |
|  9  | 521.0 | Dental caries                                     |   2.80   |
|  10 | 477.9 | Allergic rhinitis cause

#### P(D|G=Gender)

In [90]:
K = 10
evidence = {'G': 'M'}
diseases = disease_inference(evidence, K)

P(D|G='M'):
+-----+--------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+
| No. | Code   | Disease English Name                                                                                                                                                    | Prob (%) |
+-----+--------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+
|  1  | 465.9  | Acute upper respiretory infections of unspecified                                                                                                                       |   6.23   |
|  2  | 401.9  | Essential hypertension, unspecified                                                                                                                                     |   3.06   

#### P(D|A=Age, G=Gender)

In [91]:
K = 10
evidence = {'A': 10, 'G': 'M'}
diseases = disease_inference(evidence, K)

P(D|A='10'G='M'):
+-----+--------+-------------------------------------------------------+----------+
| No. | Code   | Disease English Name                                  | Prob (%) |
+-----+--------+-------------------------------------------------------+----------+
|  1  | 465.9  | Acute upper respiretory infections of unspecified     |  12.51   |
|  2  | 466.0  | Acute bronchitis                                      |   6.14   |
|  3  | 461.9  | Acute sinusitis, unspecified                          |   5.73   |
|  4  | 460.   |                                                       |   5.68   |
|  5  | 204.00 | Acute lymphoid leukemia, without mention of remission |   3.90   |
|  6  | 463.   |                                                       |   3.84   |
|  7  | 477.9  | Allergic rhinitis cause unspecified                   |   3.07   |
|  8  | 521.0  | Dental caries                                         |   2.83   |
|  9  | 786.2  | Cough                                    

## Phase #2: Medicine Inference

#### Model

In [67]:
evidence = {'A': 10, 'G': 'M'}
filtered_data = oo_cd[oo_cd['A']==evidence['A']][oo_cd['G']==evidence['G']]

model2 = BayesianModel([('D', 'M')])

print('Learning CPD using Maximum likelihood estimators')
model2.fit(filtered_data, estimator=MaximumLikelihoodEstimator)

print('Inferencing with Bayesian Network')
infer_n = VariableElimination(model2)

Learning CPD using Maximum likelihood estimators
Inferencing with Bayesian Network


  


In [96]:
for idx, (p, disease) in enumerate(diseases):
    if float(p) > 0:
        meds = infer_n.query(['M'], evidence={'D': disease}, show_progress=False)
        m = []
        for k, v in zip(meds.state_names['M'], meds.values):
            m.append([v, k])
        medicines = np.array(sorted(m, reverse=True)[:10])
        print(f'{idx+1:2d}. {disease} : {float(p)*100:.2f}%')
        table = PrettyTable(['No.', 'Code', 'Prob (%)'])
        table.align["Code"] = "l"
    
        for idx, (p2, med) in enumerate(medicines):
            if float(p2) > 0:
                table.add_row([idx+1, med, f"{float(p2)*100:.2f}"])
        print(table)
        print()

 1. 465.9 : 12.92%
+-----+------------+----------+
| No. | Code       | Prob (%) |
+-----+------------+----------+
|  1  | 05203C     |  10.14   |
|  2  | 00110C     |   7.13   |
|  3  | MA2        |   6.24   |
|  4  | MA1        |   4.72   |
|  5  | 00112C     |   2.13   |
|  6  | 00109C     |   1.69   |
|  7  | 00111C     |   1.54   |
|  8  | A040130100 |   1.12   |
|  9  | AC48934151 |   0.93   |
|  10 | A032492100 |   0.84   |
+-----+------------+----------+

 2. 466.0 : 5.97%
+-----+------------+----------+
| No. | Code       | Prob (%) |
+-----+------------+----------+
|  1  | 05203C     |  10.31   |
|  2  | MA2        |   7.04   |
|  3  | 00110C     |   6.59   |
|  4  | MA1        |   3.72   |
|  5  | 00112C     |   2.29   |
|  6  | 00109C     |   1.32   |
|  7  | 00114C     |   1.20   |
|  8  | AC43362151 |   1.15   |
|  9  | 00224C     |   1.03   |
|  10 | AC48934151 |   0.86   |
+-----+------------+----------+

 3. 461.9 : 5.44%
+-----+------------+----------+
| No. | Code   