# Async SQLAlchemy with FastAPI

Learn how to use SQLAlchemy 2.0 with async/await for database operations.

## 1. Async SQLAlchemy Setup

Configure async engine and session factory.

In [None]:
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, Integer, String, Boolean, DateTime, func
from datetime import datetime
import asyncio

# Async database URL (using SQLite for this demo)
DATABASE_URL = "sqlite+aiosqlite:///:memory:"

# Create async engine
engine = create_async_engine(
    DATABASE_URL,
    echo=False,  # Set to True to see SQL queries
    future=True
)

# Create async session factory
AsyncSessionLocal = sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False  # Keep objects after commit
)

# Base class for all models
Base = declarative_base()

print("✅ Async SQLAlchemy configured")
print(f"   Engine: {DATABASE_URL}")
print(f"   Session: AsyncSessionLocal")
print()

## 2. Define Models with Modern Type Hints

SQLAlchemy 2.0 with Mapped[] type hints.

In [None]:
from sqlalchemy.orm import Mapped, mapped_column, relationship
from typing import List
from sqlalchemy import ForeignKey

# User Model
class User(Base):
    __tablename__ = "users"
    
    id: Mapped[int] = mapped_column(primary_key=True)
    email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
    username: Mapped[str] = mapped_column(String(100), unique=True, index=True)
    hashed_password: Mapped[str] = mapped_column(String(255))
    is_active: Mapped[bool] = mapped_column(Boolean, default=True)
    created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
    
    # Relationship
    models: Mapped[List["MLModel"]] = relationship(
        back_populates="owner",
        cascade="all, delete-orphan"
    )
    
    def __repr__(self):
        return f"<User(id={self.id}, username={self.username})>"

# MLModel Model
class MLModel(Base):
    __tablename__ = "ml_models"
    
    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str] = mapped_column(String(255), index=True)
    framework: Mapped[str] = mapped_column(String(50))  # sklearn, pytorch, etc.
    accuracy: Mapped[float | None] = mapped_column(nullable=True)
    owner_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
    created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.now)
    
    # Relationship
    owner: Mapped["User"] = relationship(back_populates="models")
    
    def __repr__(self):
        return f"<MLModel(id={self.id}, name={self.name}, framework={self.framework})>"

print("✅ Models defined:")
print(f"   - User (id, email, username, is_active, created_at)")
print(f"   - MLModel (id, name, framework, accuracy, owner_id)")
print()

## 3. Create Tables

Set up database schema asynchronously.

In [None]:
async def init_db():
    """Create all tables in the database."""
    async with engine.begin() as conn:
        # Create all tables defined in Base.metadata
        await conn.run_sync(Base.metadata.create_all)
    print("✅ Database tables created")

# Run the async function
await init_db()
print()

## 4. CREATE - Insert Data Asynchronously

Add new records to the database.

In [None]:
async def create_user(email: str, username: str, password_hash: str) -> User:
    """Create a new user in the database."""
    async with AsyncSessionLocal() as session:
        # Create user object
        user = User(
            email=email,
            username=username,
            hashed_password=password_hash,
            is_active=True
        )
        
        # Add to session
        session.add(user)
        
        # Commit transaction
        await session.commit()
        
        # Refresh to get auto-generated ID
        await session.refresh(user)
        
        return user

# Create users
user1 = await create_user(
    email="alice@example.com",
    username="alice",
    password_hash="$2b$12$..."
)
print(f"✅ Created: {user1}")

user2 = await create_user(
    email="bob@example.com",
    username="bob",
    password_hash="$2b$12$..."
)
print(f"✅ Created: {user2}")
print()

## 5. READ - Query Data Asynchronously

Retrieve records from the database.

In [None]:
from sqlalchemy import select

# Get single user by ID
async def get_user_by_id(user_id: int) -> User | None:
    """Get user by ID."""
    async with AsyncSessionLocal() as session:
        # Using the modern select() syntax
        stmt = select(User).where(User.id == user_id)
        result = await session.execute(stmt)
        return result.scalars().first()

