In [0]:
#required packages

import pandas as pd, numpy as np, databricks.koalas as ks
from datetime import datetime, date, timedelta, tzinfo
import pytz
import pyspark.sql.functions as F
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.types import *
import matplotlib.pyplot as plt, seaborn as sns
from snowflake.sqlalchemy import URL
import sqlalchemy as sal
from sqlalchemy import create_engine
import snowflake.connector
%matplotlib inline
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes, hmac
import re
from databricks import sql
import logging as logger
from retrying import retry



In [0]:
#declare variables

#snowflake
user_nm='SPASU05'
scope_nm='SNF-DOPS-USER-DB-{}-SCP'.format(user_nm)


sf_connection_dict={
                 'EDW_environment' : 'edw',
                 'EDW_regular_url' : 'abs_edw_prd.west-us-2.privatelink.snowflakecomputing.com',
                 'EDW_regular_database' : 'DW_PRD',
                 'EDW_regular_schema' : 'TEMP_DS',
                 'EDW_regular_warehouse' : 'PROD_DATA_ANALYTICS_WH',
                 'EDW_regular_role' : 'EDDM_DATA_SCIENTIST_GG',
                 'EDM_environment' : 'edm',
                 'EDM_regular_url' : 'abs_itds_prd.west-us-2.privatelink.snowflakecomputing.com',
                 'EDM_regular_database' : 'EDM_FEATURES_PRD',
                 'EDM_regular_schema' : 'SCRATCH_DS',
                 'EDM_regular_warehouse' : 'EDM_DATASCIENCE_WH',
                 'EDM_regular_role' : 'EDDM_DATA_SCIENTIST_GG'                         
                }

snfKey  = dbutils.secrets.get(scope=scope_nm, key="SnowEncPswdKey")
snfPass = dbutils.secrets.get(scope=scope_nm, key="SnowEncPswdPass")
snfUser = dbutils.secrets.get(scope=scope_nm, key="SnowUsername")


#ADLS
adls_connection_dict={'adls_server_hostname': 'adb-3055846038102621.1.azuredatabricks.net',
                      'http_path': 'sql/protocolv1/o/3055846038102621/0315-150338-eerie83',
                      'access_token_scope':'scp_itds_devops',
                      'access_token_key' : 'dasc-prod-databricks-01-dbx-key'                        
                          }
#db_work dev/prod paths
pathADLSdev = 'abfs://absitdsdevwussa001@absitdsdevwussa001.dfs.core.windows.net/itds/dev/default/work/'
pathADLSprd = 'abfs://absitdsprodwussa001@absitdsprodwussa001.dfs.core.windows.net/default/work/datascience/'


In [0]:
#Function to get private key
def getKey(passwordKey, passPhrase):
  p_key = serialization.load_pem_private_key(
      passwordKey.encode(),
      password=passPhrase.encode(),
      backend=default_backend()
  )

  pkb = p_key.private_bytes(
         encoding=serialization.Encoding.PEM,
         format=serialization.PrivateFormat.PKCS8,
         encryption_algorithm=serialization.NoEncryption()
        )

  pkb = pkb.decode("UTF-8")
  return re.sub("-*(BEGIN|END) PRIVATE KEY-*\n","",pkb).replace("\n","")

'''snowflake functions '''


