In [1]:
# 原生
import re
import pickle
# 第三方
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import precision_score, recall_score, f1_score, make_scorer
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

In [2]:
data = pickle.load(open('诗句.dat', 'rb'))

In [3]:
vectorizer = CountVectorizer(
    lowercase=False, analyzer='char',
    ngram_range=(1, 2),
    max_features=5000
)

In [4]:
vectorizer.fit(data['X'])

CountVectorizer(analyzer='char', binary=False, decode_error='strict',
        dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',
        lowercase=False, max_df=1.0, max_features=5000, min_df=1,
        ngram_range=(1, 2), preprocessor=None, stop_words=None,
        strip_accents=None, token_pattern='(?u)\\b\\w\\w+\\b',
        tokenizer=None, vocabulary=None)

In [5]:
len(vectorizer.get_feature_names())

5000

In [6]:
def get_sentence(vectorizer, s):
    a = vectorizer.transform([s]).toarray()[0]
    return a

In [7]:
X_data = np.array([get_sentence(vectorizer, x) for x in data['X']])

In [8]:
y_data = np.array(data['y'])

In [9]:
X_train, X_test, y_train, y_test = train_test_split(
    X_data, y_data,
    test_size=0.2, random_state=0
)

In [10]:
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

(12126, 5000) (3032, 5000) (12126,) (3032,)


In [12]:
clf = LinearDiscriminantAnalysis()

In [13]:
clf.fit(X_train, y_train)



LinearDiscriminantAnalysis(n_components=None, priors=None, shrinkage=None,
              solver='svd', store_covariance=False, tol=0.0001)

In [14]:
pred_train = clf.predict(X_train)
print('train precision: {}'.format(precision_score(y_train, pred_train)))
print('train recall: {}'.format(recall_score(y_train, pred_train)))
print('train f1: {}'.format(f1_score(y_train, pred_train)))
pred_test = clf.predict(X_test)
print('test precision: {}'.format(precision_score(y_test, pred_test)))
print('test recall: {}'.format(recall_score(y_test, pred_test)))
print('test f1: {}'.format(f1_score(y_test, pred_test)))

train precision: 0.9422288992078667
train recall: 0.956069844789357
train f1: 0.9490989131930113
test precision: 0.8095496473141617
test recall: 0.8086720867208672
test f1: 0.8091106290672451
