In [None]:
import sys
import zipfile


def clear_modules(prefix):
    """Removes modules with the specified prefix from sys.modules."""
    for module_name in list(sys.modules.keys()):
        if module_name.startswith(prefix) and (module_name != "snowflake.connector.nanoarrow_arrow_iterator"):
            del sys.modules[module_name]


clear_modules("snowflake.core")

snowlfake_core_zip_file_name = "snowflake_core.zip"
stage_name = "STAGE_PYTHON_TEST_NOTEBOOK"

from snowflake.snowpark.context import get_active_session

session = get_active_session()

get_status = session.file.get(f"@{stage_name}/zip/{snowlfake_core_zip_file_name}", "/tmp/zip")
if len(get_status) != 1:
    raise Exception("not able to load the snowflake_core")

with zipfile.ZipFile(f"/tmp/zip/{snowlfake_core_zip_file_name}", "r") as zip_ref:
    zip_ref.extractall("/tmp/expanded/snowflake_core")
sys.path.insert(0, "/tmp/expanded/snowflake_core/src")

def get_snowflake_version(cursor):
    return cursor.execute("SELECT CURRENT_VERSION()").fetchone()[0].strip()

def is_prod_version(version_str) -> bool:
    # Check if version string is all digits or decimals, because non-prod versions contain
    # letters or other symbols.
    return version_str and all(character.isdigit() or character == '.' for character in version_str)

if is_prod_version(get_snowflake_version(session.connection.cursor())):
    compute_pool_instance_family = "CPU_X64_XS"
else:
    compute_pool_instance_family = "FAKE"

In [None]:
from datetime import timedelta
# from typing import List

from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.core import Root, CreateMode
from snowflake.core.database import Database
from snowflake.core.schema import Schema
from snowflake.core.stage import Stage
from snowflake.core.table import Table, TableColumn, PrimaryKey
from snowflake.core.task import StoredProcedureCall, Task
from snowflake.core.task.dagv1 import DAGOperation, DAG, DAGTask
from snowflake.core.warehouse import Warehouse
from snowflake.core._common import CreateMode

from snowflake.core import Root

root = Root(session)
root.connection.cursor().execute("USE ROLE ACCOUNTADMIN")

In [None]:
database = root.databases.create(Database(name="PYTHON_API_DB"), mode=CreateMode.or_replace)

schema = database.schemas.create(
    Schema(name="PYTHON_API_SCHEMA"),
    mode=CreateMode.or_replace,
)

In [None]:
table = schema.tables.create(
    Table(
        name="PYTHON_API_TABLE",
        columns=[
            TableColumn(
                name="TEMPERATURE",
                datatype="int",
                nullable=False,
            ),
            TableColumn(
                name="LOCATION",
                datatype="string",
            ),
        ],
    ),
    mode=CreateMode.or_replace,
)

In [None]:
table_details = table.fetch()

In [None]:
table_details.to_dict()

In [None]:
table_details.columns.append(
    TableColumn(
        name="elevation",
        datatype="int",
        nullable=False,
        constraints=[PrimaryKey()],
    )
)

In [None]:
table.create_or_update(table_details)

In [None]:
table.fetch().to_dict()

In [None]:
warehouses = root.warehouses

In [None]:
python_api_wh = Warehouse(
    name="PYTHON_API_WH",
    warehouse_size="SMALL",
    auto_suspend=500,
)

warehouse = warehouses.create(python_api_wh, mode=CreateMode.or_replace)
root.session.use_warehouse("PYTHON_API_WH")

In [None]:
warehouse_details = warehouse.fetch()
warehouse_details.to_dict()

In [None]:
warehouse_list = warehouses.iter(like="PYTHON_API_WH")
result = next(warehouse_list)
result.to_dict()

In [None]:
warehouse = root.warehouses.create(
    Warehouse(
        name="PYTHON_API_WH",
        warehouse_size="LARGE",
        auto_suspend=500,
    ),
    mode=CreateMode.or_replace,
)

In [None]:
warehouse.fetch().size

In [None]:
from snowflake.core import Root
from snowflake.core.database import Database

my_db = Database(name="my_db")
root.databases.create(my_db, mode=CreateMode.or_replace)

In [None]:
from snowflake.core import Root
from snowflake.core.database import Database

my_db = root.databases["my_db"].fetch()
print(my_db.to_dict())

In [None]:
from snowflake.core import Root