'''function to create snowflake table''' 
def create_table_snowflake(df, snowflake_table):
  
    snow_engine = snowflake_connector()
    index = df.index
    for i in range(len(df)//15000 + 1):
        df.iloc[(i)*15000:(i+1)* 15000].to_sql(snowflake_table, con=snow_engine, if_exists='replace', index=False, index_label=None)

        
'''function to read snowflake data'''
@retry(stop_max_attempt_number=2)
def read_snowflake(env,sql, role='regular', sys='prd'):
  try:
    if (env == 'edw' and role == 'regular'):
      swOptions = dict(sfUrl=sf_connection_dict['EDW_regular_url'],
               sfUser=snfUser,
               pem_private_key=getKey(snfKey, snfPass),
               sfDatabase=sf_connection_dict['EDW_regular_database'],
               sfSchema=sf_connection_dict['EDW_regular_schema'],
               sfWarehouse=sf_connection_dict['EDW_regular_warehouse'],
               sfRole=sf_connection_dict['EDW_regular_role'])
    elif (env == 'edm' and role == 'regular'):
      swOptions = dict(sfUrl=sf_connection_dict['EDM_regular_url'],
               sfUser=snfUser,
               pem_private_key=getKey(snfKey, snfPass),
               sfDatabase=sf_connection_dict['EDM_regular_database'],
               sfSchema=sf_connection_dict['EDM_regular_schema'],
               sfWarehouse=sf_connection_dict['EDM_regular_warehouse'],
               sfRole=sf_connection_dict['EDM_regular_role'])
    else:
      raise Exception('Configuration not found for environment {}.'.format(env))
  except Exception as e:
    raise Exception('Configuration not found for environment {}.'.format(env, e))
  df = spark.read\
    .format("snowflake")\
    .options(**swOptions)\
    .option("query", sql)\
    .load()
  df = df.toDF(*[c.lower() for c in df.columns])
  return df

'''function to write data to snowflake table'''
def write_snowflake(env, sql, table, write_mode, role='regular', sys='prd'):
  try:
      if (env == 'edw' and role == 'regular'):
        swOptions = dict(sfUrl=sf_connection_dict['EDW_regular_url'],
               sfUser=snfUser,
               pem_private_key=getKey(snfKey, snfPass),
               sfDatabase=sf_connection_dict['EDW_regular_database'],
               sfSchema=sf_connection_dict['EDW_regular_schema'],
               sfWarehouse=sf_connection_dict['EDW_regular_warehouse'],
               sfRole=sf_connection_dict['EDW_regular_role'])
      elif (env == 'edm' and role == 'regular'):
        swOptions = dict(sfUrl=sf_connection_dict['EDM_regular_url'],
               sfUser=snfUser,
               pem_private_key=getKey(snfKey, snfPass),
               sfDatabase=sf_connection_dict['EDM_regular_database'],
               sfSchema=sf_connection_dict['EDM_regular_schema'],
               sfWarehouse=sf_connection_dict['EDM_regular_warehouse'],
               sfRole=sf_connection_dict['EDM_regular_role'])
      else:
        raise Exception('Configuration not found for environment {}.'.format(env))
  except Exception as e:
    raise Exception('Configuration not found for environment {}.'.format(env, e))

  (spark.sql(sql)
     .write
     .mode(write_mode)
     .format("snowflake")
     .options(**swOptions)
     .option("truncate_table", "on")
     .option("usestagingtable", "off")
     .option("dbtable", table)
     .save()
  ) 
  
  return


'''function to check if snowflake table is empty''' 

def empty_snowflake_table_check(env, table):
  record_existence = 1
  try:
    sql = """SELECT * FROM {} limit 100 """.format(table)
    snf_record_existence =  read_snowflake(env,sf_connection_dict, sql)
    snf_record_existence.createOrReplaceTempView("snf_record_existence")
    record_existence =   spark.sql("""SELECT count(*) as cnt FROM snf_record_existence """ ).collect()[0][0]
    print(record_existence)
    if record_existence == 0:
      return """Empty Table Pre-requsite met for {}""".format(table)
    else:
      raise Exception("""Empty Table Pre-requsite not met for {}. Record not found.""".format(table))
  except Exception as e:
    raise Exception('Empty Table Pre-requsite not met for {}.'.format(table))
        


        
''' 
generic functions
convert spark df to koalas
'''

def read_sw_k(env, sql, role='regular'):
  df = read_snowflake(env, sql, role)
  df = df.to_koalas()
  return df

# convert spark df to pandas
def read_sw_p(env, sql, role='regular'):
  df = read_snowflake(env, sql, role)
  df = df.toPandas()
  return df

def list_to_string(x):
  return ', '.join(map(str, sorted(x)))




'''ADLS fucntions to access Hive tables stored as parquet '''   


'''create ADLS connector '''
def get_connector(adls_connection_dict):
    print(adls_connection_dict['adls_server_hostname'])
    connection = sql.connect(
        server_hostname=adls_connection_dict['adls_server_hostname'],
        http_path=adls_connection_dict['http_path'],
        access_token = dbutils.secrets.get(scope=adls_connection_dict['access_token_scope'], key=adls_connection_dict['access_token_key']))
    return connection
  
'''function to check record existence in hive(db_work) table '''  
def record_existence_check(table):
  record_existence = 0
  try:
    record_existence =   spark.sql("""SELECT count(*) as cnt
                            FROM {}
                            """.format(table )
                            ).collect()[0][0]
    if record_existence > 0:
      return """Pre-requsite met for {}""".format(table)
    else:
      raise Exception("""Pre-requsite not met for {}. Record not found.""".format(table))
  except Exception as e:
    raise Exception('Pre-requsite not met for {}. Error: {}'.format(table, e))
    
    
'''
function to check if table was prperly saved in ADLS
'''
def update_adls_table_check(table):
  try:
    q =   spark.sql("""DESCRIBE DETAIL {}
                            """.format(table)
                            ).collect()[0][5]
    check_dt = datetime.strftime(q, '%Y-%m-%d')
    if check_dt == Prediction_date:
      return """Pre-requsite met for {}""".format(table)
    else:
      raise Exception("""Pre-requsite not met for {}. Table is not properly saved/replaced in adls gen 2 .""".format(table))
  except Exception as e:
    raise Exception('Pre-requsite not met for {}.'.format(table))

'''
function to write/append data to db_work tables
'''
    
def write_append_parquet(path,name,df,write_mode):  
  df.write.format("parquet").mode(write_mode).option("path", f'{path}{name}').saveAsTable(f"db_work.{name.replace('/','_')}")
    
'''
function to drop hive table
'''

def drop_hive_table(table):
    query = '''DROP TABLE IF EXISTS {}'''.format(table)
    return spark.sql(query)
  
  

In [0]:
# q = """select distinct w.division_id, w.warehouse_id, w.warehouse_nm,
# rog.rog_id, a.corp_item_cd as corporate_item_cd, desc_item as item_dsc, ctgry_cd as smic_category, pack_whse as ship_unit_pack_qty
# from scratch_ds.ia_cic_master a
# join edm_views_prd.dw_views.corporate_item c on (a.corp_item_cd = c.corporate_item_cd)
# join edm_views_prd.dw_views.supply_chain_item b
# on (c.corporate_item_integration_id = b.corporate_item_integration_id)
# join edm_views_prd.dw_views.warehouse w
# on (b.warehouse_id = w.warehouse_id and b.division_id = w.division_id)
# join edm_views_prd.dw_views.retail_order_group rog
# on (rog.division_id = w.division_id)
# where c.DW_CURRENT_VERSION_IND = TRUE
# and c.DW_LOGICAL_DELETE_IND = FALSE
# and b.DW_CURRENT_VERSION_IND = TRUE
# and b.DW_LOGICAL_DELETE_IND = FALSE
# and b.division_id = 32
# and w.distribution_center_id = 'WMEL'
# and w.DW_CURRENT_VERSION_IND = TRUE
# and w.DW_LOGICAL_DELETE_IND = FALSE
# and rog.DW_CURRENT_VERSION_IND = TRUE
# and rog.DW_LOGICAL_DELETE_IND = FALSE"""

# read_snowflake('edm', q, 'regular' ,'prd').count()