In [1]:
import torch
from torchvision import models, transforms
from flask import Flask, request, jsonify, render_template
from d2l import torch as d2l
from TFRS import load_image_classification_data, evaluate_accuracy, image_classification_train, image_classification_test, image_classification_predict

# 数据处理
train_path = './data/TFRS/TFRS_train'
test_path = './data/TFRS/TFRS_val'
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=0.5, std=0.5)
])
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=0.5, std=0.5)
])
batch_size = 8
num_workers = 0
pin_memory = True
train_iter = load_image_classification_data(train_path, train_transform, batch_size, num_workers, pin_memory)
test_iter = load_image_classification_data(test_path, test_transform, batch_size, num_workers, pin_memory)
labels = train_iter.dataset.classes
# 神经网络
net = models.resnet18(num_classes=len(labels))
model_path = './model/TFRS_model.pth'
# 设备
device = d2l.try_gpu()

In [None]:
# 训练
lr, num_epochs = 0.01, 3
image_classification_train(net, train_iter, test_iter, num_epochs, lr, device)
torch.save(net.state_dict(), model_path)

In [None]:
# 测试
net.load_state_dict(torch.load(model_path, map_location=device))
image_classification_test(net, labels, test_iter, device)
evaluate_accuracy(net, test_iter, device)

In [None]:
# 预测
app = Flask(__name__)
net.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))


@app.route('/', methods=['GET'])
def home():
    return render_template('index.html')


@app.route('/predict', methods=['POST'])
def predict():
    image = request.files['image']
    predicted_class = image_classification_predict(net, labels, image, test_transform, device)
    response = jsonify({'predicted_class': predicted_class})
    response.headers['Access-Control-Allow-Origin'] = '*'
    return response


app.run()