-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
76 lines (53 loc) · 2.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import numpy as np
import json
from contextlib import asynccontextmanager
from torchvision.models import resnet50, ResNet50_Weights
from fastapi.responses import RedirectResponse
from fastapi import status
from utils import vectorize_image, cosine_similarity
import shutil
model_dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
model_dict["resnet50"] = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model_dict["resnet50"].eval()
model_dict["extraction_layer"] = model_dict["resnet50"]._modules.get('avgpool')
yield
# Clean up the ML models and release the resources
model_dict.clear()
app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def read_item(request: Request):
return templates.TemplateResponse(request=request, name="index.html")
@app.post("/search")
async def query_endpoint(request: Request, file: UploadFile):
embedding = vectorize_image(file.file, model_dict)
with open("vector_store.json", 'r') as file:
data = json.load(file)
results = {}
for i in data.keys():
results[i] = cosine_similarity(np.array(data[i]), embedding)
results = sorted(results.items(), key=lambda item: item[1], reverse=True)[:3]
results = dict(results)
return templates.TemplateResponse(request=request, context = {"results": results }, name="result.html")
@app.post("/vectorize/")
async def upload_file(file: UploadFile):
filename = file.filename
with open(f"static/images/{filename}", "wb") as f:
shutil.copyfileobj(file.file, f)
embedding = vectorize_image(file.file, model_dict)
with open("vector_store.json", 'r') as file:
data = json.load(file)
# Add the new image data to the JSON data
data[filename] = embedding.tolist()
# Write the updated data back to the JSON file
with open("vector_store.json", 'w') as file:
json.dump(data, file, indent=4)
return {"message": "Image Successfully Vectorized"}