Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from fastapi.testclient import TestClient

from . import utils
from tronbyt_server import db
from tronbyt_server.models.app import App

from . import utils


def test_api(auth_client: TestClient, db_connection: sqlite3.Connection) -> None:
# Create a device
Expand All @@ -23,7 +24,7 @@ def test_api(auth_client: TestClient, db_connection: sqlite3.Connection) -> None
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "/"
assert response.headers["location"] == "http://testserver/"

# Get user to find device_id
user = db.get_user(db_connection, "testuser")
Expand Down Expand Up @@ -91,7 +92,7 @@ def _setup_device_with_apps(
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "/"
assert response.headers["location"] == "http://testserver/"

user = db.get_user(db_connection, "testuser")
assert user
Expand Down
2 changes: 1 addition & 1 deletion tests/test_register+login.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_register_login_logout(auth_client: TestClient) -> None:
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "/"
assert response.headers["location"] == "http://testserver/"

response = auth_client.get("/auth/logout", follow_redirects=False)
assert response.status_code == 302
Expand Down
4 changes: 2 additions & 2 deletions tests/test_webp_upload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from io import BytesIO
import shutil
from io import BytesIO
from pathlib import Path

from fastapi.testclient import TestClient
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_webp_upload_and_app_creation(auth_client: TestClient) -> None:
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "/"
assert response.headers["location"] == "http://testserver/"

# 4. Check that the app is added and file is copied
with db.get_db() as db_conn:
Expand Down
24 changes: 15 additions & 9 deletions tronbyt_server/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Form, Request, status
from fastapi.responses import Response, RedirectResponse, JSONResponse
from fastapi.responses import JSONResponse, RedirectResponse, Response
from fastapi_babel import _
from pydantic import BaseModel
from werkzeug.security import generate_password_hash
from fastapi_babel import _

import tronbyt_server.db as db
from tronbyt_server import system_apps, version
Expand All @@ -23,7 +23,7 @@
manager,
)
from tronbyt_server.flash import flash
from tronbyt_server.models import User, ThemePreference
from tronbyt_server.models import ThemePreference, User
from tronbyt_server.templates import templates

router = APIRouter(prefix="/auth", tags=["auth"])
Expand Down Expand Up @@ -76,7 +76,6 @@ def post_register_owner(
request: Request,
password: str = Form(...),
db_conn: sqlite3.Connection = Depends(get_db),
settings: Settings = Depends(get_settings),
) -> Response:
"""Handle owner registration."""
if db.has_users(db_conn):
Expand Down Expand Up @@ -118,7 +117,7 @@ def get_register(
"""Render the user registration page."""
if not db.has_users(db_conn):
return RedirectResponse(
url="/auth/register_owner", status_code=status.HTTP_302_FOUND
url=request.url_for("get_register_owner"), status_code=status.HTTP_302_FOUND
)
if settings.ENABLE_USER_REGISTRATION != "1":
if not user or user.username != "admin":
Expand Down Expand Up @@ -198,7 +197,8 @@ def post_register(
),
)
return RedirectResponse(
url="/auth/register", status_code=status.HTTP_302_FOUND
url=request.url_for("get_register"),
status_code=status.HTTP_302_FOUND,
)
else:
flash(
Expand Down Expand Up @@ -276,7 +276,9 @@ def post_login(
)

user = user_data
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
response = RedirectResponse(
url=request.url_for("index"), status_code=status.HTTP_302_FOUND
)

# Set token expiration
token_expires = timedelta(days=30) if form_data.remember else timedelta(minutes=60)
Expand Down Expand Up @@ -327,7 +329,9 @@ def post_edit(
authed_user.password = generate_password_hash(password)
db.save_user(db_conn, authed_user)
flash(request, _("Success"))
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
return RedirectResponse(
url=request.url_for("index"), status_code=status.HTTP_302_FOUND
)


@router.get("/logout")
Expand Down Expand Up @@ -389,4 +393,6 @@ def generate_api_key(
flash(request, _("New API key generated successfully."))
else:
flash(request, _("Failed to generate new API key."))
return RedirectResponse(url="/auth/edit", status_code=status.HTTP_302_FOUND)
return RedirectResponse(
url=request.url_for("get_edit"), status_code=status.HTTP_302_FOUND
)
Loading