# Get user
user = await get_user_by_id(1)
if user:
    print(f"✅ Found: {user}")
    print(f"   Email: {user.email}")
    print(f"   Created: {user.created_at}")
else:
    print("User not found")
print()

In [None]:
# Get user by email
async def get_user_by_email(email: str) -> User | None:
    """Get user by email."""
    async with AsyncSessionLocal() as session:
        stmt = select(User).where(User.email == email)
        result = await session.execute(stmt)
        return result.scalars().first()

user = await get_user_by_email("bob@example.com")
if user:
    print(f"✅ Found: {user}")
print()

In [None]:
# Get all users
async def get_all_users() -> list[User]:
    """Get all users."""
    async with AsyncSessionLocal() as session:
        stmt = select(User).order_by(User.created_at.desc())
        result = await session.execute(stmt)
        return result.scalars().all()

users = await get_all_users()
print(f"✅ Total users: {len(users)}")
for user in users:
    print(f"   - {user.username} ({user.email})")
print()

## 6. UPDATE - Modify Data Asynchronously

Change existing records.

In [None]:
async def update_user(user_id: int, **kwargs) -> User | None:
    """Update user fields."""
    async with AsyncSessionLocal() as session:
        # Get user
        stmt = select(User).where(User.id == user_id)
        result = await session.execute(stmt)
        user = result.scalars().first()
        
        if not user:
            return None
        
        # Update fields
        for key, value in kwargs.items():
            if hasattr(user, key):
                setattr(user, key, value)
        
        # Save changes
        await session.commit()
        await session.refresh(user)
        
        return user

print("Before update:")
user = await get_user_by_id(1)
print(f"  is_active: {user.is_active}")
print()

print("Updating user...")
updated = await update_user(1, is_active=False)
print("After update:")
print(f"  is_active: {updated.is_active}")
print()

## 7. DELETE - Remove Data Asynchronously

Delete records from the database.

In [None]:
async def delete_user(user_id: int) -> bool:
    """Delete a user by ID."""
    async with AsyncSessionLocal() as session:
        # Get user
        stmt = select(User).where(User.id == user_id)
        result = await session.execute(stmt)
        user = result.scalars().first()
        
        if not user:
            return False
        
        # Delete
        await session.delete(user)
        await session.commit()
        
        return True

print("Users before deletion:")
users = await get_all_users()
print(f"  Count: {len(users)}")
print()

print("Deleting user with id=2...")
await delete_user(2)
print()

print("Users after deletion:")
users = await get_all_users()
print(f"  Count: {len(users)}")
for user in users:
    print(f"  - {user.username}")
print()

## 8. Relationships and Joins

Work with foreign keys and related objects.

In [None]:
# Create a model for a user
async def create_model_for_user(user_id: int, name: str, framework: str, accuracy: float = None) -> MLModel:
    """Create a model belonging to a user."""
    async with AsyncSessionLocal() as session:
        model = MLModel(
            name=name,
            framework=framework,
            accuracy=accuracy,
            owner_id=user_id
        )
        session.add(model)
        await session.commit()
        await session.refresh(model)
        return model

# Create models for user 1
model1 = await create_model_for_user(1, "Fraud Detector", "sklearn", 0.95)
model2 = await create_model_for_user(1, "Recommendation Engine", "pytorch", 0.92)
print(f"✅ Created models:")
print(f"   - {model1}")
print(f"   - {model2}")
print()

In [None]:
# Load user with their models
async def get_user_with_models(user_id: int) -> User | None:
    """Get user with all their models."""
    async with AsyncSessionLocal() as session:
        stmt = select(User).where(User.id == user_id)
        result = await session.execute(stmt)
        user = result.scalars().first()
        
        if user:
            # Access related models (lazy loaded in sync, eager in async with joinedload)
            print(f"User: {user.username}")
            print(f"Models ({len(user.models)} total):")
            for model in user.models:
                print(f"  - {model.name} ({model.framework}) - Accuracy: {model.accuracy}")
        
        return user

