Skip to content
This repository has been archived by the owner on May 5, 2023. It is now read-only.

Commit

Permalink
fix: closes #19
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbricman committed Feb 14, 2022
1 parent bf093f4 commit 0ac7b31
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 27 deletions.
52 changes: 27 additions & 25 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import FastAPI, Request, Header
from fastapi import Depends, FastAPI, Request, Header
from security import auth
from util import find, rank, save, get_authorized_thoughts, remove, dump
from sentence_transformers import SentenceTransformer
from fastapi.datastructures import UploadFile
from fastapi import FastAPI, File, Form
from fastapi.responses import FileResponse, ORJSONResponse
from fastapi.security import HTTPBearer, HTTPBasicCredentials
from pathlib import Path
from microverses import create_microverse, remove_microverse, list_microverses
from slowapi import Limiter, _rate_limit_exceeded_handler
Expand All @@ -13,6 +14,7 @@
from slowapi.errors import RateLimitExceeded


security = HTTPBearer()
limiter = Limiter(key_func=get_remote_address, default_limits=['30/minute'])
app = FastAPI()
app.state.limiter = limiter
Expand All @@ -33,7 +35,7 @@ async def find_text_handler(
return_embeddings: bool = False,
silent: bool = False,
request: Request = None,
authorization: str = Header(None)
authorization: HTTPBasicCredentials = Depends(security)
):
return find(
'text',
Expand All @@ -42,7 +44,7 @@ async def find_text_handler(
activation,
noise,
return_embeddings,
auth(authorization),
auth(authorization.credentials),
text_encoder,
text_image_encoder,
silent
Expand All @@ -58,7 +60,7 @@ async def find_image_handler(
return_embeddings: bool = Form(False),
silent: bool = Form(False),
request: Request = None,
authorization: str = Header(None)
authorization: HTTPBasicCredentials = Depends(security)
):
query = await query.read()
return find(
Expand All @@ -68,67 +70,67 @@ async def find_image_handler(
activation,
noise,
return_embeddings,
auth(authorization),
auth(authorization.credentials),
text_encoder,
text_image_encoder,
silent
)


@app.get('/save')
async def save_text_handler(query: str, request: Request, authorization: str = Header(None)):
return save('text', query, auth(authorization),
async def save_text_handler(query: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return save('text', query, auth(authorization.credentials),
text_encoder, text_image_encoder)


@app.post('/save')
async def save_image_handler(query: UploadFile = File(...), request: Request = None, authorization: str = Header(None)):
async def save_image_handler(query: UploadFile = File(...), request: Request = None, authorization: HTTPBasicCredentials = Depends(security)):
query = await query.read()
results = save('image', query, auth(authorization),
results = save('image', query, auth(authorization.credentials),
text_encoder, text_image_encoder)
return results


@app.get('/remove')
async def remove_handler(filename: str, request: Request, authorization: str = Header(None)):
return remove(auth(authorization), filename)
async def remove_handler(filename: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return remove(auth(authorization.credentials), filename)


@app.get('/dump')
async def save_text_handler(request: Request, authorization: str = Header(None)):
return dump(auth(authorization))
async def save_text_handler(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return dump(auth(authorization.credentials))


@app.get('/static')
@limiter.limit("200/minute")
async def static_handler(filename: str, request: Request, authorization: str = Header(None)):
async def static_handler(filename: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
knowledge_base_path = Path('..') / 'knowledge'
thoughts = get_authorized_thoughts(auth(authorization))
thoughts = get_authorized_thoughts(auth(authorization.credentials))
if filename in [e['filename'] for e in thoughts]:
return FileResponse(knowledge_base_path / filename)


@app.get('/microverse/create')
async def microverse_create_handler(query: str, request: Request, authorization: str = Header(None)):
return create_microverse('text', query, auth(authorization), text_encoder, text_image_encoder)
async def microverse_create_handler(query: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return create_microverse('text', query, auth(authorization.credentials), text_encoder, text_image_encoder)


@app.post('/microverse/create')
async def microverse_create_handler(query: UploadFile = File(...), request: Request = None, authorization: str = Header(None)):
async def microverse_create_handler(query: UploadFile = File(...), request: Request = None, authorization: HTTPBasicCredentials = Depends(security)):
query = await query.read()
return create_microverse('image', query, auth(authorization), text_encoder, text_image_encoder)
return create_microverse('image', query, auth(authorization.credentials), text_encoder, text_image_encoder)


@app.get('/microverse/remove')
async def microverse_remove_handler(microverse: str, request: Request, authorization: str = Header(None)):
return remove_microverse(auth(authorization), microverse)
async def microverse_remove_handler(microverse: str, request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return remove_microverse(auth(authorization.credentials), microverse)


@app.get('/microverse/list')
async def microverse_list_handler(request: Request, authorization: str = Header(None)):
return list_microverses(auth(authorization))
async def microverse_list_handler(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return list_microverses(auth(authorization.credentials))


@app.get('/custodian/check')
async def check_custodian(request: Request, authorization: str = Header(None)):
return auth(authorization)
async def check_custodian(request: Request, authorization: HTTPBasicCredentials = Depends(security)):
return auth(authorization.credentials)
3 changes: 1 addition & 2 deletions backend/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@


def auth(token):
if token == None or not token.startswith('Bearer '):
if not token:
return {
'custodian': False
}

token = token.replace('Bearer ', '')
path = Path('..') / 'knowledge' / 'records.json'

if not path.exists():
Expand Down

0 comments on commit 0ac7b31

Please sign in to comment.