## 精准度和召回率的平衡

In [22]:
import numpy as np
import matplotlib.pyplot as plt

In [23]:
from sklearn import datasets

digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()

y[digits.target==9] = 1
y[digits.target!=9] = 0

In [24]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)

In [25]:
# sklearn中逻辑回归的predict是以0作为基准的，没有参数可以同时传入定义threshold
# 可以查看decision_function的方法。该方法返回每一个样本输出的值yhat
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)
y_predict = log_reg.predict(X_test)

In [26]:
from sklearn.metrics import f1_score

f1_score(y_test, y_predict)

0.8674698795180723

In [27]:
from sklearn.metrics import confusion_matrix

confusion_matrix(y_test, y_predict)

array([[403,   2],
       [  9,  36]], dtype=int64)

In [28]:
from sklearn.metrics import precision_score

precision_score(y_test, y_predict)

0.9473684210526315

In [29]:
from sklearn.metrics import recall_score

recall_score(y_test, y_predict)

0.8

In [30]:
log_reg.decision_function(X_test)[:10]

array([-21.45600241, -32.94974484, -16.40358049, -79.91453954,
       -48.16618205, -24.21675714, -44.76256848, -24.22873941,
        -1.22419553, -19.07705062])

In [31]:
log_reg.predict(X_test)[:10]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [32]:
# 先记录decision_scores
decision_scores = log_reg.decision_function(X_test)

In [33]:
np.min(decision_scores)

-85.76642438512

In [34]:
np.max(decision_scores)

19.975142566998983

In [35]:
# 将threshold定义为5
y_predict_2 = np.array(decision_scores >= 5, dtype='int')

In [36]:
confusion_matrix(y_test, y_predict_2)

array([[404,   1],
       [ 21,  24]], dtype=int64)

In [37]:
precision_score(y_test, y_predict_2)

0.96

In [38]:
recall_score(y_test, y_predict_2)

0.5333333333333333

In [39]:
y_predict_3 = np.array(decision_scores >= -5, dtype='int')

In [40]:
confusion_matrix(y_test, y_predict_3)

array([[389,  16],
       [  5,  40]], dtype=int64)

In [41]:
precision_score(y_test, y_predict_3)

0.7142857142857143

In [42]:
recall_score(y_test, y_predict_3)

0.8888888888888888