Skip to content

Commit

Permalink
Move email token caching to database (close #30)
Browse files Browse the repository at this point in the history
  • Loading branch information
stfwn committed Sep 6, 2022
1 parent 70a71cf commit 5a22c0c
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 61 deletions.
38 changes: 25 additions & 13 deletions metaserver/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
import json
import secrets

from fastapi import (
BackgroundTasks,
Expand All @@ -17,7 +18,14 @@

import metaserver.database.api as db
from metaserver import auth, config, email
from metaserver.database.models import Clan, Skin, User, UserClanLink, Server
from metaserver.database.models import (
Clan,
EmailToken,
Skin,
User,
UserClanLink,
Server,
)
from metaserver.database.utils import UserClanLinkDeletedReason
from metaserver.schemas import (
ClanCreate,
Expand Down Expand Up @@ -136,8 +144,9 @@ def user_register(
)
try:
user = db.commit_and_refresh(session, user)
mail_token = email.generate_token(user.id)
email.send_verification_email(recipient=user.username, token=mail_token)
email_token = EmailToken(user=user)
db.commit_and_refresh(session, email_token)
email.send_verification_email(recipient=user.username, token=email_token.key)
except IntegrityError as e:
if "display_name" in e._message():
raise HTTPException(status.HTTP_409_CONFLICT, "Display name taken")
Expand Down Expand Up @@ -168,8 +177,9 @@ def user_email_verify(
):
if user.verified_email:
raise HTTPException(status.HTTP_403_FORBIDDEN, "User already verified mail")
try:
if email.verify_token(user.id, mail_token):

if token := user.email_token:
if secrets.compare_digest(mail_token, token.key):
user.verified_email = datetime.utcnow()
db.commit_and_refresh(session, user)
background_tasks.add_task(db.set_user_last_online_now, session, user)
Expand All @@ -180,7 +190,7 @@ def user_email_verify(
raise HTTPException(
status.HTTP_403_FORBIDDEN, "Incorrect mail verification token"
)
except KeyError:
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
"No mail verification token found for user (request a new one)",
Expand All @@ -195,13 +205,15 @@ def user_email_new_token(
):
if user.verified_email:
raise HTTPException(status.HTTP_403_FORBIDDEN, "User already verified mail")
if email.get_token_age_for_user(user.id) <= config.email_token_renew_timeout:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"Wait at least {config.email_token_renew_timeout} seconds before requesting a new token",
)
mail_token = email.generate_token(user.id)
email.send_verification_email(user.username, mail_token)
if email_token := user.email_token:
if datetime.utcnow() - email_token.created <= config.email_token_renew_timeout:
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"Wait at least {config.email_token_renew_timeout.seconds} seconds before requesting a new token",
)
email_token.key = EmailToken.new_key()
db.commit_and_refresh(session, email_token)
email.send_verification_email(user.username, email_token.key)
return "Token sent"


Expand Down
2 changes: 1 addition & 1 deletion metaserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
dev_mode = True if os.environ.get("DEV") else False

# How long people have to wait between receiving an email token and requesting a new one.
email_token_renew_timeout = 30
email_token_renew_timeout = timedelta(seconds=30)

# The granularity of this format determines how long a proof is valid
proof_datetime_component_format = "%Y-%m-%dT%H:%M"
21 changes: 21 additions & 0 deletions metaserver/database/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime
import random
import string
from typing import Literal, Optional

from sqlmodel import VARCHAR, Column, Field, JSON, Relationship, SQLModel, create_engine
Expand Down Expand Up @@ -93,6 +95,9 @@ class User(SQLModel, table=True):
last_online: datetime | None

clan_links: list[UserClanLink] = Relationship(back_populates="user")
email_token: Optional["EmailToken"] = Relationship(
back_populates="user", sa_relationship_kwargs={"uselist": False}
)
skin_links: list[UserSkinLink] = Relationship(back_populates="user")
servers: list["Server"] = Relationship(back_populates="user")

Expand Down Expand Up @@ -142,6 +147,22 @@ class Server(SQLModel, table=True):
user: User = Relationship(back_populates="servers")


class EmailToken(SQLModel, table=True):
created: datetime = Field(default_factory=datetime.utcnow, nullable=False)
user_id: int = Field(default=None, primary_key=True, foreign_key="user.id")
key: str = Field(default_factory=lambda: EmailToken.new_key())

user: User = Relationship(
back_populates="email_token", sa_relationship_kwargs={"uselist": False}
)

@staticmethod
def new_key():
key_length = 6
chars = string.ascii_uppercase.replace("I", "").replace("L", "")
return "".join(random.choice(chars) for i in range(key_length))


#########
# Match #
#########
Expand Down
40 changes: 1 addition & 39 deletions metaserver/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@
import requests

from metaserver import constants
import metaserver.database.api as db

# Create a new SES resource and specify a region.
ses = boto3.client("ses", region_name="eu-central-1")

CACHE_TTL = 60 * 5
TOKEN_CACHE = TTLCache(maxsize=10_000, ttl=CACHE_TTL)
TOKEN_CACHE_REVERSE = TTLCache(maxsize=10_000, ttl=CACHE_TTL)
GENERATION_TIME_FOR_USER_ID = TTLCache(maxsize=10_000, ttl=CACHE_TTL)


class DomainBlacklist:
def __contains__(self, key):
Expand All @@ -36,40 +32,6 @@ class DomainBlackListError(ValueError):
domain_blacklist = DomainBlacklist()


def generate_token(user_id: int) -> str:
# Invalidate old token if applicable
try:
old_token = TOKEN_CACHE_REVERSE[user_id]
del TOKEN_CACHE[old_token]
del TOKEN_CACHE_REVERSE[user_id]
except KeyError:
pass

# Generate new token
token = secrets.token_urlsafe(4).upper() # 4 bytes -> token has length 6
TOKEN_CACHE[token] = user_id
TOKEN_CACHE_REVERSE[user_id] = token
GENERATION_TIME_FOR_USER_ID[user_id] = time.monotonic()
return token


def verify_token(user_id: int, token: str) -> bool:
return TOKEN_CACHE[token] == user_id


def get_user_id_for_verification_token(token: str) -> int:
user_id = TOKEN_CACHE[token]
return user_id


def get_token_age_for_user(user_id: int) -> int | float:
"""Get how long ago a token was generated in seconds."""
try:
return int(time.monotonic() - GENERATION_TIME_FOR_USER_ID[user_id])
except KeyError:
return float("inf")


def send_verification_email(recipient: str, token: str):
charset = "UTF-8"
sender = "Savage Community Server <noreply@community-server.info>"
Expand Down
5 changes: 3 additions & 2 deletions tests/test_clans.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def test_clan_registration(client: TestClient, user: dict, clan_icon: str):
# Register a new clan.
clan_name, clan_tag = "Zaitev's Snore Club", "Zzz"
clan_name, clan_tag = "Zaitev's Snore Club", "^123(Zzz"
response = client.post(
"/v1/clan/register",
json=dict(tag=clan_tag, name=clan_name, icon=clan_icon),
Expand Down Expand Up @@ -48,7 +48,8 @@ def test_clan_registration(client: TestClient, user: dict, clan_icon: str):
assert response[0] == clan
assert len(response) == 1

clan2 = dict(tag=clan_tag + "2", name=clan_name + "2", icon=clan_icon)
clan_tag2 = clan_tag[:-1] + "2"
clan2 = dict(tag=clan_tag2, name=clan_name + "2", icon=clan_icon)
response = client.post(
"/v1/clan/register",
json=clan2,
Expand Down
13 changes: 8 additions & 5 deletions tests/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from metaserver import email

from tests import utils


def test_user_registration(client: TestClient):
username = "foo@example.com"
Expand Down Expand Up @@ -41,7 +43,7 @@ def test_user_registration(client: TestClient):
assert response.status_code == 409

# Verify the user's email
mail_token = email.TOKEN_CACHE_REVERSE[user["id"]]
mail_token = utils.get_email_token_for_user_id(user["id"])
response = client.post(
"/v1/user/email/verify",
json=dict(mail_token=mail_token),
Expand Down Expand Up @@ -128,12 +130,13 @@ def test_user_mail_tokens(client: TestClient):
assert "wait" in response.json()["detail"].lower()

# Request new token
old_mail_token = email.TOKEN_CACHE_REVERSE[user["id"]]
get_big_number = lambda user_id: 1_000_000
email.get_token_age_for_user = get_big_number
old_mail_token = utils.get_email_token_for_user_id(user["id"])
utils.set_email_token_created_for_user_id_to_last_year(user["id"])
response = client.post("/v1/user/email/renew-token", auth=auth)
assert response.status_code == 200
new_mail_token = email.TOKEN_CACHE_REVERSE[user["id"]]

new_mail_token = utils.get_email_token_for_user_id(user["id"])
assert old_mail_token != new_mail_token

# Try to verify with old token
response = client.post(
Expand Down
28 changes: 27 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from datetime import datetime, timedelta

from fastapi.testclient import TestClient
from sqlmodel import select

from metaserver import email
import metaserver.database.api as db
from metaserver.database.models import EmailToken


def dict_without_key(d, k):
Expand All @@ -15,7 +20,12 @@ def register_user(client: TestClient, display_name: str, username: str, password
json=dict(username=username, display_name=display_name, password=password),
).json()

mail_token = email.TOKEN_CACHE_REVERSE[user["id"]]
session = next(db.get_session())
mail_token = (
session.exec(select(EmailToken).where(EmailToken.user_id == user["id"]))
.one()
.key
)

user = client.post(
"/v1/user/email/verify",
Expand All @@ -26,3 +36,19 @@ def register_user(client: TestClient, display_name: str, username: str, password
# Add credentials for testing purposes.
user["auth"] = (username, password)
return user


def get_email_token_for_user_id(user_id: int):
session = next(db.get_session())
return (
session.exec(select(EmailToken).where(EmailToken.user_id == user_id)).one().key
)


def set_email_token_created_for_user_id_to_last_year(user_id: int):
session = next(db.get_session())
email_token = session.exec(
select(EmailToken).where(EmailToken.user_id == user_id)
).one()
email_token.created = datetime.utcnow() - timedelta(days=365)
db.commit_and_refresh(session, email_token)

0 comments on commit 5a22c0c

Please sign in to comment.