databases = root.databases.iter(like="my%")
for database_temp in databases:
    print(database_temp.name)

In [None]:
from snowflake.core import Root
from snowflake.core.schema import Schema

my_schema = Schema(name="my_schema")
root.databases["my_db"].schemas.create(my_schema, mode=CreateMode.or_replace)

In [None]:
from snowflake.core import Root
from snowflake.core.schema import Schema

my_schema = root.databases["my_db"].schemas["my_schema"].fetch()
print(my_schema.to_dict())

In [None]:
from snowflake.core import Root

schema_list = root.databases["my_db"].schemas.iter()
for schema_obj in schema_list:
    print(schema_obj.name)

In [None]:
from snowflake.core import Root
from snowflake.core.table import Table, TableColumn

my_table = Table(
    name="my_table",
    columns=[TableColumn(name="c1", datatype="int", nullable=False), TableColumn(name="c2", datatype="string")],
)
root.databases["my_db"].schemas["my_schema"].tables.create(my_table, mode=CreateMode.or_replace)

In [None]:
from snowflake.core import Root
from snowflake.core.table import Table

my_table = root.databases["my_db"].schemas["my_schema"].tables["my_table"].fetch()
my_table.to_dict()

In [None]:
from snowflake.core import Root
from snowflake.core.table import PrimaryKey, Table, TableColumn

my_table = root.databases["my_db"].schemas["my_schema"].tables["my_table"].fetch()
my_table.columns.append(TableColumn(name="c3", datatype="int", nullable=False, constraints=[PrimaryKey()]))

my_table_res = root.databases["my_db"].schemas["my_schema"].tables["my_table"]
my_table_res.create_or_update(my_table)

In [None]:
from snowflake.core import Root

tables = root.databases["my_db"].schemas["my_schema"].tables.iter(like="my%")
for table_obj in tables:
    print(table_obj.name)

In [None]:
from snowflake.core import Root
from snowflake.core.table import Table

my_table_res = root.databases["my_db"].schemas["my_schema"].tables["my_table"]
# my_table_res.delete()

In [None]:
from snowflake.core import Root
from snowflake.core.warehouse import Warehouse

my_wh = Warehouse(
    name="my_wh",
    warehouse_size="SMALL",
    auto_suspend=600,
)
warehouses = root.warehouses
warehouses.create(my_wh, mode=CreateMode.or_replace)

In [None]:
from snowflake.core import Root
from snowflake.core.warehouse import Warehouse

my_wh = root.warehouses["my_wh"].fetch()
print(my_wh.to_dict())

In [None]:
from snowflake.core import Root
from snowflake.core.warehouse import Warehouse, WarehouseCollection

warehouses: WarehouseCollection = root.warehouses
wh_iter = warehouses.iter(like="my%")  # returns a PagedIter[Warehouse]
for wh_obj in wh_iter:
    print(wh_obj.name)

In [None]:
from snowflake.core import Root
from snowflake.core.warehouse import Warehouse

my_wh_res = root.warehouses["my_wh"]

my_wh_res.suspend()
my_wh_res.resume()
my_wh_res.abort_all_queries()
# my_wh_res.delete()

In [None]:
from snowflake.core.user import User

my_user = User(name="my_user")
root.users.create(my_user, mode=CreateMode.or_replace)

In [None]:
from snowflake.core.user import User

my_user = root.users["my_user"].fetch()
my_user.to_dict()

In [None]:
users = root.users.iter(like="my%")
for user in users:
    print(user.name)

In [None]:
my_user_res = root.users["my_user"]
# my_user_res.delete()

In [None]:
from snowflake.core.role import Role

my_role = Role(name="my_role")
root.roles.create(my_role, mode=CreateMode.or_replace)

In [None]:
from snowflake.core.grant import Grant
from snowflake.core.grant._grantee import Grantees
from snowflake.core.grant._privileges import Privileges
from snowflake.core.grant._securables import Securables

root.grants.grant(Grant(grantee=Grantees.role(name="my_role"), securable=Securables.role("accountadmin")))

root.grants.grant(
    Grant(grantee=Grantees.user(name=session.get_current_user()), securable=Securables.role(name="my_role"))
)

current_role = session.get_current_role()
root.session.use_role("my_role")

In [None]:
role_list = root.roles.iter()
for role_obj in role_list:
    print(role_obj.name)

