Skip to content

Commit

Permalink
Use TTL on mongodb for expiration
Browse files Browse the repository at this point in the history
  • Loading branch information
Lxstr committed Jan 30, 2024
1 parent 86895b5 commit 9acee3c
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions src/flask_session/sessions.py
Expand Up @@ -7,7 +7,7 @@
except ImportError:
import pickle

from datetime import datetime, timezone
from datetime import datetime

from flask.sessions import SessionInterface as FlaskSessionInterface
from flask.sessions import SessionMixin
Expand Down Expand Up @@ -477,31 +477,18 @@ def __init__(
self.client = client
self.store = client[db][collection]
self.use_deprecated_method = int(pymongo.version.split(".")[0]) < 4

# Create a TTL index on the expiration time, so that mongo can automatically delete expired sessions
self.store.create_index("expiration", expireAfterSeconds=0)

super().__init__(self.store, key_prefix, use_signer, permanent, sid_length)

def fetch_session(self, sid):
# Get the saved session (document) from the database
prefixed_session_id = self.key_prefix + sid
document = self.store.find_one({"id": prefixed_session_id})

# If the expiration time is less than or equal to the current time (expired), delete the document
if document is not None:
expiration_datetime = document.get("expiration")
# tz_aware mongodb fix
expiration_datetime_tz_aware = expiration_datetime.replace(
tzinfo=timezone.utc
)
now_datetime_tz_aware = datetime.utcnow().replace(tzinfo=timezone.utc)
if expiration_datetime is None or (
expiration_datetime_tz_aware <= now_datetime_tz_aware
):
if self.use_deprecated_method:
self.store.remove({"id": prefixed_session_id})
else:
self.store.delete_one({"id": prefixed_session_id})
document = None

# If the saved session still exists after checking for expiration, load the session data from the document
# If the saved session exists and has not auto-expired, load the session data from the document
if document is not None:
try:
session_data = self.serializer.loads(want_bytes(document["val"]))
Expand Down

0 comments on commit 9acee3c

Please sign in to comment.