In [1]:
#-*- coding: utf-8 -*-

import pandas as pd
import numpy as np
import matplotlib as mpl
%matplotlib notebook
#import seaborn as sns

from sklearn.impute import SimpleImputer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn import tree

import graphviz

from tqdm import tqdm_notebook

# 그래프에서 마이너스 폰트 깨지는 문제에 대한 대처
mpl.rcParams['axes.unicode_minus'] = False

### multi class label 시도
 - HN16_ALL.sas7bdat : 설문지 데이터
 - hn16_ffq.sas7bdat : 섭취 데이터

In [2]:
def loadData() :

    '''
    데이터 로드
    '''
    df_all = pd.read_sas("./HN16_ALL.sas7bdat", format = 'sas7bdat', encoding='iso-8859-1')
    print("df_all.shape : ", df_all.shape)
    
    df_ffq = pd.read_sas("./hn16_ffq.sas7bdat", format = 'sas7bdat', encoding='iso-8859-1')
    print("df_ffq.shape : ", df_ffq.shape)
    
    '''
    공통 컬럼 추출 및 Merge key 설정
    '''
    intersection_columns = list(set(df_all.columns.values).intersection(df_ffq.columns.values))
    print("intersection_columns : ", intersection_columns)
    
    if(df_all.ID.name in intersection_columns) :
        intersection_columns.remove(df_all.ID.name)
    
    df_ffq.drop(intersection_columns, axis=1, inplace=True)
    print("intersection_columns -> df_ffq.shape : ", df_ffq.shape)
    
    df_merge = pd.merge(df_all, df_ffq, how='left', on=[df_all.ID.name])
    print("df_merge.shape : ", df_merge.shape)
    
    return df_merge

df_data = loadData()
df_data.head()

df_all.shape :  (8150, 768)
df_ffq.shape :  (8150, 434)
intersection_columns :  ['year', 'ho_incm', 'psu', 'incm', 'apt_t', 'wt_pft', 'ID', 'kstrata', 'age', 'age_month', 'wt_ntr', 'wt_hmnt', 'wt_pfhmnt', 'wt_pfnt', 'wt_itvex', 'wt_pfhm', 'wt_hm', 'wt_tot', 'mod_d', 'occp', 'ID_fam', 'sex', 'region', 'town_t', 'edu', 'wt_hs']
intersection_columns -> df_ffq.shape :  (8150, 409)
df_merge.shape :  (8150, 1176)


Unnamed: 0,mod_d,ID,ID_fam,year,region,town_t,apt_t,psu,sex,age,...,FQ_FE,FQ_NA,FQ_K,FQ_VA,FQ_RETIN,FQ_CAROT,FQ_B1,FQ_B2,FQ_NIAC,FQ_VITC
0,2018.02.01.,A651172801,A6511728,2016.0,1.0,1.0,2.0,A651,1.0,61.0,...,,,,,,,,,,
1,2018.02.01.,A651172802,A6511728,2016.0,1.0,1.0,2.0,A651,2.0,57.0,...,,,,,,,,,,
2,2018.02.01.,A651183001,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,41.0,...,12.879218,3671.647791,2226.426572,459.311661,84.990392,2158.129699,1.628015,1.13383,13.896055,79.150355
3,2018.02.01.,A651183002,A6511830,2016.0,1.0,1.0,2.0,A651,1.0,39.0,...,,,,,,,,,,
4,2018.02.01.,A651183003,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,38.0,...,,,,,,,,,,


In [3]:
def filter(df_in) :
    
    
    '''
    데이터 필터링
    '''
    # 19세 이상
    df_in = df_in.loc[(18 < df_in.age), :]
    print("filter age : ", df_in.shape)
    
    # 혈압치료중인 대상 제외
    df_in = df_in.loc[(1 < df_in.DI1_pt) | (df_in.DI1_pt < 1), :]
    print("filter 혈압치료여부", df_in.shape)
    
    df_in = df_in.dropna(axis=1, how='all')
    print("dropna how=all : ", df_in.shape)
    
    # 고혈압 유병여부 NaN값 대상 삭제
    df_in = df_in.dropna(subset=[df_in.HE_HP.name])
    print("dropna how=subset=['HE_HP'] : ", df_in.shape)
    
    '''
    NaN 처리
    '''
    #imp = SimpleImputer(missing_values=pd.np.nan, strategy="most_frequent")
    imp = SimpleImputer(strategy="most_frequent")
    df_in = pd.DataFrame(imp.fit_transform(df_in), columns=df_in.columns)
    
    '''
    convert numeric
    '''
    df_in = df_in.convert_objects(convert_numeric=True)
    
    return df_in

