-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
71 lines (52 loc) · 1.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from email.policy import default
from fastapi import FastAPI, File, UploadFile, HTTPException
from utils.video import process_video, get_frames, create_tmp_path
from utils.onnx import get_session, oxx_inference
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
app = FastAPI()
onnx_session = get_session()
# Origins for development and production clients.
origins = ["http://localhost:3000", "https://sign2text.com"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET"],
allow_headers=["*"],
max_age=3600,
)
## Close sseion on server close.
@app.on_event("shutdown")
async def shutdown():
del onnx_session
class TargetModel(BaseModel):
"""
OpenAPI Model for target response.
"""
target: str
@app.post(
"/sign",
tags=["video"],
description="Video classification",
response_model=TargetModel,
)
async def inference(video: UploadFile = File(default=None)) -> TargetModel:
"""Video classification.
Args:
video (UploadFile, optional): Video to inference.
Raises:
HTTPException: If the video is not a video.
Returns:
TargetModel: Class of the video.
"""
if not video.content_type in ["video/mp4"]:
raise HTTPException(status_code=400, detail="File must be mp4 video")
# Get video frames.
file_tmp_path = await create_tmp_path(video)
frames = get_frames(file_tmp_path)
# Process video for inference.
frames = await process_video(frames)
# Input to the model.
target = oxx_inference(frames, onnx_session)
return TargetModel(target=target)