In [None]:
# Import python packages
import streamlit as st
import pandas as pd

# We can also use Snowpark for our analyses!
from snowflake.snowpark.context import get_active_session
session = get_active_session()


### Creating compute pools

In [None]:
-- Creating compute pools required for the service jobs

use role accountadmin;

CREATE COMPUTE POOL pr_std_pool_xs
  MIN_NODES = 1
  MAX_NODES = 1
  INSTANCE_FAMILY = CPU_X64_XS;

DESCRIBE COMPUTE POOL PR_STD_POOL_XS;

CREATE COMPUTE POOL PR_STD_POOL_S
  MIN_NODES = 1
  MAX_NODES = 2
  INSTANCE_FAMILY = CPU_X64_S;

show compute pools like 'PR_STD_POOL_S';

-- You can use any role that you have created instead of SPCS_PSE_ROLE
grant usage on compute pool pr_std_pool_xs to role SPCS_PSE_ROLE;
grant usage on compute pool pr_std_pool_s to role SPCS_PSE_ROLE;

use role SPCS_PSE_ROLE;

CREATE OR REPLACE STAGE JOBS DIRECTORY = (
    ENABLE = true);

CREATE IMAGE REPOSITORY IF NOT EXISTS IMAGES;

show image repositories;

show compute pools like 'pr%';




In [None]:
show compute pools like 'PR_%';

In [None]:
-- Resuming the compute pools if they are suspended
alter compute pool PR_STD_POOL_XS resume;
alter compute pool PR_STD_POOL_S resume;


### Creating logging tables and UDTF for tracking the tasks status for a DAG

In [None]:
-- logging individual job status
create or replace table jobs_run_stats( root_task_name string, task_name string, job_status string,GRAPH_RUN_ID string , graph_start_time timestamp_ltz, errors string, created_date datetime default current_timestamp());

-- Tracking all tasks part of the task graph
create table task_logging_stats (GRAPH_RUN_GROUP_ID varchar, NAME varchar,  STATE varchar , RETURN_VALUE varchar,QUERY_START_TIME varchar,COMPLETED_TIME varchar, DURATION_IN_SECS INT,ERROR_MESSAGE VARCHAR);

-- UDTF for getting the task status for the graph - TASK_GRAPH_RUN_STATS
create or replace function TASK_GRAPH_RUN_STATS(ROOT_TASK_ID string, START_TIME timestamp_ltz)
 returns table (GRAPH_RUN_GROUP_ID varchar, NAME varchar,  STATE varchar , RETURN_VALUE varchar,QUERY_START_TIME varchar,COMPLETED_TIME varchar, DURATION_IN_SECS INT,
 ERROR_MESSAGE VARCHAR)
as
$$
select
        GRAPH_RUN_GROUP_ID,
        NAME,
        STATE,
        RETURN_VALUE,
        to_varchar(QUERY_START_TIME, 'YYYY-MM-DD HH24:MI:SS') as QUERY_START_TIME,
        to_varchar(COMPLETED_TIME,'YYYY-MM-DD HH24:MI:SS') as COMPLETED_TIME,
        timestampdiff('seconds', QUERY_START_TIME, COMPLETED_TIME) as DURATION,
        ERROR_MESSAGE
    from
        table(INFORMATION_SCHEMA.TASK_HISTORY(
              ROOT_TASK_ID => ROOT_TASK_ID ::string,
              SCHEDULED_TIME_RANGE_START => START_TIME::timestamp_ltz,
              SCHEDULED_TIME_RANGE_END => current_timestamp()
      ))
$$
;


### Creating SP for calling the SPCS service job which spins up the container, runs it and terminates it

This code does the following :

- Accepts the name of the service job to be created, pool name where the service jobs will be executed on along with some parameters which are the inputs to the container and the retry count which is used to identify how many time should the code retry to execute the container before gracefully terminating.

- For every service job execution, we are tracking the status whether its Done or Failed and logging into jobs_run_stats table. It has details on the errors if any about the service job failures.

- This SP is invoked from another SP create_job_tasks which creates the task DAG based on the job config file. 


In [None]:


use role SPCS_PSE_ROLE;

