In [None]:
from flask import Flask, render_template, request, jsonify, send_from_directory, url_for
import os
from werkzeug.utils import secure_filename
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image
import pickle
from datetime import datetime
import pymysql
from contextlib import closing
import tensorflow as tf

# 创建Flask应用
app = Flask(__name__)

# 配置上传文件的目录和允许的文件类型
UPLOAD_FOLDER = 'static/images/'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024  # 限制上传文件最大为 16MB
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH

# 定义 Cast 层
class CastLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.cast(inputs, dtype=tf.float32)

# 加载各个模型，避免每次请求时都加载
with tf.keras.utils.custom_object_scope({'Cast': CastLayer}):
    flower_model = load_model('models/flower_model_updated.h5')  #输入尺寸 (224, 224, 3)
fashion_model = load_model('models/fashion_model.h5')  # 输入尺寸 (150, 150, 3)
animal_model = load_model('models/cats_and_dogs_optimized_model.h5')  # 输入尺寸 (150, 150, 3)

# 加载 Titanic 生存预测模型
with open('models/titanic_model.pkl', 'rb') as f:
    titanic_model = pickle.load(f)

# 数据库连接配置
DB_CONFIG = {
    'host': 'localhost',
    'user': 'root',
    'password': 'abc123',
    'database': 'deep_learning_app',
    'charset': 'utf8mb4'  # 使用 utf8mb4 编码
}

# 检查文件扩展名是否合法
def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

# 图像处理和预测通用函数
def process_and_predict(model, image_path, color_mode='rgb'):
    """ 根据模型的输入尺寸处理图像并进行预测 """
    target_size = model.input_shape[1:3]  # 自动获取模型输入尺寸 (height, width)
    img = image.load_img(image_path, target_size=target_size, color_mode=color_mode)
    img_array = image.img_to_array(img) / 255.0  # 归一化
    img_array = np.expand_dims(img_array, axis=0)  # 增加批量维度
    return model.predict(img_array)

# 分类函数
def classify_image(model, image_path, classes):
    """ 根据模型进行分类 """
    color_mode = 'grayscale' if model.input_shape[-1] == 1 else 'rgb'
    prediction = process_and_predict(model, image_path, color_mode)
    return classes[np.argmax(prediction)]

# 处理上传的图像并保存
def process_image(file):
    """ 处理上传图像并保存到本地目录 """
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(file_path)
        return filename, file_path
    return None, None

# 获取生存旅客信息
def get_survived_passengers():
    """ 从数据库中获取生存的旅客信息 """
    with closing(pymysql.connect(**DB_CONFIG)) as connection:
        with closing(connection.cursor(pymysql.cursors.DictCursor)) as cursor:
            cursor.execute("SELECT * FROM titanic_passengers WHERE survival_prob > 0.5")
            return cursor.fetchall()

# 路由 - 根路径
@app.route('/')
def home_page():
    return render_template('index.html')

# 路由 - Titanic获取生存旅客信息
@app.route('/survived_passengers')
def survived_passengers():
    survivors = get_survived_passengers()
    return render_template('survived_passengers.html', survivors=survivors)

# 路由 - Titanic生存预测
@app.route('/titanic_predict', methods=['GET', 'POST'])
def titanic_predict():
    submitted_data = None
    prediction = None
    error_message = None

    if request.method == 'POST':
        try:
            # 从表单获取输入数据并进行验证
            submitted_data = {
                'pclass': int(request.form.get('pclass', 0)),
                'age': float(request.form.get('age', 0)),
                'sex': int(request.form.get('sex', 0)),
                'sibsp': int(request.form.get('sibsp', 0)),
                'parch': int(request.form.get('parch', 0)),
                'fare': float(request.form.get('fare', 0)),
                'embarked': int(request.form.get('embarked', 0))
            }

            # 进行预测
            features = np.array([[submitted_data['pclass'], submitted_data['age'], submitted_data['sex'], 
                                  submitted_data['sibsp'], submitted_data['parch'], submitted_data['fare'], 
                                  submitted_data['embarked']]])
            survival_prob = titanic_model.predict_proba(features)[0][1]  # 获取生存概率

            prediction = round(survival_prob * 100, 2)  # 生存概率转换为百分比

        except ValueError as e:
            error_message = f"输入值错误: {str(e)}"

    return render_template('titanic_predict.html', 
                           submitted_data=submitted_data, 
                           prediction=prediction, 
                           error_message=error_message)