df_filter = filter(df_data)
df_filter.head()

filter age :  (6382, 1176)
filter 혈압치료여부 (4727, 1176)
dropna how=all :  (4727, 1122)
dropna how=subset=['HE_HP'] :  (4708, 1122)


For all other conversions use the data-type specific converters pd.to_datetime, pd.to_timedelta and pd.to_numeric.


Unnamed: 0,mod_d,ID,ID_fam,year,region,town_t,apt_t,psu,sex,age,...,FQ_FE,FQ_NA,FQ_K,FQ_VA,FQ_RETIN,FQ_CAROT,FQ_B1,FQ_B2,FQ_NIAC,FQ_VITC
0,2018.02.01.,A651172802,A6511728,2016.0,1.0,1.0,2.0,A651,2.0,57.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
1,2018.02.01.,A651183001,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,41.0,...,12.879218,3671.647791,2226.426572,459.311661,84.990392,2158.129699,1.628015,1.13383,13.896055,79.150355
2,2018.02.01.,A651183002,A6511830,2016.0,1.0,1.0,2.0,A651,1.0,39.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
3,2018.02.01.,A651183003,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,38.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
4,2018.02.01.,A651194902,A6511949,2016.0,1.0,1.0,2.0,A651,2.0,44.0,...,11.622384,3837.131006,2196.359827,585.326618,67.557719,3038.498438,1.467602,1.293351,12.230508,65.639278


In [4]:
# index 재생성
df_filter = df_filter.set_index(df_filter.ID.name).reset_index()
df_filter.head()

Unnamed: 0,ID,mod_d,ID_fam,year,region,town_t,apt_t,psu,sex,age,...,FQ_FE,FQ_NA,FQ_K,FQ_VA,FQ_RETIN,FQ_CAROT,FQ_B1,FQ_B2,FQ_NIAC,FQ_VITC
0,A651172802,2018.02.01.,A6511728,2016.0,1.0,1.0,2.0,A651,2.0,57.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
1,A651183001,2018.02.01.,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,41.0,...,12.879218,3671.647791,2226.426572,459.311661,84.990392,2158.129699,1.628015,1.13383,13.896055,79.150355
2,A651183002,2018.02.01.,A6511830,2016.0,1.0,1.0,2.0,A651,1.0,39.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
3,A651183003,2018.02.01.,A6511830,2016.0,1.0,1.0,2.0,A651,2.0,38.0,...,1.999242,279.399765,223.900051,29.07613,0.0,76.955805,0.184688,0.087371,1.782046,2.35699
4,A651194902,2018.02.01.,A6511949,2016.0,1.0,1.0,2.0,A651,2.0,44.0,...,11.622384,3837.131006,2196.359827,585.326618,67.557719,3038.498438,1.467602,1.293351,12.230508,65.639278


In [5]:
'''
고혈압 유병여부
    1. 정상
    2. 고혈압전단계
    3. 고혈압
classification을 위해3. 고혈압을 2.로 변경하여 이산형으로 수정
'''

#df_data.loc[df_data.HE_HP == 3.0, 'HE_HP'] = 2.0
print("### HE_HP count###\n", pd.value_counts(df_filter.HE_HP.values, sort=False).sort_index()) # 고혈압 유병 여부

### HE_HP count###
 1.0    2640
2.0    1468
3.0     600
dtype: int64


In [6]:
def runDTR(in_df, x_name_list, y_name) :

    arrX = in_df[x_name_list]
    arrY = in_df[y_name]

    #print("x na : ", arrX1[x_name].isna().sum())
    #print("y na : ", arrY.isna().sum())
    #arrY = arrY.fillna(-1)

    model = DecisionTreeClassifier(
        criterion = 'gini',
        max_depth=4
    ).fit(arrX,arrY)
    
    pdf_name = "./pdf_multi_class/%s_feature-%s" % (y_name,x_name_list.__len__())
    print(pdf_name)

    dot_data = tree.export_graphviz(model,
                                    out_file=None, 
                                    feature_names=x_name_list, 
                                    # class_names=['정상', '고혈압전단계', '고혈압']
                                    class_names=['1','2','3']
                                   )

    graph = graphviz.Source(dot_data) 
    graph.render(pdf_name)
    #graph.render("%s+%s" % (x_name,y_name))

    modelPrediction = model.predict(arrX)
    accuracyRate = accuracy_score(
        y_true=arrY
    ,   y_pred=modelPrediction
    )

    #print('정확도 = ', accuracyRate)

    return accuracyRate

