In [None]:
!/home/ec2-user/SageMaker/rgs/miniconda/envs/rgsutils/bin/pip install sqlalchemy-redshift
!/home/ec2-user/SageMaker/rgs/miniconda/envs/rgsutils/bin/pip install SQLAlchemy

# https://www.pythonsheets.com/notes/python-sqlalchemy.html

In [1]:
import sqlalchemy as sa
import urllib.parse
import multiprocessing as mp
import uuid 
from pgdb import connect
from iamconnectioninfo import IamConnection
from benchmarkloadrunner import update_task_status

iamconnectioninfo = IamConnection()
conn_string = f'postgresql+pygresql://{urllib.parse.quote_plus(iamconnectioninfo.username)}:{urllib.parse.quote_plus(iamconnectioninfo.password)}@{iamconnectioninfo.hostname_plus_port}/{iamconnectioninfo.db}'

managed_schema = ( 
    {'schema_name': 'tpcds_1gb'},
    {'schema_name': 'tpcds_10gb'},
    {'schema_name': 'tpcds_100gb'},
    {'schema_name': 'tpcds_1tb'},
    {'schema_name': 'tpcds_3tb'},
    {'schema_name': 'tpcds_10tb'},
    {'schema_name': 'tpcds_30tb'},
    {'schema_name': 'tpch_1gb'},
    {'schema_name': 'tpch_10gb'},
    {'schema_name': 'tpch_100gb'},
    {'schema_name': 'tpch_1tb'},
    {'schema_name': 'tpch_3tb'},
    {'schema_name': 'tpch_10tb'},
    {'schema_name': 'tpch_30tb'},
)

working_dir = '/home/ec2-user/SageMaker/derived-tpcds-tpch-benchmarks'
tpcds_ddl_file = f'{working_dir}/ddl/tpcds-ddl.sql'
tpch_ddl_file = f'{working_dir}/ddl/tpch-ddl.sql'

def clean_all_managed_datasets():
    engine = sa.create_engine(conn_string)
    conn = engine.connect()
    trans = conn.begin()
    for schema in managed_schema:
        statement = f"""DROP SCHEMA IF EXISTS {schema.get('schema_name')} CASCADE"""
        engine.execute(statement)

    trans.commit()
    result = conn.execute('SELECT nspname FROM pg_namespace WHERE nspowner=(select usesysid from pg_user where usename=current_user);')
    for _r in result:
        print(_r)
    print('=============')
    trans = conn.begin()
    for schema in managed_schema:
        statement = f"""CREATE SCHEMA IF NOT EXISTS {schema.get('schema_name')} """
        engine.execute(statement)

    trans.commit()
    result = conn.execute('SELECT nspname FROM pg_namespace WHERE nspowner=(select usesysid from pg_user where usename=current_user);')
    for _r in result:
        print(_r)
    conn.close()
    
def get_sql_from_file(file):
    # Open the .sql file
    sql_file = open(file,'r')

    # Create an empty command string
    sql_command = ''

    # Iterate over all lines in the sql file
    for line in open(file, 'r').read():
        # Ignore commented lines
        if not line.startswith('--') and line.strip('\n'):
            # Append line to the command string
            sql_command += line.strip('\n')

            # If the command string ends with ';', it is a full statement
            if sql_command.endswith(';'):
                yield sql_command
                sql_command = ''
                
def load_all_ddl():
    engine = sa.create_engine(conn_string)
    conn = engine.connect()
    # Begin transaction
    trans = conn.begin()

    for schema in managed_schema:
        set_search_path_statement = f"""SET search_path TO {schema.get('schema_name')}"""
        conn.execute(set_search_path_statement)
        if schema.get('schema_name').startswith('tpcds'):
            for stmt in get_sql_from_file(tpcds_ddl_file):
                conn.execute(stmt)
        else:
            for stmt in get_sql_from_file(tpch_ddl_file):
                conn.execute(stmt)
    
    trans.commit()
    result = conn.execute("""SELECT nspname, count(c.*) num_tables FROM pg_class c JOIN pg_namespace n on n.oid = c.relnamespace and n.nspowner = c.relowner where c.relkind='r' and c.relowner=(select usesysid from pg_user where usename=current_user) group by 1 order by 1;""")
    for _r in result:
        print(_r)
    conn.close()


def load_managed_dataset(conn, dataset, scale, clean=False):
    data_set = dataset
    if data_set == 'tpcds':
        tables = [
            'store_sales', 'catalog_sales', 'web_sales', 'web_returns',
            'store_returns', 'catalog_returns', 'call_center', 'catalog_page',
            'customer_address', 'customer', 'customer_demographics', 'date_dim',
            'household_demographics', 'income_band', 'inventory', 'item',
            'promotion', 'reason', 'ship_mode', 'store', 'time_dim', 'warehouse',
            'web_page', 'web_site'
        ] 
    elif data_set == 'tpch':
        tables = [
            'nation', 'region', 'part', 'supplier', 'partsupp', 'customer',
            'orders', 'lineitem'
        ]
        
    schema = f'{dataset}_{scale}'
    
    if clean:
        for table in tables:
            conn.execute(f'TRUNCATE {schema}.{table}')
    
    table_queue = mp.JoinableQueue()
    for table in tables:
        table_queue.put(table)
        
    processes = []
    num_worker_process = 2
    task_status = {'type': 'insert', 'sql': {'task_name': f'load_{data_set}', 'task_version': '1.0',
                                                 'task_path': '/home/ec2-user/SageMaker/derived-tpcds-tpch-benchmarks/scripts/CleanManagedDatasets.ipynb',
                                                 'task_concurrency': num_worker_process, 'task_status': 'inflight', }}
    task_uuid = update_task_status(task_status)
    for i in range(num_worker_process):
        worker_process = mp.Process(target=load_worker,
                                    args=(table_queue, data_set, scale, task_uuid),
                                    daemon=True,
                                    name=f'{data_set}_worker_process_{i}',
        )
        worker_process.start()
        processes.append(worker_process)

    table_queue.join()
    return task_uuid

