# 直接下载 `safetensors` 文件然后用对外服务 chat completion

https://zhuanlan.zhihu.com/p/702475296

https://modelscope.cn/models/qwen/Qwen2-7B-Instruct/files


In [None]:
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import json

app = Flask(__name__)
CORS(app)

device = "cuda"

# Configure logging
logging.basicConfig(level=logging.INFO)

# Load the tokenizer and model once when the application starts
tokenizer = AutoTokenizer.from_pretrained("./", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "./",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to(device).eval()

@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
    data = request.json
    messages = data.get('messages', [])
    top_k = data.get('top_k', None)
    top_p = data.get('top_p', 1)
    temperature = data.get('temperature', 0.3)
    max_length = data.get('max_length', 2500)
    logging.info(f"Received data: {json.dumps(data, ensure_ascii=False)}")  # Correctly formatted logging to print request data
    inputs = tokenizer.apply_chat_template(messages,
                                           add_generation_prompt=True,
                                           tokenize=True,
                                           return_tensors="pt",
                                           return_dict=True
                                           )
    logging.info(f"messages: {messages}")
    inputs = inputs.to(device)

  
    gen_kwargs = {
        "max_length": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature":temperature
    }


    # 这个不是真流式，能同时兼容 glm4 和 qwen2
    def generate_response():
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]
            response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            for char in response_text:
                logging.info(f"Sending character: {char}")
                yield f"data: {json.dumps({'choices': [{'delta': {'content': char}}]})}\n\n"
            # yield f"data: {json.dumps({'choices': [{'delta': {'content': '[DONE]'}}]})}\n\n"
  
    return Response(generate_response(), content_type='text/event-stream')

if __name__ == '__main__':
    app.run(host='127.0.0.1', port=5000)