diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6827646..fe2d845 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,9 +13,6 @@ jobs: uses: actions/checkout@v3 - name: Set up docker uses: docker-practice/actions-setup-docker@master - - name: Run postgres - run: | - docker run -d -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust --name db-test postgres:15-alpine - uses: actions/setup-python@v4 with: python-version: '3.11' @@ -24,12 +21,18 @@ jobs: python -m ensurepip python -m pip install --upgrade pip pip install -r requirements.txt -r requirements.dev.txt + - name: Run postgres + run: | + make db - name: Migrate DB run: | - DB_DSN=postgresql://postgres@localhost:5432/postgres alembic upgrade head + make migrate + - name: Run redis + run: | + make redis - name: Build coverage file run: | - SECRET_KEY='fg' DB_DSN=postgresql://postgres@localhost:5432/postgres pytest --cache-clear --cov=print_service tests > pytest-coverage.txt + pytest --cache-clear --cov=print_service tests > pytest-coverage.txt - name: Print report if: always() run: | diff --git a/Makefile b/Makefile index af249bb..16cea6c 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ run: - source ./venv/bin/activate && uvicorn --reload --log-level debug print_service.fastapi:app + source ./venv/bin/activate && uvicorn --reload --log-level debug print_service.routes:app db: docker run -d -p 5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust --name db-print_service postgres:15 diff --git a/print_service/routes/base.py b/print_service/routes/base.py index 243a5b7..08adbff 100644 --- a/print_service/routes/base.py +++ b/print_service/routes/base.py @@ -8,6 +8,7 @@ from print_service import __version__ from print_service.routes.file import router as file_router from print_service.routes.user import router as user_router +from print_service.routes.qrprint import router as qrprint_router from print_service.settings import Settings, get_settings @@ -34,4 +35,5 @@ app.include_router(user_router, prefix='', tags=['User']) app.include_router(file_router, prefix='/file', tags=['File']) +app.include_router(qrprint_router, prefix='/qr', tags=['File']) app.mount('/static', StaticFiles(directory='static'), 'static') diff --git a/print_service/routes/file.py b/print_service/routes/file.py index 11fad48..b9c61da 100644 --- a/print_service/routes/file.py +++ b/print_service/routes/file.py @@ -15,7 +15,7 @@ from print_service.models import UnionMember from print_service.schema import BaseModel from print_service.settings import Settings, get_settings -from print_service.utils import generate_filename, generate_pin +from print_service.utils import generate_filename, generate_pin, get_file logger = logging.getLogger(__name__) @@ -240,27 +240,7 @@ async def print_file(pin: str, settings: Settings = Depends(get_settings)): бесконечное количество раз в течение 7 дней после загрузки (меняется в настройках сервера). """ - file_model = ( - db.session.query(FileModel) - .filter(func.upper(FileModel.pin) == pin.upper()) - .order_by(FileModel.created_at.desc()) - .one_or_none() - ) - if not file_model: - raise HTTPException(404, f'Pin {pin} not found') - - path = abspath(settings.STATIC_FOLDER) + '/' + file_model.file - if not exists(path): - raise HTTPException(415, 'File has not uploaded yet') - - return { - 'filename': file_model.file, - 'options': { - 'pages': file_model.option_pages or '', - 'copies': file_model.option_copies or 1, - 'two_sided': file_model.option_two_sided or False, - }, - } + return get_file(db.session, pin)[0] # endregion diff --git a/print_service/routes/qrprint.py b/print_service/routes/qrprint.py new file mode 100644 index 0000000..b64e3eb --- /dev/null +++ b/print_service/routes/qrprint.py @@ -0,0 +1,106 @@ +import json +import logging +import random +from asyncio import sleep +from datetime import datetime, timedelta + +from fastapi import APIRouter, Header, WebSocket, HTTPException +from fastapi_sqlalchemy import db +from pydantic import conlist +from redis import Redis + +from print_service.schema import BaseModel +from print_service.settings import Settings, get_settings +from print_service.utils import get_file + + +logger = logging.getLogger(__name__) +settings: Settings = get_settings() +router = APIRouter() + + +class InstantPrintCreate(BaseModel): + qr_token: str + files: conlist(str, min_items=1, max_items=10, unique_items=True) + + +class InstantPrintSender: + def __init__(self, settings: Settings = None) -> None: + settings = settings or get_settings() + self.redis: Redis = Redis.from_url(settings.REDIS_DSN) + + def send(self, qr_token: str, files: list[str]): + terminal = self.redis.get(qr_token) + if not terminal: + return None + self.redis.delete(qr_token) + old = self.redis.get(terminal) + if old: + return None + files = get_file(db.session, files) + self.redis.set(terminal, json.dumps({'files': files})) + return files + + +class InstantPrintFetcher: + def __init__(self, terminal_token: str, settings: Settings = None) -> None: + self.terminal_token = terminal_token + settings = settings or get_settings() + self.redis = Redis.from_url(settings.REDIS_DSN) + self.ttl = settings.QR_TOKEN_TTL + self.delay = settings.QR_TOKEN_DELAY + self.symbols = settings.QR_TOKEN_SYMBOLS + self.length = settings.QR_TOKEN_LENGTH + + def new_qr(self): + for _ in range(5): + qr_token = ''.join(random.choice(self.symbols) for _ in range(self.length)) + if not self.redis.get(qr_token): # If this qr already exists, generate new + break + self.redis.set(qr_token, self.terminal_token, ex=self.ttl+self.delay) # Send token to redis +ttl + return qr_token + + async def get_tasks(self) -> dict[str, list[str]]: + until = datetime.utcnow() + timedelta(seconds=self.ttl) + while datetime.utcnow() < until: + raw_value: bytes = self.redis.get(self.terminal_token) + if raw_value: + self.redis.delete(self.terminal_token) + break + await sleep(0.5) + else: + return {} + return json.loads(raw_value) + + def __aiter__(self): + return self + + async def __anext__(self): + value = await self.get_tasks() + qr_token = self.new_qr() + result = {"qr_token": qr_token, **value} + return result + + +redis_conn = InstantPrintSender() + + +@router.post("") +async def instant_print(options: InstantPrintCreate): + options.qr_token = options.qr_token.removeprefix(settings.QR_TOKEN_PREFIX) + if redis_conn.send(**options.dict()): + return {'status': 'ok'} + raise HTTPException(400, 'Terminal not found by qr') + + +@router.websocket("") +async def instant_print_terminal_connection( + websocket: WebSocket, + authorization: str = Header(), +): + await websocket.accept() + manager = InstantPrintFetcher(authorization.removeprefix("token ")) + await websocket.send_text(json.dumps({"qr_token": settings.QR_TOKEN_PREFIX + manager.new_qr()})) + async for task in manager: + task['qr_token'] = settings.QR_TOKEN_PREFIX + task['qr_token'] + await websocket.send_text(json.dumps(task)) diff --git a/print_service/settings.py b/print_service/settings.py index 6147d5e..8347b42 100644 --- a/print_service/settings.py +++ b/print_service/settings.py @@ -2,11 +2,12 @@ from functools import lru_cache from typing import List, Optional -from pydantic import BaseSettings, DirectoryPath, HttpUrl, PostgresDsn +from pydantic import BaseSettings, DirectoryPath, PostgresDsn, RedisDsn class Settings(BaseSettings): DB_DSN: PostgresDsn = 'postgresql://postgres@localhost:5432/postgres' + REDIS_DSN: RedisDsn = 'redis://localhost:6379/0' SECRET_KEY: Optional[str] = '42' @@ -25,6 +26,12 @@ class Settings(BaseSettings): CORS_ALLOW_METHODS: list[str] = ['*'] CORS_ALLOW_HEADERS: list[str] = ['*'] + QR_TOKEN_PREFIX: str = "" + QR_TOKEN_SYMBOLS: str = string.ascii_uppercase + string.digits + QR_TOKEN_LENGTH: int = 6 + QR_TOKEN_TTL: int = 30 # Show time of QR code in seconds + QR_TOKEN_DELAY: int = 5 # How long QR code valid after hide in seconds + class Config: env_file = '.env' diff --git a/print_service/utils/__init__.py b/print_service/utils/__init__.py index 737d60b..2408a09 100644 --- a/print_service/utils/__init__.py +++ b/print_service/utils/__init__.py @@ -1,10 +1,16 @@ import random import re from datetime import date, datetime, timedelta +from os.path import abspath, exists +from fastapi import File +from fastapi.exceptions import HTTPException +from sqlalchemy import func from sqlalchemy.orm.session import Session +from print_service import __version__ from print_service.models import File +from print_service.models import File as FileModel from print_service.settings import Settings, get_settings @@ -33,3 +39,31 @@ def generate_filename(original_filename: str): salt = ''.join(random.choice(settings.PIN_SYMBOLS) for i in range(128)) ext = re.findall(r'\w+', original_filename.split('.')[-1])[0] return f'{datestr}-{salt}.{ext}' + + +def get_file(dbsession, pin: str or list[str]): + pin = [pin.upper()] if isinstance(pin, str) else tuple(p.upper() for p in pin) + files: list[FileModel] = ( + dbsession.query(FileModel) + .filter(func.upper(FileModel.pin).in_(pin)) + .order_by(FileModel.created_at.desc()) + .all() + ) + if len(pin) != len(files): + raise HTTPException(404, f'{len(pin) - len(files)} file(s) not found') + + result = [] + for f in files: + path = abspath(settings.STATIC_FOLDER) + '/' + f.file + if not exists(path): + raise HTTPException(415, 'File has not uploaded yet') + + result.append({ + 'filename': f.file, + 'options': { + 'pages': f.option_pages or '', + 'copies': f.option_copies or 1, + 'two_sided': f.option_two_sided or False, + }, + }) + return result diff --git a/requirements.txt b/requirements.txt index 3da35c9..022bab1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ fastapi fastapi-sqlalchemy python-multipart gunicorn +websockets +redis diff --git a/test_client.py b/test_client.py new file mode 100644 index 0000000..d502035 --- /dev/null +++ b/test_client.py @@ -0,0 +1,9 @@ +import asyncio +import websockets + +async def hello(): + async with websockets.connect("ws://localhost:8000/qr", extra_headers={"Authorization": 'token ADAQ-123456789'}) as websocket: + async for message in websocket: + print(message) + +asyncio.run(hello()) diff --git a/tests/test_routes/conftest.py b/tests/test_routes/conftest.py index 67ab459..4d032da 100644 --- a/tests/test_routes/conftest.py +++ b/tests/test_routes/conftest.py @@ -19,6 +19,7 @@ def union_member_user(dbsession): db_user = dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).one_or_none() assert db_user is not None dbsession.query(UnionMember).filter(UnionMember.id == union_member['id']).delete() + dbsession.flush() @pytest.fixture(scope='function') @@ -33,6 +34,7 @@ def uploaded_file_db(dbsession, union_member_user, client): db_file = dbsession.query(File).filter(File.pin == res.json()['pin']).one_or_none() yield db_file dbsession.query(File).filter(File.pin == res.json()['pin']).delete() + dbsession.flush() @pytest.fixture diff --git a/tests/test_routes/test_file.py b/tests/test_routes/test_file.py index 4b4ba39..7e46c72 100644 --- a/tests/test_routes/test_file.py +++ b/tests/test_routes/test_file.py @@ -1,8 +1,13 @@ -from print_service.settings import get_settings -from print_service.models import File -from starlette import status import json +import pytest +from fastapi import HTTPException +from starlette import status + +from print_service.models import File +from print_service.settings import get_settings +from print_service.utils import get_file + url = '/file' settings = get_settings() @@ -47,3 +52,28 @@ def test_get_file_mock_path(uploaded_file_os, client): def test_get_file_wrong_pin(uploaded_file_os, client): res = client.get(f"{url}/{uploaded_file_os.pin}test404") assert res.status_code == status.HTTP_404_NOT_FOUND + + +def test_get_file_func_1_not_exists(dbsession): + with pytest.raises(HTTPException): + get_file(dbsession, ['1']) + + +def test_get_file_func_1_not_uploaded(dbsession, uploaded_file_db): + with pytest.raises(HTTPException): + data = get_file(dbsession, [uploaded_file_db.pin]) + +def test_get_file_func_1_ok(dbsession, uploaded_file_os): + data = get_file(dbsession, [uploaded_file_os.pin]) + assert len(data) == 1 + assert data[0] == { + 'filename': uploaded_file_os.file, + 'options': { + 'pages': uploaded_file_os.option_pages or '', + 'copies': uploaded_file_os.option_copies or 1, + 'two_sided': uploaded_file_os.option_two_sided or False, + }, + } +def test_get_file_func_2_not_exists(dbsession, uploaded_file_os): + with pytest.raises(HTTPException): + data = get_file(dbsession, [uploaded_file_os.pin, '1']) diff --git a/tests/test_routes/test_qr.py b/tests/test_routes/test_qr.py new file mode 100644 index 0000000..9869c02 --- /dev/null +++ b/tests/test_routes/test_qr.py @@ -0,0 +1,35 @@ +import pytest +from fastapi.testclient import TestClient +from starlette.websockets import WebSocketDisconnect +from print_service.settings import get_settings + +settings = get_settings() +settings.STATIC_FOLDER = './static' + + +@pytest.mark.skip() +def test_ws_connect_ok(client: TestClient, uploaded_file_os): + with client.websocket_connect('/qr', headers={"authorization": "token 123"}) as ws: + data = ws.receive_json() + assert set(data.keys()) == set(['qr_token']) + t = data['qr_token'] + result = client.post('/qr', json={"qr_token": t, "files": [uploaded_file_os.pin]}) + data = ws.receive_json() + assert set(data.keys()) == set(['qr_token', 'files']) + assert data["qr_token"] != t + assert len(data["files"]) == 1 + assert data["files"][0] == { + 'filename': uploaded_file_os.file, + 'options': { + 'pages': uploaded_file_os.option_pages or '', + 'copies': uploaded_file_os.option_copies or 1, + 'two_sided': uploaded_file_os.option_two_sided or False, + }, + } + return + + +def test_ws_connect_notoken(client: TestClient): + with pytest.raises(WebSocketDisconnect): + with client.websocket_connect('/qr') as ws: + pass