-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
115 lines (86 loc) · 2.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
from typing import List
import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from starlette import status
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from model import ErrorResponse, SearchQuery, SearchResponse, \
SpellCheckResponse, ContentQuery, ContentResponse
from search import generate_fuzzy_model, bm25_searcher
from sparql import construct, select, constructBatch
load_dotenv()
app = FastAPI()
STATIC_PATH = "/static"
SPELL_CHECKER = lambda ret: ""
class Dummy:
def search(self, query, cutoff):
return ""
SEARCH_ENGINE = Dummy()
origins = [
"*"
]
@app.on_event("startup")
async def startup_event():
global SPELL_CHECKER, SEARCH_ENGINE
SPELL_CHECKER = generate_fuzzy_model()
SEARCH_ENGINE = bm25_searcher()
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
@app.get("/")
async def read_root():
return {
"code": 200
}
def computeContent(query: ContentQuery):
triples = construct(query.type, query.id)
return json.loads(triples)
def computeSearchResult(result: str):
triples = constructBatch(result)
print(triples)
return json.loads(triples)
@app.post("/search", response_model=SearchResponse)
async def search(query: SearchQuery):
spell_checked = SPELL_CHECKER(query.content)
changed = not (spell_checked == query.content)
result = list(SEARCH_ENGINE.search(spell_checked, cutoff=10))
for i in range(len(result)):
# print(result[i])
result[i]['score'] = float(result[i]['score'])
# print(result[0]['id'])
contentID, type = result[0]['id'].split()
print(result)
top_result = {
"content": (computeContent(ContentQuery(contentID, type))
if len(result) > 0
else dict()),
"type": type,
"id": contentID
}
desc = computeSearchResult(" ".join([f":{res['id'].split()[0]}" for res in result]))
return SearchResponse(200, result, desc, top_result, spell_checked, changed)
@app.post("/spellcheck", response_model=SpellCheckResponse)
async def spellcheck(query: SearchQuery):
result = SPELL_CHECKER(query.content)
changed = not (result == query.content)
return SpellCheckResponse(200, result, changed)
@app.post("/content", response_model=ContentResponse)
async def content(query: ContentQuery):
result = computeContent(query)
return ContentResponse(200, result)
def common_error(err: Exception):
"""
Returns abnormal JSONResponse
"""
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND,
content=ErrorResponse("invalid request",
f"{str(err)}").dict())
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8080, log_level="info",
reload=True)