# Snowflake

### Config

In [29]:
import snowflake.connector
import dotenv
import os
import pandas as pd
import pyarrow.parquet as pq
import time
from tqdm.notebook import tqdm

# Load environment variables
# to use create a .env file in the root directory and insert the following:
# SNOWFLAKE_PASSWORD=your_password
SNOWFLAKE_PASSWORD = os.getenv("SNOWFLAKE_PASSWORD")

# Playback configuration
PLAYBACKSPEED = 10
FILEPATH = '../combined_sorted_redset_datasets.parquet'
INTERVAL = 5  # seconds


# Snowflake connection configuration
config = {
    "account": "SFEDU02-KFB85562",
    "user": "BISON",
    "password": SNOWFLAKE_PASSWORD,
    "role": "TRAINING_ROLE",
    "warehouse": "ANIMAL_TASK_WH",
    "database": "CATCH_ME_IF_YOU_CAN",
    "schema": "PUBLIC"
}

### Create new table and delete old one

In [None]:
# Connect to Snowflake
conn = snowflake.connector.connect(**config)

# Define the schema and create the table
create_table_query = """
CREATE OR REPLACE TABLE live_queries (
    instance_id VARCHAR,
    cluster_size INTEGER,
    user_id VARCHAR,
    database_id VARCHAR,
    query_id VARCHAR,
    arrival_timestamp TIMESTAMP,
    compile_duration_ms INTEGER,
    queue_duration_ms INTEGER,
    execution_duration_ms INTEGER,
    feature_fingerprint VARCHAR,
    was_aborted BOOLEAN,
    was_cached BOOLEAN,
    cache_source_query_id VARCHAR,
    query_type VARCHAR,
    num_permanent_tables_accessed INTEGER,
    num_external_tables_accessed INTEGER,
    num_system_tables_accessed INTEGER,
    read_table_ids VARCHAR,
    write_table_ids VARCHAR,
    mbytes_scanned INTEGER,
    mbytes_spilled INTEGER,
    num_joins INTEGER,
    num_scans INTEGER,
    num_aggregations INTEGER,
    dataset_type VARCHAR
);
"""

# Execute the query to create the table
cur = conn.cursor()
cur.execute(create_table_query)

# Close the cursor and connection
cur.close()
conn.close()

print("Snowflake table created successfully.")

### Insert data into the table in given interval to simulate streaming

In [None]:
_seconds = 60 * 60 * 24 * 90  # 3 months in seconds

# Get total entries without loading full file
_parquet_file = pq.ParquetFile(FILEPATH)
_entries = _parquet_file.metadata.num_rows
print(f"Entries: {_entries}")

# Calculate batch parameters

_batchsize = max(1, int(_entries / (_seconds / INTERVAL) * PLAYBACKSPEED))
print(f"Interval: {INTERVAL}\nBatch size: {_batchsize}")

def process_batch(batch):
    """Clean and convert a single batch"""
    # Define expected types for each column (modify according to your schema)
    column_types = {
        'instance_id': 'str',
        'cluster_size': 'int',
        'user_id': 'str',
        'database_id': 'str',
        'query_id': 'str',
        'arrival_timestamp': 'datetime',
        'compile_duration_ms': 'int',
        'queue_duration_ms': 'int',
        'execution_duration_ms': 'int',
        'feature_fingerprint': 'str',
        'was_aborted': 'bool',
        'was_cached': 'bool',
        'cache_source_query_id': 'str',
        'query_type': 'str',
        'num_permanent_tables_accessed': 'int',
        'num_external_tables_accessed': 'int',
        'num_system_tables_accessed': 'int',
        'read_table_ids': 'str',
        'write_table_ids': 'str',
        'mbytes_scanned': 'int',
        'mbytes_spilled': 'int',
        'num_joins': 'int',
        'num_scans': 'int',
        'num_aggregations': 'int',
        'dataset_type': 'str'
    }

    for col in batch.columns:
        col_type = column_types.get(col, 'str')
        
        # Handle null values
        if col_type in ['int', 'float']:
            batch[col] = batch[col].fillna(0).astype(col_type)
        elif col_type == 'bool':
            batch[col] = batch[col].fillna(False).astype(bool)
        elif col_type == 'datetime':
            batch[col] = pd.to_datetime(batch[col], errors='coerce')
            batch[col] = batch[col].dt.strftime('%Y-%m-%d %H:%M:%S')
        else:  # string
            batch[col] = batch[col].astype(str).str.strip()
            batch[col] = batch[col].replace({'nan': '', 'None': '', 'null': ''})
            batch[col] = batch[col].fillna('')
    return batch

def main():
    """Main processing loop"""
    parquet_file = pq.ParquetFile(FILEPATH)
    total_batches = _entries // _batchsize + 1
    
    with snowflake.connector.connect(**config) as conn, conn.cursor() as cur:
        while True:
            # Create a tqdm progress bar for our batch iteration
            pbar = tqdm(
                parquet_file.iter_batches(batch_size=_batchsize), 
                total=total_batches, 
                desc="Processing Batches", 
                unit="batch"
            )
            
            for i, arrow_batch in enumerate(pbar, start=1):
                start_time = time.time()
                
                try:
                    # Convert Arrow batch to pandas DataFrame
                    raw_batch = arrow_batch.to_pandas()
                    start_processing = time.time()
                    batch = process_batch(raw_batch)
                    process_time = time.time() - start_processing

                    rows = [tuple(row) for row in batch.to_numpy()]

                    # Generate bulk INSERT query
                    columns = ', '.join(batch.columns)
                    placeholders_per_row = ', '.join(['%s'] * len(batch.columns))
                    values_placeholders = ', '.join([f'({placeholders_per_row})' for _ in rows])
                    sql = f"INSERT INTO live_queries ({columns}) VALUES {values_placeholders}"

                    # Flatten rows into a single parameter list
                    params = [param for row in rows for param in row]

                    # Execute in one go
                    start_execute = time.time()
                    cur.execute(sql, params)
                    conn.commit()
                    execute_time = time.time() - start_execute
                    
                except Exception as e:
                    # If there's an error, display it once:
                    pbar.write(f"Error processing batch {i}: {str(e)}")
                
                # Sleep if needed
                elapsed = time.time() - start_time
                sleep_time = max(0, INTERVAL - elapsed)
                time.sleep(sleep_time)
                
                # Update progress bar postfix with some info
                pbar.set_postfix_str(
                    f"Sleeptime: {sleep_time:.2f}s, process_time: {process_time:.2f}s, execute_time: {execute_time:.2f}s"
                )

            pbar.close()
            
            # Once a full pass of the file is done, delete the table data and start again
            cur.execute("DELETE FROM live_queries")
            conn.commit()
            print("Table data deleted, starting again")

if __name__ == "__main__":
    main()