From bcf0be3d123a47868979d01e373fbe4026b7a080 Mon Sep 17 00:00:00 2001 From: veoco Date: Tue, 13 Dec 2022 12:45:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=A4=A7=E5=B0=8F=E8=AE=A1=E7=AE=97=E5=92=8C=E5=86=99=E5=85=A5?= =?UTF-8?q?=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 12 ++++++------ storage.py | 42 +++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 366feb4bc..be822f6ca 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ import asyncio from pathlib import Path -from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException +from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException, BackgroundTasks from starlette.responses import HTMLResponse, FileResponse from starlette.staticfiles import StaticFiles @@ -130,8 +130,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend @app.post('/share') -async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1), - file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)): +async def share(background_tasks: BackgroundTasks, text: str = Form(default=None), style: str = Form(default='2'), + value: int = Form(default=1), file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)): code = await get_code(s) if style == '2': if value > 7: @@ -148,11 +148,11 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'), exp_count = -1 key = uuid.uuid4().hex if file: - file_bytes = await file.read() - size = len(file_bytes) + size = await storage.get_size(file) if size > settings.FILE_SIZE_LIMIT: raise HTTPException(status_code=400, detail="文件过大") - _text, _type, name = await storage.save_file(file, file_bytes, key), file.content_type, file.filename + _text, _type, name = await storage.get_text(file, key), file.content_type, file.filename + background_tasks.add_task(storage.save_file, file, _text) else: size, _text, _type, name = len(text), text, 'text', '文本分享' info = Codes( diff --git a/storage.py b/storage.py index 851cbbb41..58a90deff 100644 --- a/storage.py +++ b/storage.py @@ -1,7 +1,10 @@ import os import asyncio -import datetime +from datetime import datetime from pathlib import Path +from typing import BinaryIO + +from fastapi import UploadFile import settings @@ -11,25 +14,38 @@ class FileSystemStorage: STATIC_URL = settings.STATIC_URL NAME = "filesystem" - async def get_filepath(self, path): - return self.DATA_ROOT / path.lstrip(self.STATIC_URL + '/') - - def _save(self, filepath, file_bytes): - with open(filepath, 'wb') as f: - f.write(file_bytes) + async def get_filepath(self, text: str): + return self.DATA_ROOT / text.lstrip(self.STATIC_URL + '/') - async def save_file(self, file, file_bytes, key): - now = datetime.datetime.now() - path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/" + async def get_text(self, file: UploadFile, key: str): ext = file.filename.split('.')[-1] - name = f'{key}.{ext}' + now = datetime.now() + path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/" if not path.exists(): path.mkdir(parents=True) - filepath = path / name - await asyncio.to_thread(self._save, filepath, file_bytes) + filepath = path / f'{key}.{ext}' text = f"{self.STATIC_URL}/{filepath.relative_to(self.DATA_ROOT)}" return text + async def get_size(self, file: UploadFile): + f = file.file + f.seek(0, os.SEEK_END) + size = f.tell() + f.seek(0, os.SEEK_SET) + return size + + def _save(self, filepath, file: BinaryIO): + with open(filepath, 'wb') as f: + chunk_size = 256 * 1024 + chunk = file.read(chunk_size) + while chunk: + f.write(chunk) + chunk = file.read(chunk_size) + + async def save_file(self, file: UploadFile, text: str): + filepath = await self.get_filepath(text) + await asyncio.to_thread(self._save, filepath, file.file) + async def delete_file(self, file): # 是文件就删除 if file['type'] != 'text': From 66d9e432bc64e59ede76247c95a9782eac612cb2 Mon Sep 17 00:00:00 2001 From: veoco Date: Tue, 13 Dec 2022 14:18:02 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 19 +++++++++++++------ storage.py | 17 +++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index be822f6ca..5ba552538 100644 --- a/main.py +++ b/main.py @@ -51,9 +51,13 @@ async def delete_expire_files(): async with AsyncSession(engine, expire_on_commit=False) as s: query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0)) exps = (await s.execute(query)).scalars().all() - files = [{'type': old.type, 'text': old.text} for old in exps] + files = [] + exps_ids = [] + for exp in exps: + if exp.type != "text": + files.append(exp.text) + exps_ids.append(exp.id) await storage.delete_files(files) - exps_ids = [exp.id for exp in exps] query = delete(Codes).where(Codes.id.in_(exps_ids)) await s.execute(query) await s.commit() @@ -83,9 +87,11 @@ async def admin_post(s: AsyncSession = Depends(get_session)): async def admin_delete(code: str, s: AsyncSession = Depends(get_session)): query = select(Codes).where(Codes.code == code) file = (await s.execute(query)).scalars().first() - await storage.delete_file({'type': file.type, 'text': file.text}) - await s.delete(file) - await s.commit() + if file: + if file.type != 'text': + await storage.delete_file(file.text) + await s.delete(file) + await s.commit() return {'detail': '删除成功'} @@ -115,7 +121,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend error_count = settings.ERROR_COUNT - ip_limit.add_ip(ip) raise HTTPException(status_code=404, detail=f"取件码错误,错误{error_count}次将被禁止10分钟") if info.exp_time < datetime.datetime.now() or info.count == 0: - await storage.delete_file({'type': info.type, 'text': info.text}) + if info.type != "text": + await storage.delete_file(info.text) await s.delete(info) await s.commit() raise HTTPException(status_code=404, detail="取件码已过期,请联系寄件人") diff --git a/storage.py b/storage.py index 58a90deff..b992a9a34 100644 --- a/storage.py +++ b/storage.py @@ -46,16 +46,13 @@ async def save_file(self, file: UploadFile, text: str): filepath = await self.get_filepath(text) await asyncio.to_thread(self._save, filepath, file.file) - async def delete_file(self, file): - # 是文件就删除 - if file['type'] != 'text': - filepath = self.DATA_ROOT / file['text'].lstrip(self.STATIC_URL + '/') - await asyncio.to_thread(os.remove, filepath) - - async def delete_files(self, files): - for file in files: - if file['type'] != 'text': - await self.delete_file(file) + async def delete_file(self, text: str): + filepath = await self.get_filepath(text) + await asyncio.to_thread(os.remove, filepath) + + async def delete_files(self, texts): + tasks = [self.delete_file(text) for text in texts] + await asyncio.gather(*tasks) STORAGE_ENGINE = {