<a href="https://colab.research.google.com/github/username06983/geospatial_ETL_ELT/blob/main/data226_assignment6_etl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import requests
from airflow import DAG
from airflow.models import Variable
from airflow.decorators import task
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.hooks.base import BaseHook
from airflow.utils.dates import days_ago

from datetime import timedelta
from datetime import datetime
import snowflake.connector

In [None]:
def return_snowflake_conn():
    hook = SnowflakeHook(snowflake_conn_id='snowflake_conn')
    conn = hook.get_conn()
    return conn

In [None]:
@task
def extract_user_session_channel():
    import pandas as pd
    url = "s3://s3-geospatial/readonly/user_session_channel.csv"

    df_channel = pd.read_csv(url, storage_options={"anon": True})

    if df_channel.empty:
        raise ValueError("user_session_channel.csv is empty or not found.")

    return df_channel


In [None]:
@task
def extract_session_timestamp():
    import pandas as pd
    url = "s3://s3-geospatial/readonly/session_timestamp.csv"
    df_timestamp = pd.read_csv(url, storage_options={"anon": True})

    if df_timestamp.empty:
        raise ValueError("session_timestamp.csv is empty or not found.")

    return df_timestamp

In [None]:
@task
def transform(df_channel, df_timestamp):

    dup_channel = df_channel["sessionid"].duplicated().sum()
    dup_timestamp = df_timestamp["sessionid"].duplicated().sum()

    transform_channel = df_channel
    transform_timestamp = df_timestamp

    return

In [None]:
@task
def load():
    # Targets
    target_channel   = "RAW.USER_SESSION_CHANNEL"
    target_timestamp = "RAW.SESSION_TIMESTAMP"

    conn = return_snowflake_conn()
    cur = conn.cursor()
    try:
        cur.execute("BEGIN;")

        # Context from Airflow connection extras
        extras = (BaseHook.get_connection("snowflake_conn").extra_dejson or {})
        wh = extras.get("warehouse")
        db = extras.get("database")
        if wh:
            cur.execute(f"USE WAREHOUSE {wh}")
        if db:
            cur.execute(f"USE DATABASE {db}")

        cur.execute("CREATE SCHEMA IF NOT EXISTS RAW;")
        cur.execute("USE SCHEMA RAW;")

        cur.execute("""
            CREATE OR REPLACE STAGE RAW.BLOB_STAGE
            URL = 's3://s3-geospatial/readonly/'
            FILE_FORMAT = (TYPE = CSV, SKIP_HEADER = 1, FIELD_OPTIONALLY_ENCLOSED_BY = '"');
        """)

        #Create table
        cur.execute(f"""
            CREATE TABLE IF NOT EXISTS {target_channel} (
                userId INT NOT NULL,
                sessionId VARCHAR(32) PRIMARY KEY,
                channel VARCHAR(32) DEFAULT 'direct'
            );
        """)
        cur.execute(f"""
            CREATE TABLE IF NOT EXISTS {target_timestamp} (
                sessionId VARCHAR(32) PRIMARY KEY,
                ts TIMESTAMP
            );
        """)


        # DELETE
        cur.execute(f"DELETE FROM {target_channel};")
        cur.execute(f"DELETE FROM {target_timestamp};")

        # COPY INTO
        cur.execute(f"""
            COPY INTO {target_channel}
            FROM @RAW.BLOB_STAGE/user_session_channel.csv
            ON_ERROR = 'ABORT_STATEMENT';
        """)

        cur.execute(f"""
            COPY INTO {target_timestamp}
            FROM @RAW.BLOB_STAGE/session_timestamp.csv
            ON_ERROR = 'ABORT_STATEMENT';
        """)

        cur.execute(f"SELECT COUNT(*) FROM {target_channel};")
        channel_cnt = cur.fetchone()[0]
        cur.execute(f"SELECT COUNT(*) FROM {target_timestamp};")
        ts_cnt = cur.fetchone()[0]

        cur.execute("COMMIT;")

        return {"user_session_channel": channel_cnt, "session_timestamp": ts_cnt}

    except Exception as e:
        cur.execute("ROLLBACK;")
        print("Error:", e)
        raise
    finally:
        cur.close()
        conn.close()

In [None]:
with DAG(
    dag_id="etl_geospatial",
    start_date=datetime(2024, 9, 21),
    schedule="30 2 * * *",
    catchup=False,
    tags=["geospatial", "snowflake", "etl"],
) as dag:
    ch = extract_user_session_channel()
    ts = extract_session_timestamp()
    transform(ch, ts)
    load()