Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL cleanup thread or per requests #211

Merged
merged 8 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/flask_session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def _get_interface(self, app):
SESSION_SQLALCHEMY_BIND_KEY = config.get(
"SESSION_SQLALCHEMY_BIND_KEY", Defaults.SESSION_SQLALCHEMY_BIND_KEY
)
SESSION_CLEANUP_N_REQUESTS = config.get(
"SESSION_CLEANUP_N_REQUESTS", Defaults.SESSION_CLEANUP_N_REQUESTS
)
SESSION_CLEANUP_N_SECONDS = config.get(
"SESSION_CLEANUP_N_SECONDS", Defaults.SESSION_CLEANUP_N_SECONDS
)

common_params = {
"app": app,
Expand Down Expand Up @@ -147,6 +153,8 @@ def _get_interface(self, app):
sequence=SESSION_SQLALCHEMY_SEQUENCE,
schema=SESSION_SQLALCHEMY_SCHEMA,
bind_key=SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests=SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds=SESSION_CLEANUP_N_SECONDS,
)
else:
raise RuntimeError(f"Unrecognized value for SESSION_TYPE: {SESSION_TYPE}")
Expand Down
4 changes: 4 additions & 0 deletions src/flask_session/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class Defaults:
SESSION_PERMANENT = True
SESSION_SID_LENGTH = 3

# Clean up settings for non TTL backends (SQL, PostgreSQL, etc.)
SESSION_CLEANUP_N_REQUESTS = None
SESSION_CLEANUP_N_SECONDS = None

# Redis settings
SESSION_REDIS = None

Expand Down
81 changes: 80 additions & 1 deletion src/flask_session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
except ImportError:
import pickle

import random
from datetime import datetime
from datetime import timedelta as TimeDelta
from threading import Thread
from typing import Any, Optional

from flask import Flask, Request, Response
Expand Down Expand Up @@ -100,13 +102,25 @@ def __init__(
use_signer: bool = Defaults.SESSION_USE_SIGNER,
permanent: bool = Defaults.SESSION_PERMANENT,
sid_length: int = Defaults.SESSION_SID_LENGTH,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds: Optional[int] = Defaults.SESSION_CLEANUP_N_SECONDS,
):
self.app = app
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.sid_length = sid_length
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
self.cleanup_n_requests = cleanup_n_requests
self.cleanup_n_seconds = cleanup_n_seconds

# Cleanup settings for non-TTL databases only
if self.ttl:
self._register_cleanup_app_command()
if self.cleanup_n_seconds:
self._start_cleanup_thread(self.cleanup_n_seconds)
if self.cleanup_n_requests:
self.app.before_request(self._cleanup_per_requests)

def save_session(
self, app: Flask, session: ServerSideSession, response: Response
Expand Down Expand Up @@ -177,6 +191,42 @@ def open_session(self, app: Flask, request: Request) -> ServerSideSession:
sid = self._generate_sid(self.sid_length)
return self.session_class(sid=sid, permanent=self.permanent)

# CLEANUP METHODS FOR NON TTL DATABASES

def _register_cleanup_app_command(self):
"""
Register a custom Flask CLI command for cleaning up expired sessions.

Run the command with `flask session_cleanup`. Run with a cron job
or scheduler such as Heroku Scheduler to automatically clean up expired sessions.
"""

@self.app.cli.command("session_cleanup")
def session_cleanup():
with self.app.app_context():
self._delete_expired_sessions()

def _cleanup_n_requests(self) -> None:
"""Delete expired sessions approximately every N requests."""
if self.cleanup_n_seconds or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should run if self.cleanup_n_seconds is truthy. If that is specified, the clean up will run multiple times:

  • before every request
  • in the cleanup threads (one per uWSGI worker)

self.cleanup_n_requests and random.randint(0, self.cleanup_n_requests) == 0
):
self._delete_expired_sessions()

def _start_cleanup_thread(self, cleanup_n_seconds: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the discord discussion, this will start one thread per uWSGI worker.

I would personally not implement this at all and use the cleanup per requests instead. It's not really essential for the clean up to run exactly every N seconds and I would lean towards having fewer connections to the database.

"""Start a background thread to delete expired sessions approximately every N seconds."""

def cleanup():
with self.app.app_context():
while True:
self._delete_expired_sessions()
time.sleep(cleanup_n_seconds)

thread = Thread(target=cleanup, daemon=True)
thread.start()

# METHODS TO BE IMPLEMENTED BY SUBCLASSES

def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
raise NotImplementedError()

Expand All @@ -188,6 +238,10 @@ def _upsert_session(
) -> None:
raise NotImplementedError()

def _delete_expired_sessions(self) -> None:
"""Delete expired sessions from the backend storage. Only required for non-TTL databases."""
pass


class RedisSessionInterface(ServerSideSessionInterface):
"""Uses the Redis key-value store as a session backend. (`redis-py` required)
Expand All @@ -207,6 +261,7 @@ class RedisSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = RedisSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -273,6 +328,7 @@ class MemcachedSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = MemcachedSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -365,6 +421,8 @@ class FileSystemSessionInterface(ServerSideSessionInterface):
"""

session_class = FileSystemSession
serializer = None
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -425,6 +483,7 @@ class MongoDBSessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = MongoDBSession
ttl = True

def __init__(
self,
Expand Down Expand Up @@ -515,6 +574,11 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface):
:param sequence: The sequence to use for the primary key if needed.
:param schema: The db schema to use
:param bind_key: The db bind key to use
:param cleanup_n_requests: Delete expired sessions approximately every N requests.
:param cleanup_n_seconds: Delete expired sessions approximately every N seconds.

.. versionadded:: 0.7
The `cleanup_n_requests` and `cleanup_n_seconds` parameters were added.

.. versionadded:: 0.6
The `sid_length`, `sequence`, `schema` and `bind_key` parameters were added.
Expand All @@ -525,6 +589,7 @@ class SqlAlchemySessionInterface(ServerSideSessionInterface):

serializer = pickle
session_class = SqlAlchemySession
non_ttl = True

def __init__(
self,
Expand All @@ -538,12 +603,14 @@ def __init__(
sequence: Optional[str] = Defaults.SESSION_SQLALCHEMY_SEQUENCE,
schema: Optional[str] = Defaults.SESSION_SQLALCHEMY_SCHEMA,
bind_key: Optional[str] = Defaults.SESSION_SQLALCHEMY_BIND_KEY,
cleanup_n_requests: Optional[int] = Defaults.SESSION_CLEANUP_N_REQUESTS,
cleanup_n_seconds: Optional[int] = Defaults.SESSION_CLEANUP_N_SECONDS,
):
self.app = app
if db is None:
from flask_sqlalchemy import SQLAlchemy

db = SQLAlchemy(app)

self.db = db
self.sequence = sequence
self.schema = schema
Expand Down Expand Up @@ -586,6 +653,18 @@ def __repr__(self):

self.sql_session_model = Session

def _delete_expired_sessions(self) -> None:
try:
self.db.session.query(self.sql_session_model).filter(
self.sql_session_model.expiry <= datetime.utcnow()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both here and in the session creation I would use the database's now() function. This would prevent issues due to time synchronization between multiple servers.

).delete(synchronize_session=False)
self.db.session.commit()
self.app.logger.info("Deleted expired sessions")
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a db.session.rollback() when handling the exception?

self.app.logger.exception(
e, "Failed to delete expired sessions. Skipping..."
)

def _retrieve_session_data(self, store_id: str) -> Optional[dict]:
# Get the saved session (record) from the database
record = self.sql_session_model.query.filter_by(session_id=store_id).first()
Expand Down