diff --git a/tests/test_api.py b/tests/test_api.py index d3b90f93..4fc7777e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,10 +4,11 @@ from fastapi.testclient import TestClient -from . import utils from tronbyt_server import db from tronbyt_server.models.app import App +from . import utils + def test_api(auth_client: TestClient, db_connection: sqlite3.Connection) -> None: # Create a device @@ -23,7 +24,7 @@ def test_api(auth_client: TestClient, db_connection: sqlite3.Connection) -> None follow_redirects=False, ) assert response.status_code == 302 - assert response.headers["location"] == "/" + assert response.headers["location"] == "http://testserver/" # Get user to find device_id user = db.get_user(db_connection, "testuser") @@ -91,7 +92,7 @@ def _setup_device_with_apps( follow_redirects=False, ) assert response.status_code == 302 - assert response.headers["location"] == "/" + assert response.headers["location"] == "http://testserver/" user = db.get_user(db_connection, "testuser") assert user diff --git a/tests/test_register+login.py b/tests/test_register+login.py index d996249c..05db842a 100644 --- a/tests/test_register+login.py +++ b/tests/test_register+login.py @@ -22,7 +22,7 @@ def test_register_login_logout(auth_client: TestClient) -> None: follow_redirects=False, ) assert response.status_code == 302 - assert response.headers["location"] == "/" + assert response.headers["location"] == "http://testserver/" response = auth_client.get("/auth/logout", follow_redirects=False) assert response.status_code == 302 diff --git a/tests/test_webp_upload.py b/tests/test_webp_upload.py index e61f4c84..f064ec34 100644 --- a/tests/test_webp_upload.py +++ b/tests/test_webp_upload.py @@ -1,5 +1,5 @@ -from io import BytesIO import shutil +from io import BytesIO from pathlib import Path from fastapi.testclient import TestClient @@ -52,7 +52,7 @@ def test_webp_upload_and_app_creation(auth_client: TestClient) -> None: follow_redirects=False, ) assert response.status_code == 302 - assert response.headers["location"] == "/" + assert response.headers["location"] == "http://testserver/" # 4. Check that the app is added and file is copied with db.get_db() as db_conn: diff --git a/tronbyt_server/routers/auth.py b/tronbyt_server/routers/auth.py index 2e3bf0d5..8492046c 100644 --- a/tronbyt_server/routers/auth.py +++ b/tronbyt_server/routers/auth.py @@ -8,10 +8,10 @@ from typing import Annotated from fastapi import APIRouter, Depends, Form, Request, status -from fastapi.responses import Response, RedirectResponse, JSONResponse +from fastapi.responses import JSONResponse, RedirectResponse, Response +from fastapi_babel import _ from pydantic import BaseModel from werkzeug.security import generate_password_hash -from fastapi_babel import _ import tronbyt_server.db as db from tronbyt_server import system_apps, version @@ -23,7 +23,7 @@ manager, ) from tronbyt_server.flash import flash -from tronbyt_server.models import User, ThemePreference +from tronbyt_server.models import ThemePreference, User from tronbyt_server.templates import templates router = APIRouter(prefix="/auth", tags=["auth"]) @@ -76,7 +76,6 @@ def post_register_owner( request: Request, password: str = Form(...), db_conn: sqlite3.Connection = Depends(get_db), - settings: Settings = Depends(get_settings), ) -> Response: """Handle owner registration.""" if db.has_users(db_conn): @@ -118,7 +117,7 @@ def get_register( """Render the user registration page.""" if not db.has_users(db_conn): return RedirectResponse( - url="/auth/register_owner", status_code=status.HTTP_302_FOUND + url=request.url_for("get_register_owner"), status_code=status.HTTP_302_FOUND ) if settings.ENABLE_USER_REGISTRATION != "1": if not user or user.username != "admin": @@ -198,7 +197,8 @@ def post_register( ), ) return RedirectResponse( - url="/auth/register", status_code=status.HTTP_302_FOUND + url=request.url_for("get_register"), + status_code=status.HTTP_302_FOUND, ) else: flash( @@ -276,7 +276,9 @@ def post_login( ) user = user_data - response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + response = RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) # Set token expiration token_expires = timedelta(days=30) if form_data.remember else timedelta(minutes=60) @@ -327,7 +329,9 @@ def post_edit( authed_user.password = generate_password_hash(password) db.save_user(db_conn, authed_user) flash(request, _("Success")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.get("/logout") @@ -389,4 +393,6 @@ def generate_api_key( flash(request, _("New API key generated successfully.")) else: flash(request, _("Failed to generate new API key.")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("get_edit"), status_code=status.HTTP_302_FOUND + ) diff --git a/tronbyt_server/routers/manager.py b/tronbyt_server/routers/manager.py index d7fa9bb7..bf326c84 100644 --- a/tronbyt_server/routers/manager.py +++ b/tronbyt_server/routers/manager.py @@ -8,7 +8,7 @@ import string import time import uuid -from datetime import date, timedelta, datetime, timezone +from datetime import date, datetime, timedelta, timezone from pathlib import Path from random import randint from typing import Annotated, Any, cast @@ -36,26 +36,26 @@ from tronbyt_server.config import Settings, get_settings from tronbyt_server.dependencies import ( DeviceAndApp, + UserAndDevice, get_db, get_device_and_app, get_user_and_device, manager, - UserAndDevice, ) from tronbyt_server.flash import flash from tronbyt_server.models import ( DEFAULT_DEVICE_TYPE, App, + Brightness, Device, DeviceID, DeviceType, Location, + ProtocolType, RecurrencePattern, RecurrenceType, User, Weekday, - ProtocolType, - Brightness, parse_custom_brightness_scale, ) from tronbyt_server.pixlet import call_handler_with_config, get_schema @@ -539,7 +539,9 @@ def create_post( if db.save_user(db_conn, user) and not db.get_device_webp_dir(device.id).is_dir(): db.get_device_webp_dir(device.id).mkdir(parents=True) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/update") @@ -549,7 +551,9 @@ def update( ) -> Response: device = deps.device if not device: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) default_img_url = request.url_for("next_app", device_id=device.id) default_ws_url = str(request.url_for("websocket_endpoint", device_id=device.id)) @@ -697,7 +701,9 @@ def update_post( device = user.devices.get(device_id) if not device: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) error = None if not form_data.name: @@ -808,23 +814,30 @@ def update_post( db.save_user(db_conn, user) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/{device_id}/delete") def delete( + request: Request, device_id: DeviceID, user: User = Depends(manager), db_conn: sqlite3.Connection = Depends(get_db), ) -> Response: device = user.devices.get(device_id) if not device: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) user.devices.pop(device_id) db.save_user(db_conn, user) db.delete_device_dirs(device_id) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/addapp") @@ -836,7 +849,9 @@ def addapp( ) -> Response: device = user.devices.get(device_id) if not device: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) custom_apps_list = db.get_apps_list(user.username) apps_list = db.get_apps_list("system") @@ -885,7 +900,9 @@ def addapp_post( ) -> Response: device = user.devices.get(device_id) if not device: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) if not name: flash(request, _("App name required.")) @@ -958,7 +975,9 @@ def addapp_post( source_path = Path(app.path) if source_path.exists(): shutil.copy(source_path, dest_path) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) return RedirectResponse( url=f"{request.url_for('configapp', device_id=device_id, iname=iname)}?delete_on_cancel=true", @@ -1098,6 +1117,7 @@ def deleteupload( @router.post("/{device_id}/{iname}/delete") def deleteapp( + request: Request, device_and_app: DeviceAndApp = Depends(get_device_and_app), user: User = Depends(manager), db_conn: sqlite3.Connection = Depends(get_db), @@ -1116,7 +1136,9 @@ def deleteapp( device.apps.pop(iname) db.save_user(db_conn, user) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/{device_id}/{iname}/toggle_pin") @@ -1145,7 +1167,9 @@ def toggle_pin( logger.error(f"Failed to toggle pin for device {device.id}: {e}") flash(request, _("Error updating pin status."), "error") - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/{device_id}/{iname}/duplicate") @@ -1237,7 +1261,9 @@ def duplicate_app( return Response("OK", status_code=status.HTTP_200_OK) else: flash(request, _("App duplicated successfully.")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/{iname}/updateapp") @@ -1367,7 +1393,9 @@ def updateapp_post( db.save_user(db_conn, user) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/{device_id}/{iname}/toggle_enabled") @@ -1396,7 +1424,9 @@ def toggle_enabled( logger.error(f"Failed to toggle enabled for app {app.iname}: {e}") flash(request, _("Error saving changes."), "error") - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/{device_id}/{iname}/moveapp") @@ -1409,7 +1439,9 @@ def moveapp( ) -> Response: if direction not in ["up", "down", "top", "bottom"]: flash(request, _("Invalid direction.")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) device = device_and_app.device iname = device_and_app.app.iname @@ -1423,11 +1455,15 @@ def moveapp( if current_idx == -1: flash(request, _("App not found.")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) if direction == "up": if current_idx == 0: - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) target_idx = current_idx - 1 apps_list[current_idx], apps_list[target_idx] = ( apps_list[target_idx], @@ -1435,7 +1471,9 @@ def moveapp( ) elif direction == "down": if current_idx == len(apps_list) - 1: - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) target_idx = current_idx + 1 apps_list[current_idx], apps_list[target_idx] = ( apps_list[target_idx], @@ -1443,13 +1481,17 @@ def moveapp( ) elif direction == "top": if current_idx == 0: - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) # Move app to the top (index 0) app_to_move = apps_list.pop(current_idx) apps_list.insert(0, app_to_move) elif direction == "bottom": if current_idx == len(apps_list) - 1: - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) # Move app to the bottom (last index) app_to_move = apps_list.pop(current_idx) apps_list.append(app_to_move) @@ -1505,7 +1547,9 @@ async def configapp_post( app = device_and_app.app device_id = device.id if not app or not app.path: - return RedirectResponse(url="/", status_code=status.HTTP_404_NOT_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_404_NOT_FOUND + ) app.config = config db.save_user(db_conn, user) @@ -1535,7 +1579,9 @@ async def configapp_post( else: flash(request, _("Error Rendering App")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/{iname}/preview") @@ -1663,6 +1709,7 @@ def adminindex( @router.post("/admin/{username}/deleteuser") def deleteuser( + request: Request, username: str, user: User = Depends(manager), db_conn: sqlite3.Connection = Depends(get_db), @@ -1672,7 +1719,9 @@ def deleteuser( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") if username != "admin": db.delete_user(db_conn, username) - return RedirectResponse(url="/adminindex", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("adminindex"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/firmware") @@ -1760,8 +1809,12 @@ def set_user_repo( if set_repo(request, apps_path, user.app_repo_url, app_repo_url): user.app_repo_url = app_repo_url db.save_user(db_conn, user) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) @router.post("/set_api_key") @@ -1774,10 +1827,14 @@ def set_api_key( """Set the user's API key.""" if not api_key: flash(request, _("API Key cannot be empty.")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) user.api_key = api_key db.save_user(db_conn, user) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) @router.post("/set_system_repo") @@ -1803,8 +1860,12 @@ def set_system_repo( user.system_repo_url = app_repo_url db.save_user(db_conn, user) system_apps.generate_apps_json(db.get_data_dir()) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) @router.post("/refresh_system_repo") @@ -1818,7 +1879,9 @@ def refresh_system_repo( # Directly update the system repo - it handles git pull internally system_apps.update_system_repo(db.get_data_dir()) flash(request, _("System repo updated successfully")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) @router.post("/mark_app_broken/{app_name}") @@ -1982,7 +2045,7 @@ def update_firmware(request: Request, user: User = Depends(manager)) -> Response logger.error(f"Error updating firmware: {e}") flash(request, _("❌ Firmware update failed: {str(e)}"), "error") return RedirectResponse( - url="/auth/edit#firmware-management", + url=str(request.url_for("edit")) + "#firmware-management", status_code=status.HTTP_302_FOUND, ) @@ -1995,8 +2058,12 @@ def refresh_user_repo( """Refresh the user's custom app repository.""" apps_path = db.get_users_dir() / user.username / "apps" if set_repo(request, apps_path, user.app_repo_url, user.app_repo_url): - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) @router.get("/export_user_config") @@ -2090,7 +2157,9 @@ async def import_device_config_post( ) if device_config["id"] != device_id: flash(request, _("Not the same device id. Import skipped.")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) # Regenerate URLs with current server root try: @@ -2108,7 +2177,9 @@ async def import_device_config_post( user.devices[device.id] = device db.save_user(db_conn, user) flash(request, _("Device configuration imported successfully")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) except json.JSONDecodeError as e: flash(request, _("Error parsing JSON file: {error}").format(error=e)) return RedirectResponse( @@ -2133,17 +2204,23 @@ async def import_user_config( """Handle import of user configuration.""" if not file.filename: flash(request, _("No selected file")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) if not file.filename.endswith(".json"): flash(request, _("Invalid file type. Please upload a JSON file.")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) try: contents = await file.read() user_config_raw = json.loads(contents) if not isinstance(user_config_raw, dict): flash(request, _("Invalid JSON structure")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) user_config = cast(dict[str, Any], user_config_raw) # Replace all user data except username and password @@ -2178,15 +2255,21 @@ async def import_user_config( db.save_user(db_conn, new_user) flash(request, _("User configuration imported successfully")) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) except json.JSONDecodeError as e: logging.error(f"JSON decode error during user config import: {e}") flash(request, _("Error parsing JSON file: {error}").format(error=e)) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) except Exception as e: logging.error(f"Error importing user config: {e}", exc_info=True) flash(request, _("Error importing config: {error}").format(error=str(e))) - return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("edit"), status_code=status.HTTP_302_FOUND + ) @router.get("/import_device", name="import_device") @@ -2207,10 +2290,14 @@ async def import_device_post( """Handle import of a new device.""" if not file.filename: flash(request, _("No selected file")) - return RedirectResponse(url="/import_device", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND + ) if not file.filename.endswith(".json"): flash(request, _("Invalid file type. Please upload a JSON file.")) - return RedirectResponse(url="/import_device", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND + ) try: contents = await file.read() @@ -2218,7 +2305,7 @@ async def import_device_post( if not isinstance(device_config_raw, dict): flash(request, _("Invalid JSON structure")) return RedirectResponse( - url="/import_device", status_code=status.HTTP_302_FOUND + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND ) device_config = cast(dict[str, Any], device_config_raw) @@ -2226,11 +2313,13 @@ async def import_device_post( if not device_id: flash(request, _("Device ID missing in config.")) return RedirectResponse( - url="/import_device", status_code=status.HTTP_302_FOUND + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND ) if device_id in user.devices or db.get_device_by_id(db_conn, device_id): flash(request, _("Device already exists. Import skipped.")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) # Regenerate URLs with current server root try: @@ -2241,17 +2330,21 @@ async def import_device_post( ) flash(request, _("Invalid device configuration file")) return RedirectResponse( - url="/import_device", status_code=status.HTTP_302_FOUND + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND ) device.img_url = str(request.url_for("next_app", device_id=device_id)) device.ws_url = str(request.url_for("websocket_endpoint", device_id=device_id)) user.devices[device.id] = device db.save_user(db_conn, user) flash(request, _("Device configuration imported successfully")) - return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("index"), status_code=status.HTTP_302_FOUND + ) except json.JSONDecodeError as e: flash(request, _("Error parsing JSON file: {error}").format(error=e)) - return RedirectResponse(url="/import_device", status_code=status.HTTP_302_FOUND) + return RedirectResponse( + url=request.url_for("import_device"), status_code=status.HTTP_302_FOUND + ) @router.get("/{device_id}/next", name="next_app")