In [1]:
import keras
from keras.models import model_from_json
from keras.utils import np_utils

from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import nilearn.image as nilimg

import os
import sys

from plotly.offline import iplot, init_notebook_mode, plot
import plotly.graph_objs as go
import plotly.figure_factory as ff

Using TensorFlow backend.


In [2]:
init_notebook_mode(connected=True)

In [3]:
directory = "/Volumes/matlab_share/nifti_img/3DEGIR_SAG/merged"
model_dir = "/Volumes/share/log/learning/gender_gm/result/10202229"
csv_dir = '/Volumes/matlab_share/csv'
SLICE_NUM = 185

In [4]:
def encode_label(y):
    processed = y - 1
    processed = np_utils.to_categorical(processed, num_classes=2)
    return processed

In [5]:
json_string = open(os.path.join(model_dir, 'model.json')).read()

In [6]:
model = model_from_json(json_string)

In [7]:
model.load_weights(os.path.join(model_dir, 'weights.hdf5'))

In [8]:
patients_list = pd.read_csv(os.path.join(csv_dir, 'patient_list.csv'),
                                   header=None,
                                   index_col=None, names=['dir', 'pid', 'page', 'psex'],
                                   dtype={'pid': 'object'})

In [9]:
properties_list = pd.read_csv(os.path.join(csv_dir, 'property.csv'),
                            header=0,
                            index_col=None,
                            dtype={'0002': 'object'})

In [10]:
predict_list = np.empty((0, 2)).astype(np.float32)

for pd_index, pd_data in patients_list.iterrows():

    patient = pd_data.dir

    try:
        nifti_files = os.listdir(os.path.join(directory, patient))
    except FileNotFoundError:
        continue

    raw_file = [file for file in nifti_files if file.startswith('c12018') & file.endswith('.nii')][0]
    personal_nifti = nilimg.smooth_img(os.path.join(directory, patient, raw_file), fwhm='fast')
    personal_img = personal_nifti.get_data()
    reshape_img = np.reshape(personal_img, (1, personal_img.shape[0], personal_img.shape[1], personal_img.shape[2]))
    if reshape_img.shape[3] < SLICE_NUM:
        train_img = np.append(reshape_img, np.zeros((1, reshape_img.shape[1], reshape_img.shape[2], SLICE_NUM - reshape_img.shape[3])), axis=3)
    elif reshape_img.shape[3] == SLICE_NUM:
        train_img = reshape_img
    else:
        sys.stderr.write('Data Size Error: This Nifti file size is ' + str(reshape_img.shape))
        sys.exit()

    pid = patients_list[patients_list.dir == str(patient)].pid.values[0]

    try:
        personal_sex = properties_list[properties_list['0002'] == pid]['0007'].values[0]
    except IndexError:
        continue
    train_label = encode_label(personal_sex)

    predict_sex = model.predict_classes(train_img)
    data = np.array([[int(personal_sex - 1), int(predict_sex[0])]])
    predict_list = np.append(predict_list, data, axis=0)

X: 正解ラベル Y:予測ラベル

In [11]:
cm = confusion_matrix(predict_list[:,0], predict_list[:,1].round(0))

In [12]:
trace = ff.create_annotated_heatmap(z=cm,
                   x=['Men', 'Women'],
                   y=['Men', 'Women'],
                   colorscale=[[0.0, '#f2f2f2'], [1.0, '#05b29c']],
                    showscale=True
                   )

for i in range(len(trace.layout.annotations)):
    trace.layout.annotations[i].font.size = 25

trace.layout.font.size = 18

trace.layout.xaxis.title = 'Predicted Labels'

plot(trace, filename="html/plot_confusionmatrix_2dgm_gender6.html", auto_open=False)
plot(trace, filename="svg/plot_confusionmatrix_2dgm_gender6.html", image_height=750, image_width=1200, image='svg', auto_open=False)

'file:///Users/yoshilab/PycharmProjects/data_plot/svg/plot_confusionmatrix_2dgm_gender6.html'

In [13]:
iplot(trace)

`plot_confusionmatrix_2dgm_gender`の各値を合算したConfusion Matrix  
変数名 : 正解ラベル 予測ラベル

In [14]:
MW = [4, 3, 5, 5, 4]
MM = [192, 194, 196, 188, 187]
WW = [146, 147, 145, 145, 146]
WM = [8, 6, 4, 12, 13]
cm_label = ['Men', 'Women']

In [15]:
total_cm =[[sum(MM), sum(MW)],
           [sum(WM), sum(WW)]]

In [16]:
trace2 = ff.create_annotated_heatmap(z=total_cm,
                                     x=cm_label,
                                     y=cm_label,
                                     colorscale=[[0.0, '#f2f2f2'], [1.0, '#05b29c']],
                                     showscale=True
                                     )
trace2.layout.font.size = 18

for i in range(len(trace.layout.annotations)):
    trace2.layout.annotations[i].font.size = 25
    
trace2.layout.xaxis.title = 'Predicted Labels'

plot(trace2, filename="html/plot_total_confusionmatrix_2dgm_gender.html", auto_open=False)
plot(trace2, filename="svg/plot_total_confusionmatrix_2dgm_gender.html", image_height=750, image_width=1200, image='svg', auto_open=False)

'file:///Users/yoshilab/PycharmProjects/data_plot/svg/plot_total_confusionmatrix_2dgm_gender.html'

In [17]:
iplot(trace2)