# Make Lakebase DB and Langraph tables for checkpointing

In [0]:
%pip install --upgrade databricks-sdk
%pip install langgraph==0.3.4 langgraph-checkpoint-postgres
%restart_python

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.database import DatabaseInstance

# Name of the database instance to create
INSTANCE_NAME = "stateful-agent-backend"  # Change here

# Initialize the Workspace client
w = WorkspaceClient()

# get current username
username = spark.sql("SELECT current_user()").collect()[0][0]

See the [documentation](https://learn.microsoft.com/en-gb/azure/databricks/oltp/create/) for more detail.

In [0]:
# Create a database instance (only run once)
instance = w.database.create_database_instance(
   DatabaseInstance(
       name=INSTANCE_NAME,
       capacity="CU_1"
   )
)
print(f"Created database instance: {instance.name}")

### Note the compute may take a little while to spin up...
 - once ready we can check it (print its version)

In [0]:
import uuid
instance_name = INSTANCE_NAME

cred = w.database.generate_database_credential(request_id=str(uuid.uuid4()), instance_names=[instance_name])
instance = w.database.get_database_instance(name=instance_name)

In [0]:
username

In [0]:
import psycopg2

# Connection parameters
conn = psycopg2.connect(
    host = instance.read_write_dns,
    dbname = "databricks_postgres",
    user = username,
    password = cred.token,
    sslmode = "require"
)

# Execute query
with conn.cursor() as cur:
    cur.execute("SELECT version()")
    version = cur.fetchone()[0]
    print(version)
conn.close()

### Make tables for langgraph

In [0]:
instance.read_write_dns

In [0]:
from langgraph.checkpoint.postgres import PostgresSaver

DB_URI = f"postgresql://{username.replace('@','%40')}:{cred.token}@{instance.read_write_dns}:5432/databricks_postgres?sslmode=require"

checkpointer = PostgresSaver.from_conn_string(DB_URI)

with PostgresSaver.from_conn_string(DB_URI) as checkpointer:
    checkpointer.setup()

Check if the tables are there.

In [0]:
import psycopg

with psycopg.connect(DB_URI) as conn:
    with conn.cursor() as cur:
        cur.execute("""
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema = 'public'
        """)
        tables = cur.fetchall()
        for table in tables:
            print(table[0])

In [0]:
with psycopg.connect(DB_URI) as conn:
    with conn.cursor() as cur:
        cur.execute("""
            SELECT column_name, data_type, is_nullable, column_default
            FROM information_schema.columns
            WHERE table_schema = 'public' AND table_name = %s
            ORDER BY ordinal_position
        """, ("checkpoints",))
        columns = cur.fetchall()
        for col in columns:
            print(col)