# core

> Core migrator code

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastlite import *
from apswutils.db import Database
import fastlite.kw

In [None]:
db = database(":memory:")

Migrator class adds a `migrations` table to your database that allows you to keep track of your database schema changes:

In [None]:
#| export
class Migrator():
    def __init__(self, db):
        self.migrations, self.rollbacks, self.db = dict(), dict(), db
        db.execute("""
        CREATE TABLE IF NOT EXISTS migrations (
            id INTEGER PRIMARY KEY, name TEXT, inserted_at TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL
        ) STRICT;
        """)
        self.Migration = db.t.migrations.dataclass()
        self.db_migrations = db.t.migrations

In [None]:
Migrator(db)
db.t

migrations

In [None]:
assert len(list(db.t)) == 1

In [None]:
#| export
@patch
def add_migration(self: Migrator, migration_id:int):
    assert type(migration_id) is int, "migration_id must be an integer"
    def decorator(migration:callable):
        assert callable(migration), "migration must be a callable"
        self.migrations[migration_id] = migration
    return decorator

@patch
def add_rollback(self: Migrator, rollback_id:int):
    assert type(rollback_id) is int, "rollback_id must be an integer"
    def decorator(rollback:callable):
        assert callable(rollback), "migration must be a callable"
        self.rollbacks[rollback_id] = rollback
    return decorator

In [None]:
#| export
@patch
def migrate(self: Migrator):
    # TODO: make sure ids are in sequence
    for id, migration in sorted(self.migrations.items()):
        if id in self.db.t.migrations:
            continue
        print(id, migration.__name__)
        self.db_migrations.insert(self.Migration(id=id, name=migration.__name__))
        migration(self.db)

In [None]:
#| export
@patch
def last_applied_migration(self: Migrator):
    return self.db_migrations('id = (SELECT MAX(id) FROM migrations)')[0]

In [None]:
#| export
@patch
def applied_migrations(self: Migrator):
    return self.db_migrations()

In [None]:
#| export
@patch
def rollback(self: Migrator):
    latest_migration = self.last_applied_migration()
    last_id = latest_migration.id
    if last_id not in self.rollbacks:
        print(f"No rollback for the latest applied migration found: {latest_migration}")
        return

    rollback = self.rollbacks[last_id]
    print(last_id, rollback.__name__)

    rollback(self.db)
    self.db_migrations.delete(last_id)

Add migrations by decorating your functions like so:

In [None]:
m = Migrator(db)

@m.add_migration(0)
def init_db(db): db.q("CREATE TABLE cats (name PRIMARY KEY)")

@m.add_migration(1)
def add_dogs(db): db.q("CREATE TABLE dogs (name PRIMARY KEY)")

Running `m.migrate` would apply these migrations in order of `migrations_id`.

In [None]:
m.migrate()

0 init_db
1 add_dogs


In [None]:
db.t

cats, dogs, migrations

In [None]:
#| hide
assert len(list(db.t)) == 3

Running it again does nothing:

In [None]:
m.migrate()

What if you realize there is something wrong with the last migration? You can write a rollback function to fix it!
Make sure `rollback_id` matches corresponding `migration_id`

In [None]:
@m.add_rollback(1)
def remove_dogs(db): db.q("DROP TABLE dogs")

In [None]:
m.applied_migrations()

[Migrations(id=0, name='init_db', inserted_at='2025-03-09 17:06:45'),
 Migrations(id=1, name='add_dogs', inserted_at='2025-03-09 17:06:45')]

In [None]:
m.last_applied_migration()

Migrations(id=1, name='add_dogs', inserted_at='2025-03-09 17:06:45')

In [None]:
m.rollback()

1 remove_dogs


In [None]:
m.applied_migrations()

[Migrations(id=0, name='init_db', inserted_at='2025-03-09 17:06:45')]

In [None]:
db.t

cats, migrations

In [None]:
assert len(m.applied_migrations()) == 1
assert len(list(db.t)) == 2

## Patch database directly

For simpler API, let's patch add everything directly to `database`!

In [None]:
#| export
_orig_database = database

In [None]:
#| export
def database(path, wal=True):
    db = _orig_database(path, wal)
    db.migrator = Migrator(db)
    # Avoid name collision with transaction rollback
    db.rollback_migration = db.migrator.rollback

    # Everything else is ok to copy automatically
    #names = [x for x in dir(Migrator) if "__" not in x and "rollback" not in x] + ["add_rollback"]    
    db.add_migration = db.migrator.add_migration
    db.add_rollback = db.migrator.add_rollback
    db.applied_migrations = db.migrator.applied_migrations
    db.last_applied_migration = db.migrator.last_applied_migration
    db.migrate = db.migrator.migrate

    return db

In [None]:
db = database(":memory:")

In [None]:
@db.add_migration(0)
def init_db(db): db.q("CREATE TABLE cats (name PRIMARY KEY)")

In [None]:
db.migrate()

0 init_db


In [None]:
db.rollback_migration()

No rollback for the latest applied migration found: Migrations(id=0, name='init_db', inserted_at='2025-03-09 17:48:14')


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()