In [1]:
from typing import Union, Dict, Optional
import yaml
import os

from snowflake.snowpark import Session



In [2]:
def create_session(config: Union[str, Dict[str, str]],
                   connection: Optional[str] = None) -> Session:

    """Establishes a snowpark connection to snowflake.

    Uses connection parameters, passed via .json config file or directly in python dict.

    Args:
        config (str/Dict) : (Relative/absolute) path to .json config file or
                            dict of connection(s) params.
        connection (str) : Specific key of preferred connection parameters held in config.
                           Defaults to None, meaning a single set of connection parameters
                           should be passed.

    Returns
        snowflake.snowpark.Session

    """

    import json

    if isinstance(config, str):  # File path passed
        with open(config) as f:
            connection_parameters = json.load(f)
    else:  # Dict of connections passed
        connection_parameters = config
    if connection:  # A specific key passed specifying connection params in config
        session = Session.builder.configs(connection_parameters[connection]).create()
    else:
        session = Session.builder.configs(connection_parameters).create()
    return session

session = create_session(config = '/Users/jsummer/.snowpark/config.json', # Set to path to .json credentials similar to snowSQL
                         connection = 'SCS')

In [3]:
with open('setup.yaml', 'r') as yaml_file:
    account_specs = yaml.safe_load(yaml_file)

In [4]:
# Set account-based constants
HUGGINGFACE_TOKEN = account_specs['HUGGINGFACE_TOKEN']
FILES_STAGE = account_specs['FILES_STAGE']
SPEC_STAGE = account_specs['SPEC_STAGE']
UDF_STAGE = account_specs['UDF_STAGE']
DATA_STAGE = account_specs['DATA_STAGE']
IMAGE_REPOSITORY = account_specs['IMAGE_REPOSITORY']
SNOW_ROLE =  account_specs['SNOW_ROLE']
SNOW_DATABASE = account_specs['SNOW_DATABASE']
SNOW_SCHEMA = account_specs['SNOW_SCHEMA']
SNOW_WAREHOUSE = account_specs['SNOW_WAREHOUSE']
CHAT_LOG_TABLE = account_specs['CHAT_LOG_TABLE']
SOURCE_TABLE_ID = account_specs['SOURCE_TABLE_ID']

In [5]:
session.use_role(SNOW_ROLE)
session.use_database(SNOW_DATABASE)
session.use_schema(SNOW_SCHEMA)
session.use_warehouse(SNOW_WAREHOUSE)

In [6]:
session.sql(f"CREATE IMAGE REPOSITORY IF NOT EXISTS {IMAGE_REPOSITORY}").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {FILES_STAGE} DIRECTORY = ( ENABLE = true ) encryption = (type = 'SNOWFLAKE_SSE')").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {DATA_STAGE} DIRECTORY = ( ENABLE = true ) encryption = (type = 'SNOWFLAKE_SSE')").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {SPEC_STAGE}").collect()
session.sql(f"CREATE STAGE IF NOT EXISTS {UDF_STAGE}").collect()
session.sql(f"""CREATE TABLE IF NOT EXISTS {CHAT_LOG_TABLE}
                (RUN_ID string,
                TIMESTAMP timestamp_ltz,
                USER_PROMPT string,
                ASSISTANT_RESPONSE string,
                SOURCE_DOCUMENTS variant)""").collect()

REPOSITORY_URL = session.sql(f"SHOW IMAGE REPOSITORIES LIKE '{IMAGE_REPOSITORY}'").collect()[0].repository_url

In [7]:
def update_spec(filepath):
    # Load the YAML data into a Python dictionary
    with open(filepath, 'r') as yaml_file:
        print(f"Updating {filepath}")
        spec = yaml.safe_load(yaml_file)
        spec['spec']['containers'][0]['image'] = f"{REPOSITORY_URL}/{spec['spec']['containers'][0]['name']}"

        for k in spec['spec']['containers'][0]['env']:
            if k == "HUGGINGFACE_TOKEN":
                spec['spec']['containers'][0]['env'][k] = HUGGINGFACE_TOKEN
            if k == "SNOW_ROLE":
                spec['spec']['containers'][0]['env'][k] = SNOW_ROLE
            if k == "SNOW_DATABASE":
                spec['spec']['containers'][0]['env'][k] = SNOW_DATABASE
            if k == "SNOW_SCHEMA":
                spec['spec']['containers'][0]['env'][k] = SNOW_SCHEMA
            if k == "SNOW_WAREHOUSE":
                spec['spec']['containers'][0]['env'][k] = SNOW_WAREHOUSE
            if k == "PRODUCT_ID":
                spec['spec']['containers'][0]['env'][k] = SOURCE_TABLE_ID
        
        for i, k in enumerate(spec['spec']['volumes']):
            if k['name'].lower() == 'stage':
                spec['spec']['volumes'][i]['source'] = f'@{FILES_STAGE}'
            if k['name'].lower() == 'data':
                spec['spec']['volumes'][i]['source'] = f'@{DATA_STAGE}'
    with open(filepath, 'w') as yaml_file:
        yaml.dump(spec, yaml_file, default_flow_style=False, sort_keys=False)

In [9]:
ignore_files = ['setup.yaml', 'system_prompt.yaml', 'repo_meta.yaml'] # These are non-spec yamls to ignore
for root, dirs, files in os.walk("../"):
    for x in files:
        if x.endswith('.yaml') and not x.endswith(tuple(ignore_files)):
            update_spec(f'{root}/{x}')
            session.file.put(f'{root}/{x}',
                             f'@{SPEC_STAGE}',
                             auto_compress = False,
                             overwrite = True)

Updating ../app/chat.yaml
Updating ../vllm/vllm.yaml
Updating ../weaviate/text2vec.yaml
Updating ../weaviate/weaviate.yaml
Updating ../weaviate/jupyter.yaml
