In [1]:
from flask import Flask, request, render_template, jsonify
import torch
import os
import pickle as pkl
from importlib import import_module

# 设置默认参数
UNK, PAD = '<UNK>', '<PAD>'
dataset_name = "THUCNews"  # 设置数据集名称
key = {
    0: '财经',
    1: '房产',
    2: '股票',
    3: '教育',
    4: '科技',
    5: '社会',
    6: '政治',
    7: '体育',
    8: '游戏',
    9: '娱乐'
}


# 预定义两个模型名称
MODEL_NAMES = ["TextCNN", "TextRNN"]  

# 模型和配置字典
models = {}
configs = {}

# 加载模型函数
def init_model(model_name):
    if model_name in models:
        return models[model_name], configs[model_name]
    
    x = import_module('models.' + model_name)
    config = x.Config(dataset_name, embedding='random')
    
    if os.path.exists(config.vocab_path):
        vocab = pkl.load(open(config.vocab_path, 'rb'))
        config.n_vocab = len(vocab)
    
    model = x.Model(config).to(config.device)
    model.load_state_dict(torch.load(config.save_path, map_location=torch.device('cuda')))
    model.eval()

    # 缓存模型和配置
    models[model_name] = model
    configs[model_name] = config
    
    return model, config

# 初始化 Flask 应用
app = Flask(__name__)

# 在启动时加载 TextCNN 和 TextRNN 模型
for model_name in MODEL_NAMES:
    init_model(model_name)

def build_predict_text(text, use_word, config, vocab):
    if use_word:
        tokenizer = lambda x: x.split(' ')
    else:
        tokenizer = lambda x: [y for y in x]

    token = tokenizer(text)
    seq_len = len(token)
    pad_size = config.pad_size
    if pad_size:
        if len(token) < pad_size:
            token.extend([PAD] * (pad_size - len(token)))
        else:
            token = token[:pad_size]
            seq_len = pad_size

    words_line = []
    for word in token:
        words_line.append(vocab.get(word, vocab.get(UNK)))

    ids = torch.LongTensor([words_line]).to(config.device)
    seq_len = torch.LongTensor(seq_len).to(config.device)

    return ids, seq_len

def predict(text, model, config, vocab):
    data = build_predict_text(text, use_word=False, config=config, vocab=vocab)
    with torch.no_grad():
        outputs = model(data)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_index = torch.argmax(probabilities)
        predicted_label = key[int(predicted_index)]
        predicted_probability = round(probabilities[0, predicted_index].item() * 100, 2)

    all_probabilities = {key[i]: round(probabilities[0, i].item() * 100, 2) for i in range(len(key))}
    return predicted_label, predicted_probability, all_probabilities


# 首页路由，显示文本输入表单
@app.route('/')
def home():
    return render_template('index.html')

# 预测路由，处理表单提交的文本
@app.route('/predict', methods=['POST'])
def make_prediction():
    # 获取用户选择的模型
    selected_model = request.form.get('model', 'TextCNN')
    
    # 获取相应的模型和配置
    model, config = init_model(selected_model)

    # 获取模型的词汇表
    vocab = pkl.load(open(config.vocab_path, 'rb'))

    text = request.form['text']
    if text:
        label, probability, all_probs = predict(text, model, config, vocab)
        return render_template('index.html', label=label, probability=probability, all_probs=all_probs, input_text=text, selected_model=selected_model)
    else:
        return render_template('index.html', error="请输入文本进行预测。", input_text="", selected_model=selected_model)

# 启动 Flask 应用
if __name__ == "__main__":
    app.run()


 * Serving Flask app '__main__'
 * Debug mode: off


  model.load_state_dict(torch.load(config.save_path, map_location=torch.device('cuda')))
 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [06/Jan/2025 19:19:08] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [06/Jan/2025 19:19:08] "GET /static/pic/background.jpg HTTP/1.1" 404 -
127.0.0.1 - - [06/Jan/2025 19:19:08] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [06/Jan/2025 19:19:11] "POST /predict HTTP/1.1" 200 -
127.0.0.1 - - [06/Jan/2025 19:19:11] "GET /static/pic/background.jpg HTTP/1.1" 404 -


In [1]:
import tensorflow as tf

log_file = r"D:\stu\python_stu\nlp\文本分类组\TextClassification\THUCNews\log\TextCNN\11-10_22.53\events.out.tfevents.1731250402.zerone"

# 读取日志文件内容
for event in tf.compat.v1.train.summary_iterator(log_file):
    for value in event.summary.value:
        print(f"{value.tag}: {value.simple_value}")



Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
loss/train: 2.3394477367401123
loss/dev: 2.21150803565979
acc/train: 0.125
acc/dev: 0.24809999763965607
loss/train: 1.023474931716919
loss/dev: 0.6886829733848572
acc/train: 0.6953125
acc/dev: 0.7874000072479248
loss/train: 0.9682983756065369
loss/dev: 0.5889559388160706
acc/train: 0.734375
acc/dev: 0.8140000104904175
loss/train: 0.6265930533409119
loss/dev: 0.5263707041740417
acc/train: 0.78125
acc/dev: 0.8389999866485596
loss/train: 1.0306557416915894
loss/dev: 0.5194095373153687
acc/train: 0.75
acc/dev: 0.8422999978065491
loss/train: 0.46502459049224854
loss/dev: 0.4972032606601715
acc/train: 0.84375
acc/dev: 0.8468999862670898
loss/train: 0.6193289160728455
loss/dev: 0.47560593485832214
acc/train: 0.8046875
acc/dev: 0.8550000190734863
loss/train: 0.693721354007721
loss/dev: 0.4547249972820282
acc/train: 0.7734375
acc/dev: 0.8600999712944031
loss/train: 0.5941688418388367
loss/dev: 0.44516620039939

In [2]:
import tensorflow as tf

log_file = r"D:\stu\python_stu\nlp\文本分类组\TextClassification\THUCNews\log\TextRNN\11-12_10.52\events.out.tfevents.1731379926.zerone"

# 读取日志文件内容
for event in tf.compat.v1.train.summary_iterator(log_file):
    for value in event.summary.value:
        print(f"{value.tag}: {value.simple_value}")


loss/train: 2.301801919937134
loss/dev: 2.2840688228607178
acc/train: 0.1171875
acc/dev: 0.22609999775886536
loss/train: 0.652897834777832
loss/dev: 0.7139697670936584
acc/train: 0.7734375
acc/dev: 0.7696999907493591
loss/train: 0.7688279747962952
loss/dev: 0.5700805187225342
acc/train: 0.7421875
acc/dev: 0.819100022315979
loss/train: 0.3927404582500458
loss/dev: 0.5227805376052856
acc/train: 0.875
acc/dev: 0.8356000185012817
loss/train: 0.6387857794761658
loss/dev: 0.4629606008529663
acc/train: 0.78125
acc/dev: 0.8547000288963318
loss/train: 0.35441115498542786
loss/dev: 0.4396161437034607
acc/train: 0.890625
acc/dev: 0.8640000224113464
loss/train: 0.4423081874847412
loss/dev: 0.43192365765571594
acc/train: 0.8515625
acc/dev: 0.8654999732971191
loss/train: 0.32340118288993835
loss/dev: 0.3998771607875824
acc/train: 0.8828125
acc/dev: 0.8716999888420105
loss/train: 0.3911683261394501
loss/dev: 0.3977448046207428
acc/train: 0.890625
acc/dev: 0.8737000226974487
loss/train: 0.373064279556