From 195c6f9ae475000b76efc5fdcdd04125e7cef625 Mon Sep 17 00:00:00 2001 From: veoco Date: Sun, 11 Dec 2022 14:03:46 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20aiosqlite=20=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E6=95=B0=E6=8D=AE=E5=BA=93=E9=A9=B1=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6c1558dd0..95c6f01ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -fastapi==0.88.0 -python-multipart==0.0.5 +fastapi[all]==0.88.0 +aiosqlite==0.17.0 SQLAlchemy==1.4.44 -uvicorn==0.20.0 From f4e6dfef0d18e8aa243dd0b3e22c2c01c11ab12c Mon Sep 17 00:00:00 2001 From: veoco Date: Sun, 11 Dec 2022 14:04:25 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=93=8D?= =?UTF-8?q?=E4=BD=9C=E6=94=B9=E4=B8=BA=E5=BC=82=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- database.py | 18 +++++++---- main.py | 86 +++++++++++++++++++++++++++-------------------------- 2 files changed, 56 insertions(+), 48 deletions(-) diff --git a/database.py b/database.py index ad7d458fe..4275b326e 100644 --- a/database.py +++ b/database.py @@ -1,17 +1,23 @@ import datetime -from sqlalchemy import create_engine, DateTime -from sqlalchemy.orm import sessionmaker +from sqlalchemy import Boolean, Column, Integer, String, DateTime from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Boolean, Column, Integer, String +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.session import AsyncSession + + +engine = create_async_engine("sqlite+aiosqlite:///database.db") -engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False}) Base = declarative_base() -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +async def get_session(): + async with AsyncSession(engine, expire_on_commit=False) as s: + yield s class Codes(Base): - __tablename__ = 'codes' + __tablename__ = "codes" id = Column(Integer, primary_key=True, index=True) code = Column(String(10), unique=True, index=True) key = Column(String(30), unique=True, index=True) diff --git a/main.py b/main.py index 14e6c2fe2..58a15102a 100644 --- a/main.py +++ b/main.py @@ -2,19 +2,23 @@ import os import uuid import threading +import random + from fastapi import FastAPI, Depends, UploadFile, Form, File -from sqlalchemy import or_ -from sqlalchemy.orm import Session from starlette.requests import Request from starlette.responses import HTMLResponse -import random - from starlette.staticfiles import StaticFiles -import database -from database import engine, SessionLocal, Base +from sqlalchemy import or_, select, update, delete, create_engine +from sqlalchemy import select, update, delete +from sqlalchemy.ext.asyncio.session import AsyncSession + +from database import engine, get_session, Base, Codes + +engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False}) Base.metadata.create_all(bind=engine) + app = FastAPI() if not os.path.exists('./static'): os.makedirs('./static') @@ -58,17 +62,9 @@ def delete_file(files): os.remove('.' + file['text']) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def get_code(db: Session = Depends(get_db)): +async def get_code(s: AsyncSession): code = random.randint(10000, 99999) - while db.query(database.Codes).filter(database.Codes.code == code).first(): + while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar(): code = random.randint(10000, 99999) return str(code) @@ -94,21 +90,23 @@ async def admin(): @app.post(f'/{admin_address}') -async def admin_post(request: Request, db: Session = Depends(get_db)): +async def admin_post(request: Request, s: AsyncSession = Depends(get_session)): if request.headers.get('pwd') == admin_password: - codes = db.query(database.Codes).all() + query = select(Codes) + codes = (await s.execute(query)).scalars().all() return {'code': 200, 'msg': '查询成功', 'data': codes} else: return {'code': 404, 'msg': '密码错误'} @app.delete(f'/{admin_address}') -async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)): +async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)): if request.headers.get('pwd') == admin_password: - file = db.query(database.Codes).filter(database.Codes.code == code).first() + query = select(Codes).where(Codes.code == code) + file = (await s.execute(query)).scalars().first() threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start() - db.delete(file) - db.commit() + await s.delete(file) + await s.commit() return {'code': 200, 'msg': '删除成功'} else: return {'code': 404, 'msg': '密码错误'} @@ -138,20 +136,24 @@ def ip_error(ip): @app.post('/') -async def index(request: Request, code: str, db: Session = Depends(get_db)): +async def index(request: Request, code: str, s: AsyncSession = Depends(get_session)): ip = request.client.host if not check_ip(ip): return {'code': 404, 'msg': '错误次数过多,请稍后再试'} - info = db.query(database.Codes).filter(database.Codes.code == code).first() + query = select(Codes).where(Codes.code == code) + info = (await s.execute(query)).scalars().first() if not info: return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'} if info.exp_time < datetime.datetime.now() or info.count == 0: threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start() - db.delete(info) - db.commit() + await s.delete(info) + await s.commit() return {'code': 404, 'msg': '取件码已过期,请联系寄件人'} - info.count -= 1 - db.commit() + + count = info.count - 1 + query = update(Codes).where(Codes.id == info.id).values(count=count) + await s.execute(query) + await s.commit() return { 'code': 200, 'msg': '取件成功,请点击"取"查看', @@ -161,17 +163,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)): @app.post('/share') async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1), - file: UploadFile = File(default=None), db: Session = Depends(get_db)): - exps = db.query(database.Codes).filter( - or_( - database.Codes.exp_time < datetime.datetime.now(), - database.Codes.count == 0 - ) - ) - threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps.all()],)).start() - exps.delete() - db.commit() - code = get_code(db) + file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)): + query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0)) + exps = (await s.execute(query)).scalars().all() + threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start() + + exps_ids = [exp.id for exp in exps] + query = delete(Codes).where(Codes.id.in_(exps_ids)) + await s.execute(query) + await s.commit() + + code = await get_code(s) if style == '2': if value > 7: return {'code': 404, 'msg': '最大有效天数为7天'} @@ -192,7 +194,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'), return {'code': 404, 'msg': '文件过大'} else: size, _text, _type, name = len(text), text, 'text', '文本分享' - info = database.Codes( + info = Codes( code=code, text=_text, size=size, @@ -202,8 +204,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'), exp_time=exp_time, key=key ) - db.add(info) - db.commit() + s.add(info) + await s.commit() return { 'code': 200, 'msg': '分享成功,请点击文件箱查看取件码',