CREATE OR REPLACE PROCEDURE ExecuteJobService(service_name VARCHAR, image_name VARCHAR, pool_name VARCHAR,table_name VARCHAR,retry_count INT)
RETURNS VARCHAR
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'create_job_service'
AS
$$
from snowflake.snowpark.functions import col
import uuid
import re
import logging
import sys

logger = logging.getLogger("python_logger")

def get_logger():
    """
    Get a logger for local logging.
    """
    logger = logging.getLogger("service-job")
    logger.setLevel(logging.INFO)
    return logger

# Functions which invokes the execute service job    
def execute_job(session, service_name, image_name,pool_name,table_name):
   # Drop the existing service if it exists
   session.sql(f'''DROP SERVICE if exists {service_name}''').collect()
   sql_qry=f'''
                        EXECUTE JOB SERVICE
                        IN COMPUTE POOL {pool_name}
                        NAME={service_name}
                        FROM SPECIFICATION  
                        '
                        spec:
                         container:
                         - name: main
                           image: {image_name}
                           env:
                             SNOWFLAKE_WAREHOUSE: xs_wh
                           args:
                           - "--query=select current_time() as time,''hello''"
                           - "--result_table={table_name}"
                        ';
                    '''
   #print(sql_qry)
   
   try: 
    _=session.sql(sql_qry).collect()
    
   except Exception as e:        
    logger.error(f"An error occurred running the app in the container: {e}")
    
   finally:
                
    job_status = session.sql(f''' SELECT    parse_json(SYSTEM$GET_SERVICE_STATUS('{service_name}'))[0]['status']::string as Status 
                                ''').collect()[0]['STATUS']

    return job_status

# This is the main function call invoked in the SP handler
# This functin calls execute_job to run the container with all the parameters required.
def create_job_service(session, service_name, image_name,pool_name,table_name,retry_count):
    import uuid
    logger = get_logger()
    logger.info("job_service")
    job_status = ''
    job_errors = ''
    current_root_task_name = ''
    current_task_name = ''
    current_graph_run_id = ''
    current_graph_start_time = ''
    try:

        cnt = retry_count

        # Execute the job service
        logger.info(
            f"Executing the Job [{service_name}] on pool [{pool_name}]"
        )
        job_status = execute_job(session, service_name,image_name, pool_name,table_name)

        # Implementing retry mechanism. Fetching the retry count value from the config file per job
        if job_status=='FAILED':
            while(cnt >0):
                r_cnt = retry_count+1 - cnt
                logger.info(
                                f"Retrying Executing the Job [{service_name}] on pool [{pool_name}] - [{r_cnt}]  out of {retry_count} times "
                            )
                job_status =  execute_job(session, service_name,image_name, pool_name,table_name)
                if job_status == 'DONE':
                    break
                cnt = cnt - 1
                
        
        if job_status=='FAILED':
            job_errors = re.sub(r"'", r"\\'",session.sql(f'''
            select SYSTEM$GET_SERVICE_LOGS('{service_name}', 0, 'main')::string as logs;
            ''').collect()[0]['LOGS'])
        else:
            job_errors = ''

        # Getting the DAG Task details. SYSTEM$TASK_RUNTIME_INFO can only work inside a task.
        result = session.sql("""select
                                SYSTEM$TASK_RUNTIME_INFO('CURRENT_ROOT_TASK_NAME')
                                root_task_name,
                                SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_NAME') 
                                task_name,
                                SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_RUN_GROUP_ID')
                                run_id,
                                SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_ORIGINAL_SCHEDULED_TIMESTAMP')  dag_start_time
                            
                            """).collect()[0]
                            
        current_root_task_name = result.ROOT_TASK_NAME
        current_task_name = result.TASK_NAME
        current_graph_run_id = result.RUN_ID
        current_graph_start_time = result.DAG_START_TIME
               
        #'a','b','c','2024-01-01'
        
        #result.ROOT_TASK_NAME, result.TASK_NAME ,result.RUN_ID, result.DAG_START_TIME
        
        # Inserting job status into logging table
        _ = session.sql(f'''
        INSERT INTO jobs_run_stats 
        (root_task_name,task_name,graph_run_id ,job_status,graph_start_time, errors ,created_date)
        SELECT '{current_root_task_name}'
        ,'{current_task_name}'
        ,'{current_graph_run_id}'
        ,'{job_status}'
        ,'{current_graph_start_time}'
        ,'{job_errors}'
        ,current_timestamp()
        ''').collect()

        
        return job_status
    except Exception as e:
        print(f"An error occurred: {e}")
        if job_status=='FAILED':
            job_errors = re.sub(r"'", r"\\'",session.sql(f'''
            select SYSTEM$GET_SERVICE_LOGS('{service_name}', 0, 'main')::string as logs;
            ''').collect()[0]['LOGS'])
        else:
            job_errors = ''
        
        session.sql(f"""
           INSERT INTO jobs_run_stats(task_name,errors,graph_run_id,job_status,created_date)
           SELECT '{service_name}',
           '{job_errors}',
           '{current_graph_run_id}',
           '{job_status}',
           current_timestamp()
            
                    """).collect()
                    
        return f'Error Occured.. Refer the job error column - {e}'
                   
