# 多遗传疾病非端到端多分类模型

In [None]:
import os
import json
from sklearn.model_selection  import cross_val_score, train_test_split
from sklearn.neighbors import KNeighborsClassifier
from deepface import DeepFace
from sklearn.metrics import classification_report

设置数据源路径以及保存模型位置

In [6]:
NORMALIZED_IMG_DIR = './dataset/'
clf_model_name = 'mutil-disease.pkl'

读取所有遗传病类型

In [7]:
class_list = os.listdir(NORMALIZED_IMG_DIR)

引入FaceNet

In [8]:
model = DeepFace.build_model('Facenet')

使用FaceNet获取人脸embedding特征向量

In [9]:
embeddings = {}

In [10]:
for class_name in class_list:
#     embeddings[class_name] = DeepFace.represent(os.listdir(os.path.join(NORMALIZED_IMG_DIR, class_name)), model_name = 'Facenet', enforce_detection=False)
    embeddings[class_name] = []
    for file_name in os.listdir(os.path.join(NORMALIZED_IMG_DIR, class_name)):
        print(f"getting {file_name}'s embedding")
        embeddings[class_name].append(
            DeepFace.represent(os.path.join(NORMALIZED_IMG_DIR, class_name, file_name), model_name = 'Facenet', model=model, enforce_detection=False)
        )

getting 0.jpg's embedding
getting 1.jpg's embedding
getting 10.jpg's embedding
getting 11.jpg's embedding
getting 12.jpg's embedding
getting 13.jpg's embedding
getting 14.jpg's embedding
getting 15.jpg's embedding
getting 16.jpg's embedding
getting 17.jpg's embedding
getting 18.jpg's embedding
getting 19.jpg's embedding
getting 2.jpg's embedding
getting 20.jpg's embedding
getting 21.jpg's embedding
getting 23.jpg's embedding
getting 3.jpg's embedding
getting 4.jpg's embedding
getting 5.jpg's embedding
getting 6.jpg's embedding
getting 7.jpg's embedding
getting 8.jpg's embedding
getting 9.jpg's embedding
getting 1.jpg's embedding
getting 10.jpg's embedding
getting 101.jpg's embedding
getting 102.jpg's embedding
getting 103.jpg's embedding
getting 104.jpg's embedding
getting 105.jpg's embedding
getting 106.jpg's embedding
getting 107.jpg's embedding
getting 108.jpg's embedding
getting 109.jpg's embedding
getting 11.jpg's embedding
getting 110.jpg's embedding
getting 111.jpg's embedding
g

构造训练集和测试集

In [11]:
X = []
Y = []
for key, value in embeddings.items():
    X = X + value
    Y = Y + [key] * len(value)

In [30]:
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.15)

使用KNN作为分类器，使用网格搜索搜索最优参数

In [31]:
k_range = range(1, 10)
best_k = -1
best_score = -1
#循环，取k=1到k=31，查看误差效果 
for k in k_range:
    knn = KNeighborsClassifier(
        n_neighbors=k,
        weights='distance'
    )
    #cv参数决定数据集划分比例，这里是按照5:1划分训练集和测试集
    score = cross_val_score(knn, Xtrain, Ytrain, cv=8, scoring='accuracy').mean()
    if score > best_score:
        best_score = score
        best_k = k
    print(f'knn with {k} neightbors: {score}')
print(f'best: acc: {best_score} k: {best_k}')

knn with 1 neightbors: 0.7228002070393376
knn with 2 neightbors: 0.7228002070393376
knn with 3 neightbors: 0.7353002070393375
knn with 4 neightbors: 0.7388198757763975
knn with 5 neightbors: 0.7477743271221532
knn with 6 neightbors: 0.7406573498964804
knn with 7 neightbors: 0.7442287784679089
knn with 8 neightbors: 0.7442805383022775
knn with 9 neightbors: 0.7388716356107661
best: acc: 0.7477743271221532 k: 5


使用最优参数训练KNN

In [32]:
knn = KNeighborsClassifier(n_neighbors=best_k, weights='distance')
knn.fit(Xtrain, Ytrain)

KNeighborsClassifier(weights='distance')

评估KNN

In [33]:
knn.score(Xtest, Ytest)

0.6767676767676768

In [35]:
y_pred = knn.predict(Xtest)

In [36]:
print(classification_report(Ytest, y_pred))

              precision    recall  f1-score   support

    Angelman       0.00      0.00      0.00         2
       Apert       0.76      0.72      0.74        40
   Fragile_X       0.60      0.35      0.44        17
      normal       0.63      0.80      0.70        40

    accuracy                           0.68        99
   macro avg       0.50      0.47      0.47        99
weighted avg       0.66      0.68      0.66        99



In [37]:
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score
print('accuracy:{}'.format(accuracy_score(Ytest, y_pred)))
print('precision:{}'.format(precision_score(Ytest, y_pred, average='micro')))
print('recall:{}'.format(recall_score(Ytest, y_pred,average='micro')))
print('f1-score:{}'.format(f1_score(Ytest, y_pred,average='micro')))

accuracy:0.6767676767676768
precision:0.6767676767676768
recall:0.6767676767676768
f1-score:0.6767676767676768


发现主要是一类疾病预测效果差，删去后重新训练模型，评估模型
发现效果有明显提升

In [38]:
del embeddings['Fragile_X']

In [45]:
X = []
Y = []
for key, value in embeddings.items():
    X = X + value
    Y = Y + [key] * len(value)

In [90]:
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.1)

In [91]:
knn = KNeighborsClassifier(n_neighbors=2, weights='distance')
knn.fit(Xtrain, Ytrain)

KNeighborsClassifier(n_neighbors=2, weights='distance')

In [92]:
knn.score(Xtest, Ytest)

0.8947368421052632

In [93]:
y_pred = knn.predict(Xtest)

In [94]:
print(classification_report(Ytest, y_pred))

              precision    recall  f1-score   support

    Angelman       0.00      0.00      0.00         0
       Apert       0.90      0.86      0.88        21
      normal       0.92      0.92      0.92        36

    accuracy                           0.89        57
   macro avg       0.61      0.59      0.60        57
weighted avg       0.91      0.89      0.90        57



In [95]:
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score
print('accuracy:{}'.format(accuracy_score(Ytest, y_pred)))
print('precision:{}'.format(precision_score(Ytest, y_pred, average='micro')))
print('recall:{}'.format(recall_score(Ytest, y_pred,average='micro')))
print('f1-score:{}'.format(f1_score(Ytest, y_pred,average='micro')))

accuracy:0.8947368421052632
precision:0.8947368421052632
recall:0.8947368421052632
f1-score:0.8947368421052632


In [76]:
print(classification_report(Ytest, y_pred))

              precision    recall  f1-score   support

    angelman       1.00      0.20      0.33         5
       apert       0.77      0.74      0.76        46
      normal       0.92      0.95      0.93       194

    accuracy                           0.89       245
   macro avg       0.90      0.63      0.67       245
weighted avg       0.89      0.89      0.89       245

