In [2]:
import pymysql
import pandas as pd
import numpy as np

# === 1. MySQL Connection Setup ===
# IMPORTANT: Update the 'password' and 'host' if they are different for your setup.
connection = pymysql.connect(
    host='localhost',
    user='root',
    password='varun@07',
    autocommit=True
)

cursor = connection.cursor()

# === 2. Create "TaxiData" Database ===
# A new database is created specifically for this data.
cursor.execute("CREATE DATABASE IF NOT EXISTS TaxiData")
cursor.execute("USE TaxiData")

# === 3. Load the CSV File ===
# The pandas library is used to read the taxi_trip_pricing.csv file into a DataFrame.
try:
    df_taxi = pd.read_csv('taxi_trip_pricing.csv')
    print("✅ CSV file loaded successfully into a pandas DataFrame!")
except FileNotFoundError:
    print("Error: The file 'taxi_trip_pricing.csv' was not found.")
    exit()

# === 4. Handle NaN values ===
# Replace NaN values with None to prevent "nan can not be used with MySQL" errors.
# We iterate through columns and replace NaN values based on the column's dtype.
for col in df_taxi.columns:
    if df_taxi[col].dtype == np.float64 or df_taxi[col].dtype == np.int64:
        df_taxi[col] = df_taxi[col].replace({np.nan: None})
    else:
        df_taxi[col] = df_taxi[col].replace({np.nan: ''})
print("✅ NaN values handled and replaced.")

# === 5. Create Table Automatically ===
# This function automatically generates a CREATE TABLE SQL statement
# based on the DataFrame's columns and data types.
def create_table_from_df(df, table_name):
    # Prepare a list of column names and their inferred SQL data types.
    cols = []
    for col in df.columns:
        # Replace spaces or invalid characters in column names with underscores.
        sql_col_name = f"`{col.replace(' ', '_').replace('.', '_').replace('(', '_').replace(')', '_')}`"
        
        dtype = df[col].dtype
        if 'int' in str(dtype) or 'float' in str(dtype):
            sql_type = 'FLOAT'
        else:
            # All other types are treated as VARCHAR. A size of 255 is a safe default.
            sql_type = 'VARCHAR(255)'
        cols.append(f"{sql_col_name} {sql_type}")
    
    columns_sql = ", ".join(cols)
    create_sql = f"CREATE TABLE IF NOT EXISTS `{table_name}` ({columns_sql})"
    cursor.execute(create_sql)
    print(f"✅ Table '{table_name}' created successfully!")

# Create the table for the taxi data
create_table_from_df(df_taxi, 'taxi_trips')

# === 6. Insert Data into Tables (using executemany for speed) ===
# This function handles the insertion of data from the DataFrame into the SQL table.
def insert_data(df, table_name):
    # The NaN values have already been handled in a previous step.
    
    # Generate the column list for the INSERT statement.
    cols = ",".join([f"`{col.replace(' ', '_').replace('.', '_').replace('(', '_').replace(')', '_')}`" for col in df.columns])
    
    # Create placeholders for each column to prevent SQL injection.
    placeholders = ",".join(["%s"] * len(df.columns))
    
    insert_sql = f"INSERT INTO `{table_name}` ({cols}) VALUES ({placeholders})"
    
    # Convert the DataFrame to a list of tuples for batch insertion.
    data = [tuple(row) for row in df.itertuples(index=False, name=None)]
    
    # Execute the batch insert.
    cursor.executemany(insert_sql, data)
    print(f"✅ {len(data)} rows inserted into '{table_name}' successfully!")

# Insert the data into the 'taxi_trips' table
insert_data(df_taxi, 'taxi_trips')

# === 7. Done ===
print("✨ All data has been loaded into the 'TaxiData' database!")

# Close the connection
cursor.close()
connection.close()


✅ CSV file loaded successfully into a pandas DataFrame!
✅ NaN values handled and replaced.
✅ Table 'taxi_trips' created successfully!
✅ 1000 rows inserted into 'taxi_trips' successfully!
✨ All data has been loaded into the 'TaxiData' database!