$$;


### SP to create Snowflake task DAG - Fan-out and Fan-in Workflow Implementation

The code has the logic which creates the fan-in and fan-out workflow and does the following tasks:
- Based on the config file passed during its invocation it will create the task DAG (fan-out and fan-in scenario) and calls the SP created above with parameters fetched from the config file. The root task DAG is scheduled to run every 59 mins.
- Every task execution has the dependency on a specific task. Example T1 is dependent on root_task, T2 is dependent on root and T3 is dependent on T1 which implements the dependency workflow that is required.
- This code creates a finalizer task which tracks the status of all the tasks( failure or Success) and logs it into the table task_logging_stats. 

In [None]:
use role SPCS_PSE_ROLE;

CREATE OR REPLACE PROCEDURE create_job_tasks(file_name string)
RETURNS string
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'create_jobservice_tasks'
AS
$$
from snowflake.snowpark.files import SnowflakeFile
import json

def create_jobservice_tasks(session, file_name):
  parent_task_name = 'root_task'
  parent_task_sql = f'''CREATE OR REPLACE TASK {parent_task_name} 
              USER_TASK_MANAGED_INITIAL_WAREHOUSE_SIZE = 'XSMALL' 
              SCHEDULE = '59 MINUTE' 
      AS
      SELECT CURRENT_TIMESTAMP() ;'''

  session.sql(f'''{parent_task_sql}''').collect()
  print(parent_task_sql)

  with SnowflakeFile.open(file_name) as j:
      json_data= json.load(j)

  for idx, task in enumerate(json_data):
      task_name = task['task_name']
      after_task_name = task['after_task_name']
      task_sql = f"CREATE  OR REPLACE TASK {task_name} "
      task_sql += f"  WAREHOUSE = xs_wh "
      task_sql += f"  AFTER {after_task_name}  "
      task_sql += f" AS CALL ExecuteJobService('{task['job_name']}','{task['image_name']}','{task['compute_pool_name']}','{task['table_name']}',{task['retry_count']})"
      # logger.info(f'{task_sql}')
      session.sql(f'''{task_sql}''').collect()

      print(task_sql)

  # This is the Finalize task which gets the status for every task part of the DAG and loads into task_logging_stats table
  session.sql(f"""
              create or replace task GET_GRAPH_STATS
  warehouse = 'xs_wh'
  finalize = 'root_task'
  as
    declare
      ROOT_TASK_ID string;
      START_TIME timestamp_ltz;
      
    begin
      ROOT_TASK_ID := (call SYSTEM$TASK_RUNTIME_INFO('CURRENT_ROOT_TASK_UUID'));

      START_TIME := (call SYSTEM$TASK_RUNTIME_INFO('CURRENT_TASK_GRAPH_ORIGINAL_SCHEDULED_TIMESTAMP'));

      -- Insert into the logging table
      INSERT INTO task_logging_stats(GRAPH_RUN_GROUP_ID , NAME ,  STATE  , RETURN_VALUE ,QUERY_START_TIME ,COMPLETED_TIME , DURATION_IN_SECS ,
                                      ERROR_MESSAGE 
                                    )
      SELECT * FROM TABLE(TASK_GRAPH_RUN_STATS(:ROOT_TASK_ID, :START_TIME))  where NAME !='GET_GRAPH_STATS';

    end;
              """
              ).collect()

  session.sql('alter task GET_GRAPH_STATS resume').collect()
  
  session.sql(f'''SELECT SYSTEM$TASK_DEPENDENTS_ENABLE('root_task')''').collect()

  return 'done'

