Skip to content

Commit

Permalink
Merge pull request from GHSA-xq59-7jf3-rjc6
Browse files Browse the repository at this point in the history
Co-authored-by: skelmis <ethan.mckee-harris@zxsecurity.co.nz>
  • Loading branch information
dantownsend and Skelmis committed Nov 10, 2023
1 parent e4946d8 commit 82679eb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
15 changes: 15 additions & 0 deletions piccolo/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextvars
import logging
import pprint
import string
import typing as t
from abc import ABCMeta, abstractmethod

Expand All @@ -15,6 +16,20 @@


logger = logging.getLogger(__name__)
# This is a set to speed up lookups from O(n) when
# using str vs O(1) when using set[str]
VALID_SAVEPOINT_CHARACTERS: t.Final[set[str]] = set(
string.ascii_letters + string.digits + "-" + "_"
)


def validate_savepoint_name(savepoint_name: str) -> None:
"""Validates a save point's name meets the required character set."""
if not all(i in VALID_SAVEPOINT_CHARACTERS for i in savepoint_name):
raise ValueError(
"Savepoint names can only contain the following characters:"
f" {VALID_SAVEPOINT_CHARACTERS}"
)


class Batch:
Expand Down
5 changes: 4 additions & 1 deletion piccolo/engine/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing as t
from dataclasses import dataclass

from piccolo.engine.base import Batch, Engine
from piccolo.engine.base import Batch, Engine, validate_savepoint_name
from piccolo.engine.exceptions import TransactionError
from piccolo.query.base import DDL, Query
from piccolo.querystring import QueryString
Expand Down Expand Up @@ -129,11 +129,13 @@ def __init__(self, name: str, transaction: PostgresTransaction):
self.transaction = transaction

async def rollback_to(self):
validate_savepoint_name(self.name)
await self.transaction.connection.execute(
f"ROLLBACK TO SAVEPOINT {self.name}"
)

async def release(self):
validate_savepoint_name(self.name)
await self.transaction.connection.execute(
f"RELEASE SAVEPOINT {self.name}"
)
Expand Down Expand Up @@ -236,6 +238,7 @@ def get_savepoint_id(self) -> int:

async def savepoint(self, name: t.Optional[str] = None) -> Savepoint:
name = name or f"savepoint_{self.get_savepoint_id()}"
validate_savepoint_name(name)
await self.connection.execute(f"SAVEPOINT {name}")
return Savepoint(name=name, transaction=self)

Expand Down
5 changes: 4 additions & 1 deletion piccolo/engine/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from decimal import Decimal

from piccolo.engine.base import Batch, Engine
from piccolo.engine.base import Batch, Engine, validate_savepoint_name
from piccolo.engine.exceptions import TransactionError
from piccolo.query.base import DDL, Query
from piccolo.querystring import QueryString
Expand Down Expand Up @@ -309,11 +309,13 @@ def __init__(self, name: str, transaction: SQLiteTransaction):
self.transaction = transaction

async def rollback_to(self):
validate_savepoint_name(self.name)
await self.transaction.connection.execute(
f"ROLLBACK TO SAVEPOINT {self.name}"
)

async def release(self):
validate_savepoint_name(self.name)
await self.transaction.connection.execute(
f"RELEASE SAVEPOINT {self.name}"
)
Expand Down Expand Up @@ -413,6 +415,7 @@ def get_savepoint_id(self) -> int:

async def savepoint(self, name: t.Optional[str] = None) -> Savepoint:
name = name or f"savepoint_{self.get_savepoint_id()}"
validate_savepoint_name(name)
await self.connection.execute(f"SAVEPOINT {name}")
return Savepoint(name=name, transaction=self)

Expand Down
13 changes: 13 additions & 0 deletions tests/engine/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import typing as t
from unittest import TestCase

import pytest

from piccolo.engine.postgres import Atomic
from piccolo.engine.sqlite import SQLiteEngine, TransactionType
from piccolo.table import drop_db_tables_sync
Expand Down Expand Up @@ -296,3 +298,14 @@ async def run_test():
self.assertListEqual(
Manager.select(Manager.name).run_sync(), [{"name": "Manager 1"}]
)

def test_savepoint_sqli_checks(self):
# Added to test the fix for GHSA-xq59-7jf3-rjc6
async def run_test():
async with Manager._meta.db.transaction() as transaction:
await transaction.savepoint(
"my_savepoint; SELECT * FROM Manager"
)

with pytest.raises(ValueError):
run_sync(run_test())

0 comments on commit 82679eb

Please sign in to comment.