# 路由 - 猫狗图像分类
@app.route('/animal_classify', methods=['GET', 'POST'])
def animal_classify_page():     
    if request.method == 'POST':         
        file = request.files.get('file')         
        if not file or file.filename == '':             
            return jsonify({"error": "没有选择文件"})          
        
        # 处理图像并保存
        filename, file_path = process_image(file)         
        if filename:
            # 使用 process_and_predict 进行图像预处理并预测
            prediction = process_and_predict(animal_model, file_path)

            # 检查预测输出形状
            if prediction.shape[1] == 1:  
                result = 'Dog' if prediction[0][0] > 0.5 else 'Cat'
            else: 
                result = 'Dog' if prediction[0][1] > 0.5 else 'Cat'
            
            # 返回预测结果和图像路径
            return render_template('animal_classify.html', prediction_text=f'预测结果: {result}', 
                                    image_file=filename, 
                                    image_url=url_for('uploaded_file', filename=filename))
    
    return render_template('animal_classify.html')


# 路由 - 花卉图像分类
@app.route('/flower_classify', methods=['GET', 'POST'])
def flower_classify_page():
    if request.method == 'POST':
        file = request.files.get('file')
        if not file or file.filename == '':
            return jsonify({"error": "没有选择文件"})

        filename, file_path = process_image(file)
        if filename:
            classes = ['雏菊', '蒲公英', '玫瑰', '向日葵', '郁金香']
            prediction = classify_image(flower_model, file_path, classes)
            return render_template('flower_classify.html', prediction_text=prediction, image_file=filename, image_url=url_for('uploaded_file', filename=filename))
    return render_template('flower_classify.html')

# 路由 - 时尚服饰图像分类
@app.route('/fashion_classify', methods=['GET', 'POST'])
def fashion_classify_page():
    if request.method == 'POST':
        files = request.files.getlist('files')
        if not files:
            return jsonify({"error": "没有选择文件"})

        results = []
        # 处理每个文件
        for file in files:
            filename, file_path = process_image(file)
            if filename:
                classes = ['T恤', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '靴子']
                prediction = classify_image(fashion_model, file_path, classes)
                
                # 保存图像及分类信息到数据库
                with closing(pymysql.connect(**DB_CONFIG)) as connection:
                    with closing(connection.cursor()) as cursor:
                        sql = """INSERT INTO fashion_images (filename, category, upload_time)
                                 VALUES (%s, %s, %s)"""
                        cursor.execute(sql, (filename, prediction, datetime.now()))
                        connection.commit()
                
                results.append({"filename": filename, "prediction": prediction})
        
        return render_template('fashion_classify.html', results=results)
    
    return render_template('fashion_classify.html')

# 路由 - 显示上传的图像
@app.route('/uploads/<filename>')
def uploaded_file(filename):
    return send_from_directory(app.config['UPLOAD_FOLDER'], filename)

# 启动Flask应用
if __name__ == '__main__':
    app.run(debug=False)




 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


INFO:werkzeug: * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
INFO:werkzeug:127.0.0.1 - - [15/Nov/2024 11:02:35] "GET / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [15/Nov/2024 11:02:35] "[36mGET /static/css/style.css HTTP/1.1[0m" 304 -
INFO:werkzeug:127.0.0.1 - - [15/Nov/2024 11:02:36] "[33mGET /static/icon/favicon.ico HTTP/1.1[0m" 404 -