In [None]:
from snowflake.core.grant import Grant
from snowflake.core.grant._grantee import Grantees
from snowflake.core.grant._privileges import Privileges
from snowflake.core.grant._securables import Securables

root.grants.grant(
    Grant(
        grantee=Grantees.role(name="my_role"),
        securable=Securables.current_account,
        privileges=[Privileges.create_database, Privileges.create_warehouse],
    )
)

In [None]:
from snowflake.core.stage import Stage, StageEncryption

my_stage = Stage(name="my_stage", encryption=StageEncryption(type="SNOWFLAKE_SSE"))
stages = root.databases["my_db"].schemas["my_schema"].stages
stages.create(my_stage, mode=CreateMode.or_replace)

In [None]:
from snowflake.core.stage import Stage

my_stage = root.databases["my_db"].schemas["my_schema"].stages["my_stage"].fetch()
my_stage.to_dict()

In [None]:
from snowflake.core.stage import Stage, StageCollection

stages: StageCollection = root.databases["my_db"].schemas["my_schema"].stages
stage_iter = stages.iter(like="my%")  # returns a PagedIter[Stage]
for stage_obj in stage_iter:
    print(stage_obj.name)

In [None]:
from snowflake.core.compute_pool import ComputePool

compute_pool = ComputePool(
    name="my_compute_pool", min_nodes=1, max_nodes=1, instance_family=compute_pool_instance_family, auto_resume=False
)
root.compute_pools.create(compute_pool, mode="ifNotExists")

In [None]:
from snowflake.core.compute_pool import ComputePool

compute_pool = root.compute_pools["my_compute_pool"].fetch()
compute_pool.to_dict()

In [None]:
compute_pools = root.compute_pools.iter(like="my%")
for compute_pool in compute_pools:
    print(compute_pool.name)

In [None]:
from snowflake.core.compute_pool import ComputePoolResource

compute_pool_res = root.compute_pools["my_compute_pool"]
compute_pool_res.suspend()
compute_pool_res.resume()
compute_pool_res.stop_all_services()

In [None]:
from snowflake.core.image_repository import ImageRepository

my_repo = ImageRepository("my_repo")
root.databases["my_db"].schemas["my_schema"].image_repositories.create(my_repo, mode=CreateMode.or_replace)

In [None]:
from snowflake.core.image_repository import ImageRepository

my_repo_res = root.databases["my_db"].schemas["my_schema"].image_repositories["my_repo"]
my_repo = my_repo_res.fetch()
print(my_repo.owner)

In [None]:
repo_list = root.databases["my_db"].schemas["my_schema"].image_repositories.iter()
for repo_obj in repo_list:
    print(repo_obj.name)

In [None]:
from snowflake.core.image_repository import ImageRepositoryResource

my_repo_res = root.databases["my_db"].schemas["my_schema"].image_repositories["my_repo"]
# my_repo_res.delete()

In [None]:
# from snowflake.core.service import Service, ServiceSpec

# my_service = Service(name="my_service", min_instances=1, max_instances=2, compute_pool="my_compute_pool", spec=ServiceSpec("@my_stage/my_service_spec.yaml"))
# root.databases["my_db"].schemas["my_schema"].services.create(my_service)

In [None]:
# from textwrap import dedent
# from snowflake.core.service import Service, ServiceSpec

# spec_text = dedent(f"""\
#     spec:
#       containers:
#       - name: hello-world
#         image: repo/hello-world:latest
#       endpoints:
#       - name: hello-world-endpoint
#         port: 8080
#         public: true
#     """)

# my_service = Service(name="my_service", min_instances=1, max_instances=2, compute_pool="my_compute_pool", spec=ServiceSpec(spec_text))
# root.databases["my_db"].schemas["my_schema"].services.create(my_service)

In [None]:
# from snowflake.core.function import FunctionArgument, ServiceFunction

# root.databases["my_db"].schemas["my_schema"].functions.create(
#   ServiceFunction(
#     name="my-udf",
#     arguments=[
#         FunctionArgument(name="input", datatype="TEXT")
#     ],
#     returns="TEXT",
#     service="hello-world",
#     endpoint="'hello-world-endpoint'",
#     path="/hello-world-path",
#     max_batch_rows=5,
#   ),
#   mode = CreateMode.or_replace
# )

In [None]:
# result = root.databases["my_db"].schemas["my_schema"].functions["my-udf(TEXT)"].execute_function(["test"])
# print(result)

