Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
from pathlib import Path

from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException
from fastapi import FastAPI, Depends, UploadFile, Form, File, HTTPException, BackgroundTasks
from starlette.responses import HTMLResponse, FileResponse
from starlette.staticfiles import StaticFiles

Expand Down Expand Up @@ -51,9 +51,13 @@ async def delete_expire_files():
async with AsyncSession(engine, expire_on_commit=False) as s:
query = select(Codes).where(or_(Codes.exp_time < datetime.datetime.now(), Codes.count == 0))
exps = (await s.execute(query)).scalars().all()
files = [{'type': old.type, 'text': old.text} for old in exps]
files = []
exps_ids = []
for exp in exps:
if exp.type != "text":
files.append(exp.text)
exps_ids.append(exp.id)
await storage.delete_files(files)
exps_ids = [exp.id for exp in exps]
query = delete(Codes).where(Codes.id.in_(exps_ids))
await s.execute(query)
await s.commit()
Expand Down Expand Up @@ -83,9 +87,11 @@ async def admin_post(s: AsyncSession = Depends(get_session)):
async def admin_delete(code: str, s: AsyncSession = Depends(get_session)):
query = select(Codes).where(Codes.code == code)
file = (await s.execute(query)).scalars().first()
await storage.delete_file({'type': file.type, 'text': file.text})
await s.delete(file)
await s.commit()
if file:
if file.type != 'text':
await storage.delete_file(file.text)
await s.delete(file)
await s.commit()
return {'detail': '删除成功'}


Expand Down Expand Up @@ -115,7 +121,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend
error_count = settings.ERROR_COUNT - ip_limit.add_ip(ip)
raise HTTPException(status_code=404, detail=f"取件码错误,错误{error_count}次将被禁止10分钟")
if info.exp_time < datetime.datetime.now() or info.count == 0:
await storage.delete_file({'type': info.type, 'text': info.text})
if info.type != "text":
await storage.delete_file(info.text)
await s.delete(info)
await s.commit()
raise HTTPException(status_code=404, detail="取件码已过期,请联系寄件人")
Expand All @@ -130,8 +137,8 @@ async def index(code: str, ip: str = Depends(ip_limit), s: AsyncSession = Depend


@app.post('/share')
async def share(text: str = Form(default=None), style: str = Form(default='2'), value: int = Form(default=1),
file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
async def share(background_tasks: BackgroundTasks, text: str = Form(default=None), style: str = Form(default='2'),
value: int = Form(default=1), file: UploadFile = File(default=None), s: AsyncSession = Depends(get_session)):
code = await get_code(s)
if style == '2':
if value > 7:
Expand All @@ -148,11 +155,11 @@ async def share(text: str = Form(default=None), style: str = Form(default='2'),
exp_count = -1
key = uuid.uuid4().hex
if file:
file_bytes = await file.read()
size = len(file_bytes)
size = await storage.get_size(file)
if size > settings.FILE_SIZE_LIMIT:
raise HTTPException(status_code=400, detail="文件过大")
_text, _type, name = await storage.save_file(file, file_bytes, key), file.content_type, file.filename
_text, _type, name = await storage.get_text(file, key), file.content_type, file.filename
background_tasks.add_task(storage.save_file, file, _text)
else:
size, _text, _type, name = len(text), text, 'text', '文本分享'
info = Codes(
Expand Down
57 changes: 35 additions & 22 deletions storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import asyncio
import datetime
from datetime import datetime
from pathlib import Path
from typing import BinaryIO

from fastapi import UploadFile

import settings

Expand All @@ -11,35 +14,45 @@ class FileSystemStorage:
STATIC_URL = settings.STATIC_URL
NAME = "filesystem"

async def get_filepath(self, path):
return self.DATA_ROOT / path.lstrip(self.STATIC_URL + '/')

def _save(self, filepath, file_bytes):
with open(filepath, 'wb') as f:
f.write(file_bytes)
async def get_filepath(self, text: str):
return self.DATA_ROOT / text.lstrip(self.STATIC_URL + '/')

async def save_file(self, file, file_bytes, key):
now = datetime.datetime.now()
path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/"
async def get_text(self, file: UploadFile, key: str):
ext = file.filename.split('.')[-1]
name = f'{key}.{ext}'
now = datetime.now()
path = self.DATA_ROOT / f"upload/{now.year}/{now.month}/{now.day}/"
if not path.exists():
path.mkdir(parents=True)
filepath = path / name
await asyncio.to_thread(self._save, filepath, file_bytes)
filepath = path / f'{key}.{ext}'
text = f"{self.STATIC_URL}/{filepath.relative_to(self.DATA_ROOT)}"
return text

async def delete_file(self, file):
# 是文件就删除
if file['type'] != 'text':
filepath = self.DATA_ROOT / file['text'].lstrip(self.STATIC_URL + '/')
await asyncio.to_thread(os.remove, filepath)
async def get_size(self, file: UploadFile):
f = file.file
f.seek(0, os.SEEK_END)
size = f.tell()
f.seek(0, os.SEEK_SET)
return size

async def delete_files(self, files):
for file in files:
if file['type'] != 'text':
await self.delete_file(file)
def _save(self, filepath, file: BinaryIO):
with open(filepath, 'wb') as f:
chunk_size = 256 * 1024
chunk = file.read(chunk_size)
while chunk:
f.write(chunk)
chunk = file.read(chunk_size)

async def save_file(self, file: UploadFile, text: str):
filepath = await self.get_filepath(text)
await asyncio.to_thread(self._save, filepath, file.file)

async def delete_file(self, text: str):
filepath = await self.get_filepath(text)
await asyncio.to_thread(os.remove, filepath)

async def delete_files(self, texts):
tasks = [self.delete_file(text) for text in texts]
await asyncio.gather(*tasks)


STORAGE_ENGINE = {
Expand Down