Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update migration command to enable/disable transactions #828

Merged
merged 2 commits into from Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion beanie/executors/migrate.py
Expand Up @@ -53,6 +53,7 @@ def __init__(self, **kwargs):
or self.get_from_toml("allow_index_dropping")
or False
)
self.use_transaction = bool(kwargs.get("use_transaction"))

@staticmethod
def get_env_value(field_name) -> Any:
Expand Down Expand Up @@ -111,7 +112,9 @@ async def run_migrate(settings: MigrationSettings):
direction=settings.direction, distance=settings.distance
)
await root.run(
mode=mode, allow_index_dropping=settings.allow_index_dropping
mode=mode,
allow_index_dropping=settings.allow_index_dropping,
use_transaction=settings.use_transaction,
)


Expand Down Expand Up @@ -160,13 +163,23 @@ async def run_migrate(settings: MigrationSettings):
default=False,
help="if allow-index-dropping is set, Beanie will drop indexes from your collection",
)
@click.option(
"--use-transaction/--no-use-transaction",
required=False,
default=True,
help="Enable or disable the use of transactions during migration. "
"When enabled (--use-transaction), Beanie uses transactions for migration, "
"which necessitates a replica set. When disabled (--no-use-transaction), "
"migrations occur without transactions.",
)
def migrate(
direction,
distance,
connection_uri,
database_name,
path,
allow_index_dropping,
use_transaction,
):
settings_kwargs = {}
if direction:
Expand All @@ -181,6 +194,7 @@ def migrate(
settings_kwargs["path"] = path
if allow_index_dropping:
settings_kwargs["allow_index_dropping"] = allow_index_dropping
settings_kwargs["use_transaction"] = use_transaction
settings = MigrationSettings(**settings_kwargs)

asyncio.run(run_migrate(settings))
Expand Down
3 changes: 3 additions & 0 deletions beanie/migrations/controllers/base.py
Expand Up @@ -5,6 +5,9 @@


class BaseMigrationController(ABC):
def __init__(self, function):
self.function = function

@abstractmethod
async def run(self, session):
pass
Expand Down
83 changes: 60 additions & 23 deletions beanie/migrations/runner.py
@@ -1,7 +1,9 @@
import logging
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import Optional, Type
from typing import List, Optional, Type

from motor.motor_asyncio import AsyncIOMotorClientSession, AsyncIOMotorDatabase

from beanie.migrations.controllers.iterative import BaseMigrationController
from beanie.migrations.database import DBHandler
Expand Down Expand Up @@ -55,7 +57,12 @@ async def update_current_migration(self):
await self.clean_current_migration()
await MigrationLog(is_current=True, name=self.name).insert()

async def run(self, mode: RunningMode, allow_index_dropping: bool):
async def run(
self,
mode: RunningMode,
allow_index_dropping: bool,
use_transaction: bool,
):
"""
Migrate
Expand All @@ -71,7 +78,8 @@ async def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info("Running migrations forward without limit")
while True:
await migration_node.run_forward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.next_migration
if migration_node is None:
Expand All @@ -80,7 +88,8 @@ async def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info(f"Running {mode.distance} migrations forward")
for i in range(mode.distance):
await migration_node.run_forward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.next_migration
if migration_node is None:
Expand All @@ -91,7 +100,8 @@ async def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info("Running migrations backward without limit")
while True:
await migration_node.run_backward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.prev_migration
if migration_node is None:
Expand All @@ -100,30 +110,41 @@ async def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info(f"Running {mode.distance} migrations backward")
for i in range(mode.distance):
await migration_node.run_backward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.prev_migration
if migration_node is None:
break

async def run_forward(self, allow_index_dropping):
async def run_forward(
self, allow_index_dropping: bool, use_transaction: bool
):
if self.forward_class is not None:
await self.run_migration_class(
self.forward_class, allow_index_dropping=allow_index_dropping
self.forward_class,
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
await self.update_current_migration()

async def run_backward(self, allow_index_dropping):
async def run_backward(
self, allow_index_dropping: bool, use_transaction: bool
):
if self.backward_class is not None:
await self.run_migration_class(
self.backward_class, allow_index_dropping=allow_index_dropping
self.backward_class,
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
if self.prev_migration is not None:
await self.prev_migration.update_current_migration()
else:
await self.clean_current_migration()

async def run_migration_class(self, cls: Type, allow_index_dropping: bool):
async def run_migration_class(
self, cls: Type, allow_index_dropping: bool, use_transaction: bool
):
"""
Run Backward or Forward migration class
Expand All @@ -142,19 +163,35 @@ async def run_migration_class(self, cls: Type, allow_index_dropping: bool):
if client is None:
raise RuntimeError("client must not be None")
async with await client.start_session() as s:
async with s.start_transaction():
for migration in migrations:
for model in migration.models:
await init_beanie(
database=db,
document_models=[model], # type: ignore
allow_index_dropping=allow_index_dropping,
) # TODO this is slow
logger.info(
f"Running migration {migration.function.__name__} "
f"from module {self.name}"
if use_transaction:
async with s.start_transaction():
await self.run_migrations(
migrations, db, allow_index_dropping, s
)
await migration.run(session=s)
else:
await self.run_migrations(
migrations, db, allow_index_dropping, s
)

async def run_migrations(
self,
migrations: List[BaseMigrationController],
db: AsyncIOMotorDatabase,
allow_index_dropping: bool,
session: AsyncIOMotorClientSession,
) -> None:
for migration in migrations:
for model in migration.models:
await init_beanie(
database=db,
document_models=[model], # type: ignore
allow_index_dropping=allow_index_dropping,
) # TODO this is slow
logger.info(
f"Running migration {migration.function.__name__} "
f"from module {self.name}"
)
await migration.run(session=session)

@classmethod
async def build(cls, path: Path):
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorial/migrations.md
@@ -1,6 +1,5 @@
## Attention!

Migrations use transactions inside. They work only with **MongoDB replica sets**

## Create

Expand All @@ -17,6 +16,8 @@ Each one contains instructions to roll migration respectively forward and backwa

## Run

**Attention**: By default, migrations use transactions. This approach only works with **MongoDB replica sets**. If you prefer to run migrations without transactions, pass the `--no-use-transaction` flag to the `migrate` command. However, be aware that this approach is risky, as there is no way to roll back migrations without transactions.

To roll one forward migration, run:

```shell
Expand All @@ -26,7 +27,7 @@ beanie migrate -uri 'mongodb+srv://user:pass@host/db' -p relative/path/to/migrat
To roll all forward migrations, run:

```shell
beanie migrate -uri 'mongodb+srv://user:pass@host/db' -p relative/path/to/migrations/directory/
beanie migrate -uri 'mongodb://user:pass@host' -db db -p relative/path/to/migrations/directory/
```

To roll one backward migration, run:
Expand Down
23 changes: 23 additions & 0 deletions tests/migrations/test_free_fall.py
Expand Up @@ -65,3 +65,26 @@ async def test_migration_free_fall(settings, notes, db):
assert inspection.status == InspectionStatuses.OK
note = await OldNote.find_one({})
assert note.name == "0"


async def test_migration_free_fall_no_use_transactions(settings, notes, db):
migration_settings = MigrationSettings(
connection_uri=settings.mongodb_dsn,
database_name=settings.mongodb_db_name,
path="tests/migrations/migrations_for_test/free_fall",
use_transaction=False,
)
await run_migrate(migration_settings)

await init_beanie(database=db, document_models=[Note])
inspection = await Note.inspect_collection()
assert inspection.status == InspectionStatuses.OK
note = await Note.find_one({})
assert note.title == "0"

migration_settings.direction = RunningDirections.BACKWARD
await run_migrate(migration_settings)
inspection = await OldNote.inspect_collection()
assert inspection.status == InspectionStatuses.OK
note = await OldNote.find_one({})
assert note.name == "0"