In [4]:
import pandas as pd
from sqlalchemy import create_engine
import pymysql


In [9]:
# Extract Phase
def extract_from_local(csv_file):
    """Extracts data from a CSV file."""
    return pd.read_csv(csv_file)



def extract_from_db(db_url, username, password, db_name):
    """Extracts data from a remote db hosted on AWS RDS."""
    
    # Format the connection string for SQLAlchemy
    db_connection_str = f'mysql+pymysql://{username}:{password}@{db_url}/{db_name}'
    
    # Create a database connection using SQLAlchemy
    db_connection = create_engine(db_connection_str)
    
    # Example: Loading data from a table
    query = "SELECT * FROM athlete_events"
    data = pd.read_sql(query, db_connection)
    
    return data



# Transform Phase
def transform(data):
    """Cleans and transforms the data."""
    # Remove rows with missing values
    transformed_data = data.dropna()

    # Convert all column names to lowercase
    transformed_data.columns = [c.lower() for c in transformed_data.columns]

    
    columns_to_convert = ['age', 'height', 'weight']

    # Drop rows where 'age', 'height', or 'weight' have 'NA' values
    transformed_data = transformed_data[~transformed_data[columns_to_convert].isin(['NA', '']).any(axis=1)]

    # Convert the columns to numeric type for calculation
    transformed_data[columns_to_convert] = transformed_data[columns_to_convert].astype(float)

    # Calculate averages for specified columns
    averages = {
        'avg_age': transformed_data['age'].mean(),
        'avg_height': transformed_data['height'].mean(),
        'avg_weight': transformed_data['weight'].mean()
    }

    # Create a DataFrame for the statistical data
    stats_data = pd.DataFrame([averages])

    # Calculate averages for specified columns grouped by 'year'
    grouped_data = transformed_data.groupby('year')[['age', 'height', 'weight']].mean().reset_index()
    
    # Rename columns to indicate these are averages
    grouped_data.rename(columns={
        'age': 'avg_age',
        'height': 'avg_height',
        'weight': 'avg_weight'
    }, inplace=True)

    return transformed_data, stats_data, grouped_data

# Load Phase
def load_to_db(db_url, username, password, db_name, load_df, table_name, if_exists='replace', index=False):
    """
    Loads data to a remote database hosted on AWS RDS using bulk insert.
    
    Parameters:
        db_url (str): Database URL, typically the RDS endpoint.
        username (str): Username for the RDS database.
        password (str): Password for the RDS database.
        db_name (str): Database name.
        load_df (DataFrame): DataFrame to load into the database.
        table_name (str): Name of the table where data will be inserted.
        if_exists (str): Behavior if the table already exists - 'fail', 'replace', 'append'. Default is 'replace'.
        index (bool): Whether to write DataFrame index as a column. Default is False.
    """

    # Format the connection string for SQLAlchemy
    db_connection_str = f'mysql+pymysql://{username}:{password}@{db_url}/{db_name}'
    
    # Create a database connection using SQLAlchemy
    db_engine = create_engine(db_connection_str)
    
    # Number of rows before insertion
    rows_before = pd.read_sql(f"SELECT COUNT(*) as count FROM {table_name}", db_engine).iloc[0]['count']
    
    # Use pandas to_sql function to perform bulk insert
    load_df.to_sql(table_name, db_engine, if_exists=if_exists, index=index, method='multi')

    # Number of rows after insertion
    rows_after = pd.read_sql(f"SELECT COUNT(*) as count FROM {table_name}", db_engine).iloc[0]['count']
    
    # Calculate the number of rows inserted
    rows_inserted = rows_after - rows_before if if_exists != 'replace' else len(load_df)

    print(f"Data successfully loaded to {table_name}. Rows inserted: {rows_inserted}.")




In [10]:
# Path to the CSV file
# csv_file_path = 'practice_dataset/athlete_events.csv'
RDS_URL = 'database-1.cvi0ecu6mury.ap-southeast-2.rds.amazonaws.com'
USERNAME = 'admin'
PASSWORD = '111232***'
DB_NAME = 'demo_db'


# Running the ETL process
if __name__ == "__main__":
    # Step 1: Extract
    # df = extract(csv_file_path)
    df = extract_from_db(RDS_URL, USERNAME, PASSWORD, DB_NAME)
    print(f"Data Extract Completed, Length of Source DataFrame: {len(df)}")
    print(f" ")


    # Step 2: Transform
    transformed_df, stats_df, grouped_df = transform(df)

    # Print the lengths of the DataFrames
    print(f"Data Transformation Completed:")
    print(f"Length of Transformed DataFrame: {len(transformed_df)}")
    print(f"Length of Stats DataFrame: {len(stats_df)}")
    print(f"Length of Grouped DataFrame: {len(grouped_df)}")
    print(f" ")

    
    # Step 3: Load
    target_table = 'trans_athlete_events'
    load_to_db(RDS_URL, USERNAME, PASSWORD, DB_NAME, transformed_df, target_table)
    print(f"Data Load Completed")


Data Extract Completed, Length of Source DataFrame: 99998
 
Data Transformation Completed:
Length of Transformed DataFrame: 75045
Length of Stats DataFrame: 1
Length of Grouped DataFrame: 35
 
Data successfully loaded to trans_athlete_events. Rows inserted: 75045.
Data Load Completed