def clean_orphan_jobs():
    pass
    # postgres=# delete from task_load_status where task_uuid in (select task_uuid from task_status where task_status='inflight');
    # DELETE 2
    # postgres=# delete from task_status where task_status='inflight';
    # DELETE 1

def load_worker(queue, data_set, scale, task_uuid):

    while True:
        tbl = queue.get()
        print('Processing %s (MP: %s) ' % (tbl, mp.current_process().name))

        schema = '{}_{}'.format(data_set, scale)

        bucket = 'redshift-managed-loads-datasets-us-east-1'
        copy_sql = f"COPY {tbl} FROM 's3://{bucket}/dataset={data_set}/size={scale.upper()}/table={tbl}/{tbl}.manifest' iam_role '{iamconnectioninfo.iamrole}' gzip delimiter '|' COMPUPDATE OFF MANIFEST"
        copy_sql_double_quoted = copy_sql.translate(str.maketrans({"'": r"''"}))

        with connect(dbname='postgres',host='0.0.0.0',user='postgres') as conn:
            cursor = conn.cursor()
            cursor.execute(
                f"INSERT INTO task_load_status(task_uuid,tablename,dataset,status,load_start,querytext) values('{task_uuid}','{tbl}','{schema}','inflight',timezone('utc', now()),'{copy_sql_double_quoted}')")

        with connect(database=iamconnectioninfo.db,
                     host=iamconnectioninfo.hostname_plus_port,
                     user=iamconnectioninfo.username,
                     password=iamconnectioninfo.password) as conn:
            cursor = conn.cursor()
            cursor.execute('set search_path to %s' % (schema))
            cursor.execute(copy_sql)
            cursor.execute('select pg_last_copy_id()')
            query_id = int("".join(filter(str.isdigit,
                                          str(cursor.fetchone()))))
            cursor.execute('select count(*) from %s' % (tbl))
            row_count = int("".join(filter(str.isdigit,
                                           str(cursor.fetchone()))))
            cursor.execute(
                'select count(*) from stv_blocklist where tbl=\'%s.%s\'::regclass::oid'
                % (schema, tbl))
            block_count = int("".join(
                filter(str.isdigit, str(cursor.fetchone()))))

            

        with connect(dbname='postgres',host='0.0.0.0',user='postgres') as conn:
            cursor = conn.cursor()
            cursor.execute(
                'UPDATE task_load_status SET status=\'complete\',load_end=timezone(\'utc\', now()), '
                'query_id=%s,rows_d=%s, size_d=%s WHERE tablename=\'%s\' and dataset=\'%s\''
                % (query_id, row_count, block_count, tbl, schema))

        queue.task_done()

In [2]:
# Drop all managed data sets and create empty schema
clean_all_managed_datasets()
# Load all table DDL
load_all_ddl()

('tpcds_1gb',)
('tpcds_10gb',)
('tpcds_100gb',)
('tpcds_1tb',)
('tpcds_3tb',)
('tpcds_10tb',)
('tpcds_30tb',)
('tpch_1gb',)
('tpch_10gb',)
('tpch_100gb',)
('tpch_1tb',)
('tpch_3tb',)
('tpch_10tb',)
('tpch_30tb',)
('tpcds_100gb', 24)
('tpcds_10gb', 24)
('tpcds_10tb', 24)
('tpcds_1gb', 24)
('tpcds_1tb', 24)
('tpcds_30tb', 24)
('tpcds_3tb', 24)
('tpch_100gb', 8)
('tpch_10gb', 8)
('tpch_10tb', 8)
('tpch_1gb', 8)
('tpch_1tb', 8)
('tpch_30tb', 8)
('tpch_3tb', 8)


In [None]:
# Load TPC-DS 1GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpcds', '1gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)

In [None]:
# Load TPC-DS 10GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpcds', '10gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)

In [None]:
# Load TPC-DS 100GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpcds', '100gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)

In [3]:
# Load TPC-H 1GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpch', '1gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)

Processing nation (MP: tpch_worker_process_0) 
Processing region (MP: tpch_worker_process_1) 
Processing part (MP: tpch_worker_process_1) 
Processing supplier (MP: tpch_worker_process_0) 
Processing partsupp (MP: tpch_worker_process_0) 
Processing customer (MP: tpch_worker_process_1) 
Processing orders (MP: tpch_worker_process_1) 
Processing lineitem (MP: tpch_worker_process_0) 


In [None]:
# Load TPC-H 10GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpch', '10gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)

Processing nation (MP: tpch_worker_process_0) 
Processing region (MP: tpch_worker_process_1) 
Processing part (MP: tpch_worker_process_0) 
Processing supplier (MP: tpch_worker_process_1) 


In [None]:
# Load TPC-H 100GB
engine = sa.create_engine(conn_string)
task_uuid = load_managed_dataset(engine, 'tpch', '100gb', clean=True)
task_status = {'type': 'update', 'uuid': task_uuid, 'sql': {'task_status': 'complete', }}
update_task_status(task_status)