# 蘑菇分类模型部署和推理

In [1]:
# 导入数据
import pandas as pd

# 挑战所需训练数据集，复制链接粘贴到浏览器即可下载
df = pd.read_csv("mushrooms.csv")
df.head()

Unnamed: 0,class,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,...,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,p,x,s,w,f,c,f,c,n,p,...,s,w,w,p,w,o,p,n,s,d
1,p,x,s,e,f,s,f,c,n,b,...,k,w,w,p,w,o,e,w,v,p
2,p,k,s,e,f,y,f,c,n,b,...,s,p,p,p,w,o,e,w,v,d
3,p,f,f,g,f,f,f,c,b,p,...,k,b,n,p,w,o,l,h,y,g
4,e,f,f,n,f,n,f,w,b,h,...,s,w,w,p,w,o,e,k,s,g


In [2]:
# 模型训练和保存
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
import joblib

X = pd.get_dummies(df.iloc[:, 1:])  # 读取特征并独热编码
y = df['class']  # 目标值

model = RandomForestClassifier()  # 随机森林
print(cross_val_score(model, X, y, cv=5).mean())  # 交叉验证结果

model.fit(X, y)  # 训练模型
joblib.dump(model, "mushrooms.pkl")  # 保存模型
print("model saved.")

1.0
model saved.


In [5]:
%%writefile predict.py
# 将此单元格代码写入 predict.py 文件方便后面执行
import joblib
import pandas as pd
from flask import Flask, request, jsonify

app = Flask(__name__)

@app.route("/", methods=["POST"])  # 请求方法为 POST
def inference():
    query_df = pd.DataFrame(request.json)  # 将 JSON 变为 DataFrame
    
    df = pd.read_csv("mushrooms.csv")  # 读取数据
    X = pd.get_dummies(df.iloc[:, 1:])  # 读取特征并独热编码
    query = pd.get_dummies(query_df).reindex(columns=X.columns, fill_value=0)  # 将请求数据 DataFrame 处理成独热编码样式
    
    clf = joblib.load('mushrooms.pkl')  # 加载模型
    prediction = clf.predict(query)  # 模型推理
    return jsonify({"prediction": list(prediction)})  # 返回推理结果

Writing predict.py


In [6]:
# Notebook 中必须以子进程才能正常启动 Flask
import time
import subprocess as sp

# 启动子进程执行 Flask app
server = sp.Popen("FLASK_APP=predict.py flask run", shell=True)
time.sleep(5)  # 等待 5 秒保证 Flask 启动成功
server

 * Serving Flask app 'predict.py'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


<Popen: returncode: None args: 'FLASK_APP=predict.py flask run'>

In [8]:
import json

# 从测试数据中取 1 条用于测试推理
df_test = pd.read_csv("mushrooms_test.csv")
sample_data = df.sample(1).to_json(orient='records')
sample_json = json.loads(sample_data)
sample_json

[{'class': 'e',
  'cap-shape': 'k',
  'cap-surface': 'y',
  'cap-color': 'e',
  'bruises': 't',
  'odor': 'n',
  'gill-attachment': 'f',
  'gill-spacing': 'c',
  'gill-size': 'b',
  'gill-color': 'w',
  'stalk-shape': 'e',
  'stalk-root': '?',
  'stalk-surface-above-ring': 's',
  'stalk-surface-below-ring': 's',
  'stalk-color-above-ring': 'e',
  'stalk-color-below-ring': 'e',
  'veil-type': 'p',
  'veil-color': 'w',
  'ring-number': 't',
  'ring-type': 'e',
  'spore-print-color': 'w',
  'population': 'c',
  'habitat': 'w'}]

In [9]:
import requests

requests.post(url="http://localhost:5000", json=sample_json).content  # 建立 POST 请求，并发送数据请求

127.0.0.1 - - [27/Nov/2025 14:24:43] "POST / HTTP/1.1" 200 -


b'{"prediction":["e"]}\n'