await get_user_with_models(1)
print()

## 9. Filtering and Pagination

Query with WHERE clauses, ordering, and pagination.

In [None]:
# Query with filters
async def get_models_by_framework(framework: str) -> list[MLModel]:
    """Get all models using a specific framework."""
    async with AsyncSessionLocal() as session:
        stmt = select(MLModel).where(MLModel.framework == framework).order_by(MLModel.created_at.desc())
        result = await session.execute(stmt)
        return result.scalars().all()

models = await get_models_by_framework("sklearn")
print(f"Models using sklearn: {len(models)}")
for model in models:
    print(f"  - {model.name}")
print()

In [None]:
# Pagination
async def get_models_paginated(skip: int = 0, limit: int = 10) -> tuple[list[MLModel], int]:
    """Get paginated models."""
    async with AsyncSessionLocal() as session:
        # Get total count
        count_stmt = select(func.count()).select_from(MLModel)
        count_result = await session.execute(count_stmt)
        total = count_result.scalar()
        
        # Get paginated results
        stmt = select(MLModel).offset(skip).limit(limit).order_by(MLModel.created_at.desc())
        result = await session.execute(stmt)
        models = result.scalars().all()
        
        return models, total

models, total = await get_models_paginated(skip=0, limit=2)
print(f"Page 1 (limit=2): {len(models)} models out of {total} total")
for model in models:
    print(f"  - {model.name}")
print()

## 10. Advanced Queries

Complex queries with conditions and aggregations.

In [None]:
from sqlalchemy import and_, or_

# Complex filter conditions
async def get_high_accuracy_models(min_accuracy: float = 0.9) -> list[MLModel]:
    """Get models with accuracy >= min_accuracy."""
    async with AsyncSessionLocal() as session:
        stmt = (
            select(MLModel)
            .where(
                and_(
                    MLModel.accuracy >= min_accuracy,
                    MLModel.accuracy.is_not(None)  # Not NULL
                )
            )
            .order_by(MLModel.accuracy.desc())
        )
        result = await session.execute(stmt)
        return result.scalars().all()

models = await get_high_accuracy_models(min_accuracy=0.90)
print(f"High accuracy models (≥0.90): {len(models)}")
for model in models:
    print(f"  - {model.name}: {model.accuracy}")
print()

## 11. Database Dependency for FastAPI

Create a dependency function for use in route handlers.

In [None]:
# This is how you'd use it in FastAPI
print("FastAPI Database Dependency Pattern:")
print()
print("""@app.get("/users/{user_id}")
async def get_user(user_id: int, db: AsyncSession = Depends(get_db)):
    stmt = select(User).where(User.id == user_id)
    result = await db.execute(stmt)
    user = result.scalars().first()
    if not user:
        raise HTTPException(status_code=404, detail="User not found")
    return user
""")
print()
print("Where get_db() is defined as:")
print()
print("""async def get_db():
    async with AsyncSessionLocal() as session:
        yield session
""")

## Summary

**Async SQLAlchemy Key Concepts:**
- ✅ create_async_engine() for non-blocking database access
- ✅ AsyncSessionLocal for session management
- ✅ select() for type-safe query building
- ✅ Mapped[] type hints for modern Python syntax
- ✅ Relationships with cascade delete
- ✅ CRUD operations with async/await

**Best Practices:**
1. Use context managers (async with) for automatic cleanup
2. Always await database operations
3. Refresh objects after commit to get auto-generated IDs
4. Use select() instead of filter() for type safety
5. Define repositories/functions for database access
6. Use FastAPI's Depends() for dependency injection

**Common Patterns:**
- One session per request (Depends(get_db))
- Always close sessions with async context managers
- Use eager loading for related objects if needed
- Implement pagination for large result sets