In [7]:
# HE_HP(고혈압 유병여부) 제외
skip_list = [
    df_filter.HE_HP.name
    #, df_data.ID.name
]

# accuracyRate 담을 변수 준비
accuracyRate_list = []
for y_name in skip_list :
    for x_name in tqdm_notebook(df_filter.columns):
        if x_name in skip_list : continue
        if df_filter[x_name].dtype == 'object' : continue
        #accuracyRate = runDTR(df_filter, [x_name], y_name)
        accuracyRate = None
        accuracyRate_list.append([x_name, y_name, accuracyRate])
    
# accuracyRate_list > dataframe으로 변환해서 csv파일로 저장   
df_accuracyRate = pd.DataFrame(accuracyRate_list)
df_accuracyRate.to_csv("accuracyRate_list_multi_class.csv")

HBox(children=(IntProgress(value=0, max=1122), HTML(value='')))




In [8]:
#pd.value_counts(df_filter[df_filter.HE_Uacid.name]).sort_index()
print(df_filter[df_filter.HE_HP.name].dtype)
print(df_filter[df_filter.HE_HP.name][:5])

print(df_filter[df_filter.HE_Uacid.name].dtype)
print(df_filter[df_filter.HE_Uacid.name][:5])

Y = df_filter.HE_HP.name
X = [
      df_filter.HE_Uacid.name
    , df_filter.ainc_1.name
    , df_filter.BS6_2.name
]

runDTR(df_filter, X, Y)

float64
0    2.0
1    1.0
2    3.0
3    1.0
4    1.0
Name: HE_HP, dtype: float64
float64
0    4.0
1    5.0
2    6.8
3    3.9
4    3.7
Name: HE_Uacid, dtype: float64
./pdf_multi_class/HE_HP_feature-3


0.5783772302463891

In [41]:
check_count = 0
for name in df_filter.columns :
    
    # 2 < count < 20
    value_counts = pd.value_counts(getattr(df_filter, name)).sort_index()
    key_count = len(value_counts.keys())
    if(key_count < 3 or 20 < key_count) : continue
    
    # int 변환이 가능한 것
    try :
      [int(k) for k in value_counts.keys()]
    except : 
        #print("%s pass" % name)
        continue
        
    print("\n", value_counts)
    check_count += 1

check_count    


 1.0      930
2.0      311
3.0      234
4.0      260
5.0      158
6.0      156
7.0       87
8.0      109
9.0     1113
10.0     126
11.0     141
12.0     178
13.0     154
14.0     156
15.0     232
16.0     269
17.0      94
Name: region, dtype: int64

 1.0    1154
2.0    1196
3.0    1178
4.0    1180
Name: incm, dtype: int64

 1.0     683
2.0    1136
3.0    1409
4.0    1480
Name: ho_incm, dtype: int64

 1.0     668
2.0     393
3.0    1531
4.0    2116
Name: edu, dtype: int64

 1.0     675
2.0     550
3.0     627
4.0     145
5.0     464
6.0     348
7.0    1899
Name: occp, dtype: int64

 1.0     444
2.0    1127
3.0    1282
4.0    1367
5.0     386
6.0     102
Name: cfam, dtype: int64

 1.0     444
2.0     834
3.0      66
4.0    2352
5.0     393
6.0     251
7.0     368
Name: genertn, dtype: int64

 10.0     264
20.0    4441
99.0       3
Name: allownc, dtype: int64

 1.0    1489
2.0    2531
3.0     686
9.0       2
Name: house, dtype: int64

 1.0    1508
2.0    2751
3.0     238
4.0     176
5.0 

Name: DL1_pt, dtype: int64

 0.0    3789
1.0     707
9.0     212
Name: DJ8_dg, dtype: int64

 0.0      93
1.0     614
8.0    3789
9.0     212
Name: DJ8_pr, dtype: int64

 0.0     483
1.0     224
8.0    3789
9.0     212
Name: DJ8_pt, dtype: int64

 0.0    4201
1.0     295
9.0     212
Name: DJ6_dg, dtype: int64

 0.0     174
1.0     121
8.0    4201
9.0     212
Name: DJ6_pr, dtype: int64

 0.0     258
1.0      37
8.0    4201
9.0     212
Name: DJ6_pt, dtype: int64

 0.0    4247
1.0     248
9.0     213
Name: DH4_dg, dtype: int64

 0.0     209
1.0      39
8.0    4247
9.0     213
Name: DH4_pr, dtype: int64

 0.0     232
1.0      16
8.0    4247
9.0     213
Name: DH4_pt, dtype: int64

 0.0    4155
