Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 54 additions & 58 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import uuid
import threading
import random
import asyncio
from pathlib import Path

from fastapi import FastAPI, Depends, UploadFile, Form, File
from starlette.requests import Request
from starlette.responses import HTMLResponse, FileResponse
Expand All @@ -11,55 +14,56 @@
from sqlalchemy import or_, select, update, delete
from sqlalchemy.ext.asyncio.session import AsyncSession

from database import get_session, Codes, init_models
import settings
from database import get_session, Codes, init_models, engine

app = FastAPI(debug=settings.DEBUG)

DATA_ROOT = Path(settings.DATA_ROOT)
if not DATA_ROOT.exists():
DATA_ROOT.mkdir(parents=True)

app = FastAPI()
if not os.path.exists('./static'):
os.makedirs('./static')
app.mount("/static", StaticFiles(directory="static"), name="static")
STATIC_URL = settings.STATIC_URL
app.mount(STATIC_URL, StaticFiles(directory=DATA_ROOT), name="static")


@app.on_event('startup')
async def startup():
await init_models()


############################################
# 需要修改的参数
# 允许错误次数
error_count = 5
# 禁止分钟数
error_minute = 10
# 后台地址
admin_address = 'admin'
# 管理密码
admin_password = 'admin'
# 文件大小限制 10M
file_size_limit = 1024 * 1024 * 10
# 系统标题
title = '文件快递柜'
# 系统描述
description = 'FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件'
# 系统关键字
keywords = 'FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件'
############################################
asyncio.create_task(delete_expire_files())

index_html = open('templates/index.html', 'r', encoding='utf-8').read() \
.replace('{{title}}', title) \
.replace('{{description}}', description) \
.replace('{{keywords}}', keywords)
.replace('{{title}}', settings.TITLE) \
.replace('{{description}}', settings.DESCRIPTION) \
.replace('{{keywords}}', settings.KEYWORDS)
admin_html = open('templates/admin.html', 'r', encoding='utf-8').read() \
.replace('{{title}}', title) \
.replace('{{description}}', description) \
.replace('{{keywords}}', keywords)
.replace('{{title}}', settings.TITLE) \
.replace('{{description}}', settings.DESCRIPTION) \
.replace('{{keywords}}', settings.KEYWORDS)

error_ip_count = {}


def delete_file(files):
for file in files:
if file['type'] != 'text':
os.remove('.' + file['text'])
os.remove(DATA_ROOT / file['text'].lstrip(STATIC_URL+'/'))


async def delete_expire_files():
while True:
async with AsyncSession(engine, expire_on_commit=False) as s:
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
exps = (await s.execute(query)).scalars().all()
await asyncio.to_thread(delete_file, [{'type': old.type, 'text': old.text} for old in exps])

exps_ids = [exp.id for exp in exps]
query = delete(Codes).where(Codes.id.in_(exps_ids))
await s.execute(query)
await s.commit()

await asyncio.sleep(random.randint(60, 300))


async def get_code(s: AsyncSession):
Expand All @@ -73,38 +77,39 @@ def get_file_name(key, ext, file):
now = datetime.datetime.now()
file_bytes = file.file.read()
size = len(file_bytes)
if size > file_size_limit:
if size > settings.FILE_SIZE_LIMIT:
return size, '', '', ''
path = f'./static/upload/{now.year}/{now.month}/{now.day}/'
path = DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/"
name = f'{key}.{ext}'
if not os.path.exists(path):
os.makedirs(path)
with open(f'{os.path.join(path, name)}', 'wb') as f:
if not path.exists():
path.mkdir(parents=True)
filepath = path / name
with open(filepath, 'wb') as f:
f.write(file_bytes)
return size, path[1:] + name, file.content_type, file.filename
return size, f"{STATIC_URL}/{filepath.relative_to(DATA_ROOT)}", file.content_type, file.filename


@app.get(f'/{admin_address}')
@app.get(f'/{settings.ADMIN_ADDRESS}')
async def admin():
return HTMLResponse(admin_html)


