Skip to content

Commit

Permalink
[v5.1.0] Refactor the rest of the repositories to use sqlalchemy core…
Browse files Browse the repository at this point in the history
… 1.4 (#632)

* Refactor the rest of the repositories to use sqlalchemy core 1.4

Co-authored-by: James Wilson <tsunyoku@users.noreply.github.com>

* fix types & nullables

* fmt

* add & migrate to thin database adapter, ensure all queries are compatible with sqlalchemy 2.0

* bugfix startup process

* Add `render_postcompile` flag to support `IN` clauses

* log db queries in debug mode

* remove default from logs.time (doesn't exist)

* Fix bugs in favorites table repo

* Bump minor version -- to 5.1.0

---------

Co-authored-by: James Wilson <tsunyoku@users.noreply.github.com>
Co-authored-by: tsunyoku <tsunyoku@gmail.com>
  • Loading branch information
3 people committed Feb 26, 2024
1 parent 0cc409f commit 26b8595
Show file tree
Hide file tree
Showing 30 changed files with 1,353 additions and 1,656 deletions.
Empty file added app/adapters/__init__.py
Empty file.
100 changes: 100 additions & 0 deletions app/adapters/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from typing import Any
from typing import cast

from databases import Database as _Database
from databases.core import Transaction
from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.expression import ClauseElement

from app import settings


class MySQLDialect(MySQLDialect_mysqldb):
default_paramstyle = "named"


DIALECT = MySQLDialect()

MySQLRow = dict[str, Any]
MySQLParams = dict[str, Any] | None
MySQLQuery = ClauseElement | str


class Database:
def __init__(self, url: str) -> None:
self._database = _Database(url)

async def connect(self) -> None:
await self._database.connect()

async def disconnect(self) -> None:
await self._database.disconnect()

def _compile(self, clause_element: ClauseElement) -> tuple[str, MySQLParams]:
compiled: Compiled = clause_element.compile(
dialect=DIALECT,
compile_kwargs={"render_postcompile": True},
)
if settings.DEBUG:
print(str(compiled), compiled.params)
return str(compiled), compiled.params

async def fetch_one(
self,
query: MySQLQuery,
params: MySQLParams = None,
) -> MySQLRow | None:
if isinstance(query, ClauseElement):
query, params = self._compile(query)

row = await self._database.fetch_one(query, params)
return dict(row._mapping) if row is not None else None

async def fetch_all(
self,
query: MySQLQuery,
params: MySQLParams = None,
) -> list[MySQLRow]:
if isinstance(query, ClauseElement):
query, params = self._compile(query)

rows = await self._database.fetch_all(query, params)
return [dict(row._mapping) for row in rows]

async def fetch_val(
self,
query: MySQLQuery,
params: MySQLParams = None,
column: Any = 0,
) -> Any:
if isinstance(query, ClauseElement):
query, params = self._compile(query)

val = await self._database.fetch_val(query, params, column)
return val

async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int:
if isinstance(query, ClauseElement):
query, params = self._compile(query)

rec_id = await self._database.execute(query, params)
return cast(int, rec_id)

# NOTE: this accepts str since current execute_many uses are not using alchemy.
# alchemy does execute_many in a single query so this method will be unneeded once raw SQL is not in use.
async def execute_many(self, query: str, params: list[MySQLParams]) -> None:
if isinstance(query, ClauseElement):
query, _ = self._compile(query)

await self._database.execute_many(query, params)

def transaction(
self,
*,
force_rollback: bool = False,
**kwargs: Any,
) -> Transaction:
return self._database.transaction(force_rollback=force_rollback, **kwargs)
2 changes: 1 addition & 1 deletion app/api/domains/cho.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ async def handle_osu_login_request(
# country wasn't stored on registration.
log(f"Fixing {login_data['username']}'s country.", Ansi.LGREEN)

await users_repo.update(
await users_repo.partial_update(
id=user_info["id"],
country=geoloc["country"]["acronym"],
)
Expand Down
27 changes: 9 additions & 18 deletions app/api/domains/osu.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,17 @@ async def osuSearchSetHandler(
return Response(b"") # invalid args

# Get all set data.
rec = await app.state.services.database.fetch_one(
bmapset = await app.state.services.database.fetch_one(
"SELECT DISTINCT set_id, artist, "
"title, status, creator, last_update "
f"FROM maps WHERE {k} = :v",
{"v": v},
)

if rec is None:
if bmapset is None:
# TODO: get from osu!
return Response(b"")

rating = 10.0 # TODO: real data
bmapset = dict(rec._mapping)

return Response(
(
Expand Down Expand Up @@ -979,7 +977,7 @@ async def osuSubmitModularSelector(

server_achievements = await achievements_usecases.fetch_many()
player_achievements = await user_achievements_usecases.fetch_many(
score.player.id,
user_id=score.player.id,
)

for server_achievement in server_achievements:
Expand Down Expand Up @@ -1184,17 +1182,14 @@ async def get_leaderboard_scores(
# TODO: customizability of the number of scores
query.append("ORDER BY _score DESC LIMIT 50")

score_rows = [
dict(r._mapping)
for r in await app.state.services.database.fetch_all(
" ".join(query),
params,
)
]
score_rows = await app.state.services.database.fetch_all(
" ".join(query),
params,
)

if score_rows: # None or []
# fetch player's personal best score
personal_best_score_rec = await app.state.services.database.fetch_one(
personal_best_score_row = await app.state.services.database.fetch_one(
f"SELECT id, {scoring_metric} AS _score, "
"max_combo, n50, n100, n300, "
"nmiss, nkatu, ngeki, perfect, mods, "
Expand All @@ -1206,9 +1201,7 @@ async def get_leaderboard_scores(
{"map_md5": map_md5, "mode": mode, "user_id": player.id},
)

if personal_best_score_rec is not None:
personal_best_score_row = dict(personal_best_score_rec._mapping)

if personal_best_score_row is not None:
# calculate the rank of the score.
p_best_rank = 1 + await app.state.services.database.fetch_val(
"SELECT COUNT(*) FROM scores s "
Expand All @@ -1226,8 +1219,6 @@ async def get_leaderboard_scores(

# attach rank to personal best row
personal_best_score_row["rank"] = p_best_rank
else:
personal_best_score_row = None
else:
score_rows = []
personal_best_score_row = None
Expand Down
2 changes: 1 addition & 1 deletion app/api/v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ async def api_get_replay(
'attachment; filename="{username} - '
"{artist} - {title} [{version}] "
'({play_time:%Y-%m-%d}).osr"'
).format(**dict(row._mapping)),
).format(**row),
},
)

Expand Down
44 changes: 20 additions & 24 deletions app/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ async def changename(ctx: Context) -> str | None:
return "Username already taken by another player."

# all checks passed, update their name
await users_repo.update(ctx.player.id, name=name)
await users_repo.partial_update(ctx.player.id, name=name)

ctx.player.enqueue(
app.packets.notification(f"Your username has been changed to {name}!"),
Expand Down Expand Up @@ -388,21 +388,17 @@ async def top(ctx: Context) -> str | None:
# !top rx!std
mode = GAMEMODE_REPR_LIST.index(ctx.args[0])

scores = [
dict(s._mapping)
for s in await app.state.services.database.fetch_all(
"SELECT s.pp, b.artist, b.title, b.version, b.set_id map_set_id, b.id map_id "
"FROM scores s "
"LEFT JOIN maps b ON b.md5 = s.map_md5 "
"WHERE s.userid = :user_id "
"AND s.mode = :mode "
"AND s.status = 2 "
"AND b.status in (2, 3) "
"ORDER BY s.pp DESC LIMIT 10",
{"user_id": player.id, "mode": mode},
)
]

scores = await app.state.services.database.fetch_all(
"SELECT s.pp, b.artist, b.title, b.version, b.set_id map_set_id, b.id map_id "
"FROM scores s "
"LEFT JOIN maps b ON b.md5 = s.map_md5 "
"WHERE s.userid = :user_id "
"AND s.mode = :mode "
"AND s.status = 2 "
"AND b.status in (2, 3) "
"ORDER BY s.pp DESC LIMIT 10",
{"user_id": player.id, "mode": mode},
)
if not scores:
return "No scores"

Expand Down Expand Up @@ -564,7 +560,7 @@ async def apikey(ctx: Context) -> str | None:
# generate new token
ctx.player.api_key = str(uuid.uuid4())

await users_repo.update(ctx.player.id, api_key=ctx.player.api_key)
await users_repo.partial_update(ctx.player.id, api_key=ctx.player.api_key)
app.state.sessions.api_keys[ctx.player.api_key] = ctx.player.id

return f"API key generated. Copy your api key from (this url)[http://{ctx.player.api_key}]."
Expand Down Expand Up @@ -654,7 +650,7 @@ async def _map(ctx: Context) -> str | None:
if ctx.args[1] == "set":
# update all maps in the set
for _bmap in bmap.set.maps:
await maps_repo.update(_bmap.id, status=new_status, frozen=True)
await maps_repo.partial_update(_bmap.id, status=new_status, frozen=True)

# make sure cache and db are synced about the newest change
for _bmap in app.state.cache.beatmapset[bmap.set_id].maps:
Expand All @@ -671,7 +667,7 @@ async def _map(ctx: Context) -> str | None:

else:
# update only map
await maps_repo.update(bmap.id, status=new_status, frozen=True)
await maps_repo.partial_update(bmap.id, status=new_status, frozen=True)

# make sure cache and db are synced about the newest change
if bmap.md5 in app.state.cache.beatmap:
Expand Down Expand Up @@ -2326,7 +2322,7 @@ async def clan_create(ctx: Context) -> str | None:
ctx.player.clan_id = new_clan["id"]
ctx.player.clan_priv = ClanPrivileges.Owner

await users_repo.update(
await users_repo.partial_update(
ctx.player.id,
clan_id=new_clan["id"],
clan_priv=ClanPrivileges.Owner,
Expand Down Expand Up @@ -2362,15 +2358,15 @@ async def clan_disband(ctx: Context) -> str | None:
if not clan:
return "You're not a member of a clan!"

await clans_repo.delete(clan["id"])
await clans_repo.delete_one(clan["id"])

# remove all members from the clan
clan_member_ids = [
clan_member["id"]
for clan_member in await users_repo.fetch_many(clan_id=clan["id"])
]
for member_id in clan_member_ids:
await users_repo.update(member_id, clan_id=0, clan_priv=0)
await users_repo.partial_update(member_id, clan_id=0, clan_priv=0)

member = app.state.sessions.players.get(id=member_id)
if member:
Expand Down Expand Up @@ -2423,15 +2419,15 @@ async def clan_leave(ctx: Context) -> str | None:

clan_members = await users_repo.fetch_many(clan_id=clan["id"])

await users_repo.update(ctx.player.id, clan_id=0, clan_priv=0)
await users_repo.partial_update(ctx.player.id, clan_id=0, clan_priv=0)
ctx.player.clan_id = None
ctx.player.clan_priv = None

clan_display_name = f"[{clan['tag']}] {clan['name']}"

if not clan_members:
# no members left, disband clan
await clans_repo.delete(clan["id"])
await clans_repo.delete_one(clan["id"])

# announce clan disbanding
announce_chan = app.state.sessions.channels.get_by_name("#announce")
Expand Down
3 changes: 1 addition & 2 deletions app/objects/beatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,8 +885,7 @@ async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None:
)
.translate(IGNORED_BEATMAP_CHARS)
)

await maps_repo.update(bmap.id, filename=bmap.filename)
await maps_repo.partial_update(bmap.id, filename=bmap.filename)

bmap_set.maps.append(bmap)

Expand Down
12 changes: 6 additions & 6 deletions app/objects/player.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ async def update_privs(self, new: Privileges) -> None:
if "bancho_priv" in vars(self):
del self.bancho_priv # wipe cached_property

await users_repo.update(
await users_repo.partial_update(
id=self.id,
priv=self.priv,
)
Expand All @@ -424,7 +424,7 @@ async def add_privs(self, bits: Privileges) -> None:
if "bancho_priv" in vars(self):
del self.bancho_priv # wipe cached_property

await users_repo.update(
await users_repo.partial_update(
id=self.id,
priv=self.priv,
)
Expand All @@ -441,7 +441,7 @@ async def remove_privs(self, bits: Privileges) -> None:
if "bancho_priv" in vars(self):
del self.bancho_priv # wipe cached_property

await users_repo.update(
await users_repo.partial_update(
id=self.id,
priv=self.priv,
)
Expand Down Expand Up @@ -527,7 +527,7 @@ async def silence(self, admin: Player, duration: float, reason: str) -> None:
"""Silence `self` for `duration` seconds, and log to sql."""
self.silence_end = int(time.time() + duration)

await users_repo.update(
await users_repo.partial_update(
id=self.id,
silence_end=self.silence_end,
)
Expand Down Expand Up @@ -555,7 +555,7 @@ async def unsilence(self, admin: Player, reason: str) -> None:
"""Unsilence `self`, and log to sql."""
self.silence_end = int(time.time())

await users_repo.update(
await users_repo.partial_update(
id=self.id,
silence_end=self.silence_end,
)
Expand Down Expand Up @@ -973,7 +973,7 @@ async def stats_from_sql_full(self) -> None:

def update_latest_activity_soon(self) -> None:
"""Update the player's latest activity in the database."""
task = users_repo.update(
task = users_repo.partial_update(
id=self.id,
latest_activity=int(time.time()),
)
Expand Down
8 changes: 0 additions & 8 deletions app/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.orm import registry

Expand All @@ -14,10 +13,3 @@ class Base(metaclass=DeclarativeMeta):
metadata = mapper_registry.metadata

__init__ = mapper_registry.constructor


class MySQLDialect(MySQLDialect_mysqldb):
default_paramstyle = "named"


DIALECT = MySQLDialect()
Loading

0 comments on commit 26b8595

Please sign in to comment.