Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions database.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import datetime

from sqlalchemy import create_engine, DateTime
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Boolean, Column, Integer, String, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession


engine = create_async_engine("sqlite+aiosqlite:///database.db")

engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
Base = declarative_base()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


async def get_session():
async with AsyncSession(engine, expire_on_commit=False) as s:
yield s


class Codes(Base):
__tablename__ = 'codes'
__tablename__ = "codes"
id = Column(Integer, primary_key=True, index=True)
code = Column(String(10), unique=True, index=True)
key = Column(String(30), unique=True, index=True)
Expand Down
85 changes: 44 additions & 41 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import os
import uuid
import threading
import random

from fastapi import FastAPI, Depends, UploadFile, Form, File
from sqlalchemy import or_
from sqlalchemy.orm import Session
from starlette.requests import Request
from starlette.responses import HTMLResponse, FileResponse
import random

from starlette.staticfiles import StaticFiles

import database
from database import engine, SessionLocal, Base
from sqlalchemy import or_, select, update, delete, create_engine
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio.session import AsyncSession

from database import engine, get_session, Base, Codes


engine = create_engine('sqlite:///database.db', connect_args={"check_same_thread": False})
Base.metadata.create_all(bind=engine)

app = FastAPI()
if not os.path.exists('./static'):
os.makedirs('./static')
Expand Down Expand Up @@ -58,17 +63,9 @@ def delete_file(files):
os.remove('.' + file['text'])


def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()


def get_code(db: Session = Depends(get_db)):
async def get_code(s: AsyncSession):
code = random.randint(10000, 99999)
while db.query(database.Codes).filter(database.Codes.code == code).first():
while (await s.execute(select(Codes.id).where(Codes.code == code))).scalar():
code = random.randint(10000, 99999)
return str(code)

Expand All @@ -94,21 +91,23 @@ async def admin():


@app.post(f'/{admin_address}')
async def admin_post(request: Request, db: Session = Depends(get_db)):
async def admin_post(request: Request, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
codes = db.query(database.Codes).all()
query = select(Codes)
codes = (await s.execute(query)).scalars().all()
return {'code': 200, 'msg': '查询成功', 'data': codes}
else:
return {'code': 404, 'msg': '密码错误'}


@app.delete(f'/{admin_address}')
async def admin_delete(request: Request, code: str, db: Session = Depends(get_db)):
async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
file = db.query(database.Codes).filter(database.Codes.code == code).first()
query = select(Codes).where(Codes.code == code)
file = (await s.execute(query)).scalars().first()
threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start()
db.delete(file)
db.commit()
await s.delete(file)
await s.commit()
return {'code': 200, 'msg': '删除成功'}
else:
return {'code': 404, 'msg': '密码错误'}
Expand Down Expand Up @@ -150,22 +149,26 @@ async def get_file(code: str, db: Session = Depends(get_db)):


@app.post('/')
async def index(request: Request, code: str, db: Session = Depends(get_db)):
async def index(request: Request, code: str, s: AsyncSession = Depends(get_session)):
ip = request.client.host
if not check_ip(ip):
return {'code': 404, 'msg': '错误次数过多,请稍后再试'}
info = db.query(database.Codes).filter(database.Codes.code == code).first()
query = select(Codes).where(Codes.code == code)
info = (await s.execute(query)).scalars().first()
if not info:
return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'}
if info.exp_time < datetime.datetime.now() or info.count == 0:
threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start()
db.delete(info)
db.commit()
await s.delete(info)
await s.commit()
return {'code': 404, 'msg': '取件码已过期,请联系寄件人'}
info.count -= 1
db.commit()
count = info.count - 1
query = update(Codes).where(Codes.id == info.id).values(count=count)
await s.execute(query)
await s.commit()
if info.type != 'text':
info.text = f'/select?code={code}'

return {
'code': 200,
'msg': '取件成功,请点击"取"查看',
Expand All @@ -175,17 +178,17 @@ async def index(request: Request, code: str, db: Session = Depends(get_db)):

@app.post('/share')
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
file: UploadFile = File(default=None), db: Session = Depends(get_db)):
exps = db.query(database.Codes).filter(
or_(
database.Codes.exp_time < datetime.datetime.now(),
database.Codes.count == 0
)
)
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps.all()],)).start()
exps.delete()
db.commit()
code = get_code(db)
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
exps = (await s.execute(query)).scalars().all()
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()

exps_ids = [exp.id for exp in exps]
query = delete(Codes).where(Codes.id.in_(exps_ids))
await s.execute(query)
await s.commit()

code = await get_code(s)
if style == '2':
if value > 7:
return {'code': 404, 'msg': '最大有效天数为7天'}
Expand All @@ -206,7 +209,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
return {'code': 404, 'msg': '文件过大'}
else:
size, _text, _type, name = len(text), text, 'text', '文本分享'
info = database.Codes(
info = Codes(
code=code,
text=_text,
size=size,
Expand All @@ -216,8 +219,8 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
exp_time=exp_time,
key=key
)
db.add(info)
db.commit()
s.add(info)
await s.commit()
return {
'code': 200,
'msg': '分享成功,请点击文件箱查看取件码',
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fastapi==0.88.0
python-multipart==0.0.5
fastapi[all]==0.88.0
aiosqlite==0.17.0
SQLAlchemy==1.4.44
uvicorn==0.20.0