1.0     340
9.0     213
Name: DH2_dg, dtype: int64

 0.0     181
1.0     159
8.0    4155
9.0     213
Name: DH2_pr, dtype: int64

 0.0     249
1.0      91
8.0    4155
9.0     213
Name: DH2_pt, dtype: int64

 0.0    4437
1.0      58
9.0     213
Name: DH3_dg, dtype: int64

 0.0      12
1.0      46
8.0   

Name: BP7, dtype: int64

 1.0     106
2.0    1707
3.0    2820
9.0      75
Name: BS1_1, dtype: int64

 1.0     774
2.0     148
3.0     891
8.0    2820
9.0      75
Name: BS3_1, dtype: int64

 1.0       11
2.0        9
3.0       15
4.0        5
5.0       12
7.0        6
8.0        1
10.0      18
12.0       1
15.0      25
18.0       1
20.0      37
22.0       1
25.0       1
28.0       1
30.0       4
88.0    4485
99.0      75
Name: BS3_3, dtype: int64

 0.0      696
1.0       48
2.0       36
3.0       23
4.0        5
5.0       21
6.0       26
7.0        6
8.0        7
9.0        5
10.0      15
11.0       3
88.0    3742
99.0      75
Name: BS6_2_2, dtype: int64

 0.0      649
1.0       43
2.0       26
3.0       28
4.0       12
5.0       26
6.0       46
7.0        9
8.0       14
9.0        9
10.0      24
11.0       5
88.0    3742
99.0      75
Name: BS6_4_2, dtype: int64

 1.0     207
2.0     288
3.0     141
4.0     286
8.0    3711
9.0      75
Name: BS5_4, dtype: int64

 1.0     540
2.0     382


Name: FF_RICE, dtype: int64

 1.0      882
2.0     3075
3.0       92
4.0       28
88.0     182
99.0     449
Name: FA_RICE, dtype: int64

 1.0      316
2.0       95
3.0      152
4.0      139
5.0      268
6.0       93
7.0      666
8.0     2259
9.0      271
99.0     449
Name: FF_BARLEY, dtype: int64

 1.0      954
2.0     2897
3.0       77
4.0       15
88.0     316
99.0     449
Name: FA_BARLEY, dtype: int64

 1.0      779
2.0      578
3.0     2167
4.0      500
5.0      228
6.0        5
7.0        2
99.0     449
Name: FF_BIBIM, dtype: int64

 1.0      300
2.0     3024
3.0      156
88.0     779
99.0     449
Name: FA_BIBIM, dtype: int64

 1.0     2382
2.0      759
3.0      678
4.0      299
5.0      129
6.0        7
7.0        3
8.0        2
99.0     449
Name: FF_GIMBAB, dtype: int64

 1.0       96
2.0     2634
3.0      228
4.0      304
88.0     997
99.0     449
Name: FA_GIMBAB, dtype: int64

 1.0     2804
2.0      870
3.0      483
4.0       85
5.0       16
7.0        1
99.0     449
Name: FF_

Name: FF_S_EGG, dtype: int64

 1.0       90
2.0     2567
3.0      745
88.0     857
99.0     449
Name: FA_S_EGG, dtype: int64

 1.0      425
2.0      695
3.0     2322
4.0      563
5.0      246
6.0        8
99.0     449
Name: FF_R_PORK, dtype: int64

 1.0      374
2.0     2874
3.0      586
88.0     425
99.0     449
Name: FA_R_PORK, dtype: int64

 1.0     2812
2.0      969
3.0      331
4.0      112
5.0       30
6.0        3
7.0        2
99.0     449
Name: FF_S_PORK, dtype: int64

 1.0      295
2.0      746
3.0      406
88.0    2812
99.0     449
Name: FA_S_PORK, dtype: int64

 1.0      711
2.0     2362
3.0      741
4.0      318
5.0      120
6.0        7
99.0     449
Name: FF_F_PORK, dtype: int64

 1.0      505
2.0     2532
3.0      511
88.0     711
99.0     449
Name: FA_F_PORK, dtype: int64

 1.0     2560
2.0      895
3.0      530
4.0      220
5.0       52
6.0        2
99.0     449
Name: FF_C_PORK, dtype: int64

 1.0      326
2.0      905
3.0      468
88.0    2560
99.0     449
Name: FA_C_P

Name: FA_SVEG, dtype: int64

 1.0     3107
2.0      497
3.0      361
4.0      177
5.0       88
6.0       20
7.0        7
8.0        2
99.0     449
Name: FF_ROOT, dtype: int64

 1.0      288
2.0      648
3.0      216
88.0    3107
99.0     449
Name: FA_ROOT, dtype: int64

 1.0     2437