$$;




In [None]:
-- Creating stage for config.json
CREATE OR REPLACE STAGE JOBS DIRECTORY = (
    ENABLE = true);

In [None]:
-- Upload the jobconfig.json file to JOBS stage
ls @jobs

### Calling the Scheduling workflow SP

In the below cell we are invoking the scheduling workflow SP which accepts the jobconfig file which has all the  config required for the tasks to be created. Below is an extract from the config file. 

[{
    "task_name":"t_myjob_1",
    "compute_pool_name":"PR_STD_POOL_XS",
    "job_name":"myjob_1",
    "table_name":"results_1",
    "retry_count":0,
    "after_task_name":"root_task"
   },
   

]

In [None]:

call create_job_tasks(build_scoped_file_url(@jobs, 'jobconfig.json'));


In [None]:
--  Checks the DAG task created for the root_task. You can see the column predecessor which mentions the dependent task name
select *
  from table(information_schema.task_dependents(task_name => 'root_task', recursive => true));

In [None]:
--  With the below code we are simulating failure, so that we can tracjk what error are we tracking in the logging tables and to test the retry logic behaviour

-- This should fail the job3
ALTER TABLE RESULTS_3 DROP COLUMN "'HELLO'";
ALTER TABLE RESULTS_3 ADD COLUMN CREATEDATE DATETIME ;




In [None]:
--  View job run status
select top 10 * from jobs_run_stats order by created_date desc;


In [None]:
--  Query task logging status (by the finalizer task)
SELECT top 10 * FROM task_logging_stats ORDER BY CAST(QUERY_START_TIME AS DATETIME) DESC;



In [None]:
--  View all dependent task for root_task
select *
  from table(information_schema.task_dependents(task_name => 'root_task', recursive => false));

#### This is the SP to drop all the tasks whose root task name is root_task.

In [None]:

CREATE OR REPLACE PROCEDURE drop_job_tasks()
RETURNS string
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'drop_tasks'
execute as caller
AS
$$
from snowflake.snowpark.files import SnowflakeFile
import json
def drop_tasks(session):
    session.sql('alter task root_task suspend').collect()
    res= session.sql(f''' select name
        from table(information_schema.task_dependents(task_name => 'root_task', recursive => true))''').collect()
    for r in res:
        print(r.NAME)
        session.sql(f'drop task {r.NAME}').collect()
    session.sql('drop task GET_GRAPH_STATS').collect()
    return 'Done'
$$;



In [None]:
-- Deleting the DAG
call drop_job_tasks();

### Below code is show can you use YAML file for creating the service jobs. 

In [None]:


CREATE OR REPLACE PROCEDURE CreateJobService(service_name VARCHAR, pool_name VARCHAR)
RETURNS VARCHAR
LANGUAGE PYTHON
RUNTIME_VERSION = '3.8'
PACKAGES = ('snowflake-snowpark-python')
HANDLER = 'create_job_service'
AS
$$
from snowflake.snowpark.functions import col
import uuid


def create_job_service(session, service_name, pool_name):
    import uuid
    try:
        # Drop the existing service if it exists
        session.sql(f'''DROP SERVICE if exists {service_name}''').collect()
        
        # Execute the job service
        session.sql(f'''
                        EXECUTE JOB SERVICE
                        IN COMPUTE POOL {pool_name}
                        NAME={service_name}
                        FROM @specs
                        SPECIFICATION_FILE='my_job_spec.yaml' 
                    ''').collect()
        
        # Get the status of the job service
        job_status = session.sql(f'''
                                    SELECT parse_json(SYSTEM$GET_SERVICE_STATUS('{service_name}'))[0]['status']::string as Status 
                                ''').collect()[0]['STATUS']
        return job_status
    except Exception as e:
        print(f"An error occurred: {e}")
        return 'Error Occured.. Refer the logs'
        
                     
$$;