# Connection

> Helps Manage Snowflake Connection

In [1]:
#| default_exp connection

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export

import os
import logging
import yaml
import warnings

from typing import Optional, Dict
from snowflake.snowpark import Session
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.exceptions import SnowparkSessionException


logging.getLogger('snowflake.snowpark').setLevel(logging.WARNING)
 

In [4]:
#| export

class SnowparkConnection:
    """
    Manages Snowpark connection sessions, configuration, and lifecycle.
    """

    def __init__(self, connection_config: Optional[Dict[str, str]] = None, config_file: str = 'snowflake_config.yaml'):
        # If a connection config is provided, use it. Otherwise, load from a YAML file or environment variables.
        self.connection_config = connection_config or self.load_connection_config(config_file)
        self.session = self._get_active_or_new_session()

    def load_connection_config(self, yaml_file: str) -> Dict[str, str]:
        """
        Load the Snowflake connection configuration from a YAML file or environment variables.

        Args:
            yaml_file (str): The path to the YAML file.

        Returns:
            Dict[str, str]: The Snowflake connection configuration.
        """
        config = {}
        if os.path.isfile(yaml_file):
            try:
                with open(yaml_file, 'r') as file:
                    config = yaml.safe_load(file).get('snowflake', {})
            except FileNotFoundError:
                logging.warning(f"Configuration file '{yaml_file}' not found. Falling back to environment variables.")
        
        # Fallback to environment variables if certain keys are missing
        config.update({
            'account': config.get('account') or os.getenv('SNOWFLAKE_ACCOUNT'),
            'user': config.get('user') or os.getenv('SNOWFLAKE_USER'),
            'password': config.get('password') or os.getenv('SNOWFLAKE_PASSWORD'),
            'role': config.get('role') or os.getenv('SNOWFLAKE_ROLE', 'ACCOUNTADMIN'),
            'warehouse': config.get('warehouse') or os.getenv('SNOWFLAKE_WAREHOUSE'),
            'database': config.get('database') or os.getenv('SNOWFLAKE_DATABASE'),
            'schema': config.get('schema') or os.getenv('SNOWFLAKE_SCHEMA')
        })
        return config

    def _get_active_or_new_session(self) -> Session:
        """
        Get the active Snowpark session or create a new one if none exists.

        Returns:
            Session: The Snowpark session.
        """
        try:
            session = get_active_session()
            logging.info("Using active Snowpark session.")
        except SnowparkSessionException:
            session = self.create_session()
        return session

    def create_session(self) -> Session:
        """
        Create a new Snowpark session using the provided configuration.

        Returns:
            Session: The new Snowpark session.
        """
        session_config = self.connection_config
        try:
            session = Session.builder.configs(session_config).create()
            logging.info("Snowpark session successfully created.")
            return session
        except SnowparkSessionException as e:
            logging.error(f"Error creating Snowpark session: {e}")
            raise e

    def get_session(self) -> Session:
        """
        Return the Snowpark session.

        Returns:
            Session: The Snowpark session.
        """
        return self.session

    def close_session(self) -> None:
        """
        Close the Snowpark session.
        """
        try:
            self.session.close()
        except SnowparkSessionException as e:
            logging.error(f"Error closing Snowpark session: {e}")


In [5]:
#| skip
from snowflake.snowpark.version import VERSION
from snowflake.snowpark.functions import col


# Create a SnowparkConnection instance
connection = SnowparkConnection(
    connection_config={
        'user': os.getenv('SNOWFLAKE_USER'),
        'password': os.getenv('SNOWFLAKE_PASSWORD'),
        'account': os.getenv('SNOWFLAKE_ACCOUNT'),
        'database': 'CORTEX',
        'warehouse': 'CORTEX_WH',
        'schema': 'DEV',
        'role': 'CORTEX_USER_ROLE'  # Use the desired role
    }
)

# Get the Snowpark session from the connection
session = connection.get_session()

# Enable SQL simplifier
session.sql_simplifier_enabled = True

# Get environment and Snowpark version details
snowflake_environment = session.sql('SELECT current_user(), current_version()').collect()
snowpark_version = VERSION

print('\nConnection Established with the following parameters:')
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(session.get_current_role()))
print('Database                    : {}'.format(session.get_current_database()))
print('Schema                      : {}'.format(session.get_current_schema()))
print('Warehouse                   : {}'.format(session.get_current_warehouse()))
print('Snowflake version           : {}'.format(snowflake_environment[0][1]))
print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0], snowpark_version[1], snowpark_version[2]))



Connection Established with the following parameters:
User                        : JD_SERVICE_ACCOUNT_ADMIN
Role                        : "CORTEX_USER_ROLE"
Database                    : "CORTEX"
Schema                      : "DEV"
Warehouse                   : "CORTEX_WH"
Snowflake version           : 8.31.1
Snowpark for Python version : 1.19.0


In [7]:
#| hide
import nbdev; nbdev.nbdev_export()