In [1]:
import pandas as pd

# 挑战所需训练数据集，复制链接粘贴到浏览器即可下载
df = pd.read_csv(
    "./mushrooms.csv"
)
df.head()

Unnamed: 0,class,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,...,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,p,x,s,w,f,c,f,c,n,p,...,s,w,w,p,w,o,p,n,s,d
1,p,x,s,e,f,s,f,c,n,b,...,k,w,w,p,w,o,e,w,v,p
2,p,k,s,e,f,y,f,c,n,b,...,s,p,p,p,w,o,e,w,v,d
3,p,f,f,g,f,f,f,c,b,p,...,k,b,n,p,w,o,l,h,y,g
4,e,f,f,n,f,n,f,w,b,h,...,s,w,w,p,w,o,e,k,s,g


In [2]:
y = df["class"]  # 目标
X = df.drop(columns="class")
# X = pd.get_dummies(df.iloc[:, 1:])  # 读取特征并独热编码
# iloc[:, 1:] 取所有行，列从第一列开始

Unnamed: 0,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,stalk-shape,...,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,x,s,w,f,c,f,c,n,p,e,...,s,w,w,p,w,o,p,n,s,d
1,x,s,e,f,s,f,c,n,b,t,...,k,w,w,p,w,o,e,w,v,p
2,k,s,e,f,y,f,c,n,b,t,...,s,p,p,p,w,o,e,w,v,d
3,f,f,g,f,f,f,c,b,p,e,...,k,b,n,p,w,o,l,h,y,g
4,f,f,n,f,n,f,w,b,h,t,...,s,w,w,p,w,o,e,k,s,g


In [3]:
X = pd.get_dummies(X)  # 独热编码
X.head()

Unnamed: 0,cap-shape_b,cap-shape_c,cap-shape_f,cap-shape_k,cap-shape_s,cap-shape_x,cap-surface_f,cap-surface_g,cap-surface_s,cap-surface_y,...,population_s,population_v,population_y,habitat_d,habitat_g,habitat_l,habitat_m,habitat_p,habitat_u,habitat_w
0,False,False,False,False,False,True,False,False,True,False,...,True,False,False,True,False,False,False,False,False,False
1,False,False,False,False,False,True,False,False,True,False,...,False,True,False,False,False,False,False,True,False,False
2,False,False,False,True,False,False,False,False,True,False,...,False,True,False,True,False,False,False,False,False,False
3,False,False,True,False,False,False,True,False,False,False,...,False,False,True,False,True,False,False,False,False,False
4,False,False,True,False,False,False,True,False,False,False,...,True,False,False,False,True,False,False,False,False,False


In [5]:
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier()  # 随机森林
np.mean(cross_val_score(model, X, y, cv=5))  # 5 次交叉验证求平均

1.0

In [7]:
import joblib

model.fit(X, y)  # 训练模型
joblib.dump(model, "mushrooms.pkl")  # 保存模型

['mushrooms.pkl']

In [11]:
%%writefile mushrooms-predict.py
# 构建 Flask Web
# 将此单元格代码写入 predict.py 文件方便后面执行
import joblib
import pandas as pd
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route("/", methods=["POST"])  # 请求方法为 POST
def inference():
    query_df = pd.DataFrame(request.json)  # 将 JSON 变为 DataFrame
    
    df = pd.read_csv("./mushrooms.csv")  # 读取数据
    X = pd.get_dummies(df.iloc[:, 1:])  # 读取特征并独热编码
    query = pd.get_dummies(query_df).reindex(columns=X.columns, fill_value=0)  # 将请求数据 DataFrame 处理成独热编码样式
    clf = joblib.load('mushrooms.pkl')  # 加载模型
    prediction = clf.predict(query)  # 模型推理
    return jsonify({"prediction": list(prediction)})  # 返回推理结果

Writing mushrooms-predict.py


In [12]:
# Notebook 中必须以子进程才能正常启动 Flask
import time
import subprocess as sp

# 启动子进程执行 Flask app
server = sp.Popen("FLASK_APP=mushrooms-predict.py flask run", shell=True)
time.sleep(5)  # 等待 5 秒保证 Flask 启动成功
server

<Popen: returncode: None args: 'FLASK_APP=mushrooms-predict.py flask run'>

In [15]:
import json

# 从测试数据中取 1 条用于测试推理
df_test = pd.read_csv("./mushrooms_test.csv")
sample_data = df.sample(1).to_json(orient='records')
sample_json = json.loads(sample_data)
sample_json

[{'class': 'p',
  'cap-shape': 'f',
  'cap-surface': 'y',
  'cap-color': 'y',
  'bruises': 'f',
  'odor': 'f',
  'gill-attachment': 'f',
  'gill-spacing': 'c',
  'gill-size': 'b',
  'gill-color': 'p',
  'stalk-shape': 'e',
  'stalk-root': 'b',
  'stalk-surface-above-ring': 'k',
  'stalk-surface-below-ring': 'k',
  'stalk-color-above-ring': 'b',
  'stalk-color-below-ring': 'p',
  'veil-type': 'p',
  'veil-color': 'w',
  'ring-number': 'o',
  'ring-type': 'l',
  'spore-print-color': 'h',
  'population': 'v',
  'habitat': 'd'}]

In [16]:
import requests

requests.post(url="http://localhost:5000", json=sample_json).content  # 建立 POST 请求，并发送数据请求

b'{"prediction":["yes"]}\n'

In [17]:
server.terminate()  # 结束子进程，关闭端口占用