@app.post(f'/{admin_address}')
@app.post(f'/{settings.ADMIN_ADDRESS}')
async def admin_post(request: Request, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
if request.headers.get('pwd') == settings.ADMIN_PASSWORD:
query = select(Codes)
codes = (await s.execute(query)).scalars().all()
return {'code': 200, 'msg': '查询成功', 'data': codes}
else:
return {'code': 404, 'msg': '密码错误'}


@app.delete(f'/{admin_address}')
@app.delete(f'/{settings.ADMIN_ADDRESS}')
async def admin_delete(request: Request, code: str, s: AsyncSession = Depends(get_session)):
if request.headers.get('pwd') == admin_password:
if request.headers.get('pwd') == settings.ADMIN_PASSWORD:
query = select(Codes).where(Codes.code == code)
file = (await s.execute(query)).scalars().first()
threading.Thread(target=delete_file, args=([{'type': file.type, 'text': file.text}],)).start()
await asyncio.to_thread(delete_file, [{'type': file.type, 'text': file.text}])
await s.delete(file)
await s.commit()
return {'code': 200, 'msg': '删除成功'}
Expand All @@ -120,8 +125,8 @@ async def index():
def check_ip(ip):
# 检查ip是否被禁止
if ip in error_ip_count:
if error_ip_count[ip]['count'] >= error_count:
if error_ip_count[ip]['time'] + datetime.timedelta(minutes=error_minute) > datetime.datetime.now():
if error_ip_count[ip]['count'] >= settings.ERROR_COUNT:
if error_ip_count[ip]['time'] + datetime.timedelta(minutes=settings.ERROR_MINUTE) > datetime.datetime.now():
return False
else:
error_ip_count.pop(ip)
Expand All @@ -143,7 +148,7 @@ async def get_file(code: str, s: AsyncSession = Depends(get_session)):
if info.type == 'text':
return {'code': code, 'msg': '查询成功', 'data': info.text}
else:
return FileResponse('.' + info.text, filename=info.name)
return FileResponse(DATA_ROOT / info.text.lstrip(STATIC_URL+'/'), filename=info.name)
else:
return {'code': 404, 'msg': '口令不存在'}

Expand All @@ -156,7 +161,7 @@ async def index(request: Request, code: str, s: AsyncSession = Depends(get_sessi
query = select(Codes).where(Codes.code == code)
info = (await s.execute(query)).scalars().first()
if not info:
return {'code': 404, 'msg': f'取件码错误,错误{error_count - ip_error(ip)}次将被禁止10分钟'}
return {'code': 404, 'msg': f'取件码错误,错误{settings.ERROR_COUNT - ip_error(ip)}次将被禁止10分钟'}
if info.exp_time < datetime.datetime.now() or info.count == 0:
threading.Thread(target=delete_file, args=([{'type': info.type, 'text': info.text}],)).start()
await s.delete(info)
Expand All @@ -179,15 +184,6 @@ async def index(request: Request, code: str, s: AsyncSession = Depends(get_sessi
@app.post('/share')
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
exps = (await s.execute(query)).scalars().all()
threading.Thread(target=delete_file, args=([[{'type': old.type, 'text': old.text}] for old in exps],)).start()

exps_ids = [exp.id for exp in exps]
query = delete(Codes).where(Codes.id.in_(exps_ids))
await s.execute(query)
await s.commit()

code = await get_code(s)
if style == '2':
if value > 7:
Expand All @@ -205,7 +201,7 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
key = uuid.uuid4().hex
if file:
size, _text, _type, name = get_file_name(key, file.filename.split('.')[-1], file)
if size > file_size_limit:
if size > settings.FILE_SIZE_LIMIT:
return {'code': 404, 'msg': '文件过大'}
else:
size, _text, _type, name = len(text), text, 'text', '文本分享'
Expand Down
28 changes: 28 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from starlette.config import Config


config = Config(".env")

DEBUG = config('DEBUG', cast=bool, default=False)

DATABASE_URL = config('DATABASE_URL', cast=str, default="sqlite+aiosqlite:///database.db")

DATA_ROOT = config('DATA_ROOT', cast=str, default="./static")

STATIC_URL = config('STATIC_URL', cast=str, default="/static")

ERROR_COUNT = config('ERROR_COUNT', cast=int, default=5)

ERROR_MINUTE = config('ERROR_MINUTE', cast=int, default=10)

ADMIN_ADDRESS = config('ADMIN_ADDRESS', cast=str, default="admin")

ADMIN_PASSWORD = config('ADMIN_ADDRESS', cast=str, default="admin")

FILE_SIZE_LIMIT = config('FILE_SIZE_LIMIT', cast=int, default=1024 * 1024 * 10)

TITLE = config('TITLE', cast=str, default="文件快递柜")

DESCRIPTION = config('DESCRIPTION', cast=str, default="FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件")

KEYWORDS = config('TITLE', cast=str, default="FileCodeBox,文件快递柜,口令传送箱,匿名口令分享文本,文件,图片,视频,音频,压缩包等文件")