In [None]:
# from snowflake.core.service import Service

# my_service = root.databases["my_db"].schemas["my_schema"].services["my_service"].fetch()

In [None]:
services = root.databases["my_db"].schemas["my_schema"].services.iter(like="abc%")
for service_obj in services:
    print(service_obj.name)

In [None]:
# from snowflake.core.service import ServiceResource

# my_service_res = root.databases["my_db"].schemas["my_schema"].services["my_service"]

# my_service_res.suspend()
# my_service_res.resume()
# status = my_service_res.get_service_status(10)

In [None]:
stages = root.databases[database.name].schemas[schema.name].stages
stages.create(Stage(name="TASKS_STAGE"), mode=CreateMode.or_replace)

In [None]:
def trunc(session: Session, from_table: str, to_table: str, count: int) -> str:
    (session.table(from_table).limit(count).write.save_as_table(to_table))
    return "Truncated table successfully created!"


def filter_by_shipmode(session: Session, mode: str) -> str:
    (
        session.table("snowflake_sample_data.tpch_sf100.lineitem")
        .filter(col("L_SHIPMODE") == mode)
        .limit(10)
        .write.save_as_table("filter_table")
    )
    return "Filter table successfully created!"

In [None]:
tasks_stage_name = f"{database.name}.{schema.name}.TASKS_STAGE"

task1 = Task(
    name="task_python_api_trunc",
    definition=StoredProcedureCall(
        func=trunc,
        stage_location=f"@{tasks_stage_name}",
        packages=["snowflake-snowpark-python"],
    ),
    warehouse="PYTHON_API_WH",
    schedule=timedelta(minutes=1),
)

task2 = Task(
    name="task_python_api_filter",
    definition=StoredProcedureCall(
        func=filter_by_shipmode,
        stage_location=f"@{tasks_stage_name}",
        packages=["snowflake-snowpark-python"],
    ),
    warehouse="PYTHON_API_WH",
)

In [None]:
# create the task in the Snowflake database
tasks = schema.tasks
root.session.use_warehouse("PYTHON_API_WH")

trunc_task = tasks.create(task1, mode=CreateMode.or_replace)

# should be the fully qualified name
# task2.predecessors = [f"{trunc_task.database.name}.{trunc_task.schema.name}.{trunc_task.name}"]
task2.predecessors = [trunc_task.name]
filter_task = tasks.create(task2, mode=CreateMode.or_replace)

In [None]:
trunc_task.resume()

In [None]:
taskiter = tasks.iter()
for t in taskiter:
    print("Name: ", t.name, "| State: ", t.state)

In [None]:
trunc_task.suspend()

In [None]:
# trunc_task.delete()
# filter_task.delete()

In [None]:
dag_name = "python_api_dag"
dag = DAG(name=dag_name, schedule=timedelta(days=1))
with dag:
    dag_task1 = DAGTask(
        name="task_python_api_trunc",
        definition=StoredProcedureCall(
            func=trunc, stage_location=f"@{tasks_stage_name}", packages=["snowflake-snowpark-python"]
        ),
        warehouse="PYTHON_API_WH",
    )
    dag_task2 = DAGTask(
        name="task_python_api_filter",
        definition=StoredProcedureCall(
            func=filter_by_shipmode, stage_location=f"@{tasks_stage_name}", packages=["snowflake-snowpark-python"]
        ),
        warehouse="PYTHON_API_WH",
    )
    dag_task1 >> dag_task2
dag_op = DAGOperation(schema)
dag_op.deploy(dag, mode=CreateMode.or_replace)

In [None]:
dag_op.run(dag)

In [None]:
# dag_op.delete(dag)

In [None]:
# warehouse.delete()

In [None]:
from snowflake.core.stage import Stage

my_stage_res = root.databases["my_db"].schemas["my_schema"].stages["my_stage"]
# my_stage_res.delete()

In [None]:
root.session.use_role(current_role)
my_role_res = root.roles["my_role"]
# my_role_res.delete()

In [None]:
from snowflake.core import Root
from snowflake.core.schema import Schema

my_schema_res = root.databases["my_db"].schemas["my_schema"]
# my_schema_res.delete()

In [None]:
from snowflake.core import Root
from snowflake.core.database import Database

my_db_res = root.databases["my_db"]
# my_db_res.delete()

In [None]:
# database.delete()