2.0      964
3.0      580
4.0      207
5.0       70
6.0        1
99.0     449
Name: FF_PCAKE, dtype: int64

 1.0      417
2.0      545
3.0      860
88.0    2437
99.0     449
Name: FA_PCAKE, dtype: int64

 1.0     3093
2.0      842
3.0      267
4.0       50
5.0        7
99.0     449
Name: FF_F_VEG, dtype: int64

 1.0       74
2.0      383
3.0      709
88.0    3093
99.0     449
Name: FA_F_VEG, dtype: int64

 1.0     2392
2.0      714
3.0      614
4.0      350
5.0      170
6.0       12
7.0        5
8.0        1
9.0        1
99.0     449
Name: FF_MUSHRO, dtype: int64

 1.0      126
2.0      749
3.0      992
88.0    2392
99.0     449
Name: FA_MUSHRO, dtype: int64

 1.0      328
2.0      303
3.0      512
4.0    

Name: FF_TANG_YR, dtype: int64

 1.0      270
2.0      864
3.0     2955
88.0     170
99.0     449
Name: FA_TANG, dtype: int64

 1.0      192
2.0     4067
99.0     449
Name: FS_BANANA, dtype: int64

 1.0      595
2.0      533
3.0     2027
4.0      547
5.0      425
6.0       72
7.0       58
8.0        2
99.0     449
Name: FF_BANANA, dtype: int64

 1.0     2513
2.0       47
3.0      640
4.0      532
5.0      402
6.0       70
7.0       53
8.0        2
99.0     449
Name: FF_BANANA_YR, dtype: int64

 1.0       96
2.0     3168
3.0      400
88.0     595
99.0     449
Name: FA_BANANA, dtype: int64

 1.0     1160
2.0     3099
99.0     449
Name: FS_CITRUS, dtype: int64

 1.0     2363
2.0      587
3.0      486
4.0      392
5.0      331
6.0       46
7.0       46
8.0        6
9.0        2
99.0     449
Name: FF_CITRUS, dtype: int64

 1.0     3472
3.0      496
4.0      201
5.0       84
6.0        5
7.0        1
99.0     449
Name: FF_CITRUS_YR, dtype: int64

 1.0      345
2.0     2687
3.0      249
88.0 

Name: FQ_OKIMCH, dtype: int64

 0.000000     2841
0.232558      418
0.581395      495
1.000000      398
3.000000      363
5.500000       66
7.000000      100
14.000000      19
21.000000       8
Name: FQ_SVEG, dtype: int64

 0.000000     3556
0.232558      497
0.581395      361
1.000000      177
3.000000       88
5.500000       20
7.000000        7
14.000000       2
Name: FQ_ROOT, dtype: int64

 0.000000    2886
0.232558     964
0.581395     580
1.000000     207
3.000000      70
5.500000       1
Name: FQ_PCAKE, dtype: int64

 0.000000    3542
0.232558     842
0.581395     267
1.000000      50
3.000000       7
Name: FQ_F_VEG, dtype: int64

 0.000000     2841
0.232558      714
0.581395      614
1.000000      350
3.000000      170
5.500000       12
7.000000        5
14.000000       1
21.000000       1
Name: FQ_MUSHRO, dtype: int64

 0.000000      328
0.232558      303
0.581395      512
1.000000      537
3.000000     2656
5.500000      210
7.000000      125
14.000000      31
21.000000      

Name: FQ_CHOCO, dtype: int64

 0.000000    2782
0.232558     585
0.581395     693
1.000000     408
3.000000     210
5.500000      24
7.000000       6
Name: FQ_ICECM, dtype: int64

 0.000000    3729
0.232558     443
0.581395     272
1.000000     122
3.000000      90
5.500000      24
7.000000      28
Name: FQ_PEANUT, dtype: int64

 0.000000    3911
0.232558     555
0.581395     187
1.000000      38
3.000000      10
5.500000       4
7.000000       3
Name: FQ_CHNUT, dtype: int64

 0.000000     3382
0.232558      343
0.581395      297
1.000000      221
3.000000      374
5.500000       52
7.000000       36
14.000000       2
21.000000       1
Name: FQ_SOJU, dtype: int64

 0.000000     3245
0.232558      372
0.581395      392
1.000000      301
3.000000      339
5.500000       39
7.000000       19
14.000000       1
Name: FQ_BEER, dtype: int64

 0.000000    4272
0.232558     218
0.581395     111
1.000000      53
3.000000      46
5.500000       4
7.000000       4
Name: FQ_RWINE, dtype: int64


820