<a href="https://colab.research.google.com/github/raphalencar/Coupon_AB_test_case/blob/main/Load_raw_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
# !rm -rf "/content/download_dir/"
# !rm -rf "/content/raw_parquet/"
# !rm -rf "/content/sample_data/"

### PySpark install

In [2]:
!pip install pyspark



### Imports

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark import StorageLevel

import os
import tarfile

### Utils

In [22]:
def download_data(url, local_path):
    file_name = os.path.basename(local_path)
    # local_file_path = os.path.join(local_path, file_name)

    if os.path.exists(local_path) and os.path.getsize(local_path) > 0:
        print(f"'{file_name}' already exists.")
        return True
    else:
        print(f"Loading '{file_name}'...")
        # Adiciona -nc para não sobrescrever e remove -q para mostrar progresso
        # O -O garante que o arquivo seja salvo com o nome especificado em local_path
        !wget -nc $url -O {local_path}

        if not (os.path.exists(local_path) and os.path.getsize(local_path) > 0):
            print(f"Download '{file_name}' failed or is empty.")
            return False
        else:
            return True

def read_data(source_path, spark_read_func, parquet_output_path, timestamp_partition_cols=None, **kwargs):
    file_name = os.path.basename(source_path)
    df_raw = None

    if os.path.exists(parquet_output_path) and len(os.listdir(parquet_output_path)) > 0:
        print(f"Reading '{file_name.replace('.', '_')}_df' from existing Parquet: {parquet_output_path}")
        df_raw = spark.read.parquet(parquet_output_path)
    else:
        df_raw = save_to_parquet(df_raw, source_path, file_name, spark_read_func, parquet_output_path, timestamp_partition_cols, **kwargs)

    if df_raw:
        print(f"\nDataFrame '{file_name.replace('.', '_')}_df' from Parquet (raw):")
        df_raw.printSchema()
        df_raw.show(5)
    return df_raw

def save_to_parquet(df, source_path, file_name, spark_read_func, parquet_output_path, timestamp_partition_cols, **kwargs):
    print(f"Reading '{file_name.replace('.', '_')}_df' from raw source: {source_path}")
    try:
        # salvando no formato parquet e particionando para melhor performance
        df = spark_read_func(source_path, **kwargs)

        prepare_df_for_partitioning(df, timestamp_partition_cols)

        # df = cast_to_timestamp(df, timestamp_partition_cols)
        print(f"Initial load of '{file_name.replace('.', '_')}_df' from raw source completed.")

        print(f"Saving '{file_name.replace('.', '_')}_df' to raw Parquet: {parquet_output_path}")
        df_writer = df.write.mode("overwrite")
        if timestamp_partition_cols:
            print(f"Partitioning raw data by: {timestamp_partition_cols}")
            df_writer = df_writer.partitionBy(*timestamp_partition_cols)

        df_writer.parquet(parquet_output_path)
        print(f"Successfully saved '{file_name.replace('.', '_')}_df' to raw Parquet.")

        return spark.read.parquet(parquet_output_path)

    except Exception as e:
        print(f"Error loading '{file_name}' and saving to Parquet: {e}")
        return None

def prepare_df_for_partitioning(df_input, date_col_name, partition_cols_names):
    if date_col_name in df_input.columns:
        print(f"Preparing for raw partitioning: converting '{date_col_name}' and extracting date components for partitioning.")

        # Converte para Timestamp primeiro, se ainda não for
        # Se a coluna já é Timestamp/Date, usa ela diretamente para extração
        if df_input.schema[date_col_name].dataType not in (TimestampType(), DateType()):
            df_input = df_input.withColumn(f"{date_col_name}_temp_ts", to_timestamp(col(date_col_name)))
            col_to_use_for_extraction = f"{date_col_name}_temp_ts"
        else:
            col_to_use_for_extraction = date_col_name

        df_output = df_input
        # Cria as colunas de particionamento dinamicamente com base em partition_cols_names
        for p_col_name in partition_cols_names:
            if p_col_name.endswith("_year"):
                df_output = df_output.withColumn(p_col_name, year(col(col_to_use_for_extraction)))
            elif p_col_name.endswith("_month"):
                df_output = df_output.withColumn(p_col_name, month(col(col_to_use_for_extraction)))

        # Opcional: drop a coluna temporária de Timestamp se não for necessária no RAW Parquet
        if f"{date_col_name}_temp_ts" in df_output.columns:
            df_output = df_output.drop(f"{date_col_name}_temp_ts")

        return df_output
    else:
        print(f"Date column '{date_col_name}' not found for preprocessing for partitioning.")
        return df_input

def extract_data(archive_path, extract_dir, expected_extensions=['.csv', '.json']):
    print(f"Extracting '{os.path.basename(archive_path)}' to '{extract_dir}'...")
    try:
        found_file_path = None

        # Primeiro, verificar se um arquivo de dados real já foi extraído
        # e não é um arquivo de metadados como '._ab_test_ref.csv'
        # Isso ainda é uma heurística, mas mais direcionada
        for ext in expected_extensions:
            # Tentar o nome do arquivo "limpo" derivado do tarball
            base_name_without_tar_gz = os.path.basename(archive_path).replace('.tar.gz', '')
            potential_extracted_path = os.path.join(extract_dir, base_name_without_tar_gz + ext)
            if os.path.exists(potential_extracted_path) and os.path.getsize(potential_extracted_path) > 0:
                print(f"File '{os.path.basename(potential_extracted_path)}' already extracted and is valid. Skipping extraction.")
                return potential_extracted_path

        # Se não encontrou arquivo extraído existente e válido, procede com a descompactação
        if not os.path.exists(extract_dir):
            os.makedirs(extract_dir)

        with tarfile.open(archive_path, "r:gz") as tar:
            # Usamos uma lista para armazenar os membros válidos para processar em ordem
            valid_members = []
            for member in tar.getmembers():
                # Filtrar arquivos de metadados/ocultos e diretórios
                # member.isfile() garante que é um arquivo, não um diretório.
                # not member.name.startswith('._') e not member.name.startswith('__MACOSX/')
                # ignora arquivos ocultos do macOS ou pastas de metadados.
                if member.isfile() and \
                   not member.name.startswith('._') and \
                   not member.name.startswith('__MACOSX/') and \
                   any(member.name.lower().endswith(ext) for ext in expected_extensions):
                    valid_members.append(member)

            # Priorizar arquivos com nomes "limpos" (sem subdiretórios inesperados)
            # ou o arquivo que parece ser o principal
            valid_members.sort(key=lambda m: (
                1 if os.path.basename(m.name) == os.path.basename(archive_path).replace('.tar.gz', '.csv') else
                2 if os.path.basename(m.name) == os.path.basename(archive_path).replace('.tar.gz', '.json') else
                0 # Outros arquivos
            ))

            for member in valid_members:
                print(f"Extracting '{member.name}'...")
                tar.extract(member, path=extract_dir) # Extrai para o diretório alvo
                extracted_path = os.path.join(extract_dir, member.name)

                if os.path.exists(extracted_path) and os.path.getsize(extracted_path) > 0:
                    found_file_path = extracted_path
                    print(f"File '{member.name}' extracted to '{extracted_path}'.")
                    # Optamos por parar no primeiro arquivo de dados válido encontrado e extraído
                    break

            if found_file_path:
                print(f"Extract completed. File found: {found_file_path}")
                return found_file_path
            else:
                print(f"No valid data files with extensions {expected_extensions} found inside the compressed file, ignoring metadata/hidden files.")
                return None

    except tarfile.ReadError:
        print(f"Error: The file '{archive_path}' is not a valid tar.gz archive or is corrupt.")
        return None
    except Exception as e:
        print(f"Error unpacking '{archive_path}': {e}")
        return None

def raw_data_info(raw_data):
    numeric_cols = [f.name for f in raw_data.schema
                    if isinstance(f.dataType, (DoubleType, LongType, IntegerType))]

    print(f"Total rows: {raw_data.count()}")
    print("\nNull values per column:")
    raw_data.select([count(when(col(c).isNull(), c)).alias(c) for c in raw_data.columns]).show()

    if len(numeric_cols) > 0:
        print("\nSummary - numeric columns:")
        raw_data.select(*numeric_cols).describe().show()

def cast_to_timestamp(df, columns):
    for col_name in columns:
        if col_name in df.columns and df.schema[col_name].dataType not in (TimestampType(), DateType()):
            df = df.withColumn(col_name, to_timestamp(col(col_name)))
    return df

def cast_to_double(df, columns):
    for col_name in columns:
        if col_name in df.columns and df.schema[col_name].dataType != DoubleType():
            df = df.withColumn(col_name, col(col_name).cast(DoubleType()))
    return df

def cast_to_long(df, columns):
    for col_name in columns:
        if col_name in df.columns and df.schema[col_name].dataType != LongType():
            df = df.withColumn(col_name, col(col_name).cast(LongType()))
    return df

### Spark Session initialization

In [5]:
spark = SparkSession.builder \
    .appName("CouponABTest") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()
    # .config("spark.sql.shuffle.partitions", "100") \


print("Spark Session initialized!")

Spark Session initialized!


### Load and inspect raw data

In [6]:
order_url = "https://data-architect-test-source.s3-sa-east-1.amazonaws.com/order.json.gz"
consumer_url = "https://data-architect-test-source.s3-sa-east-1.amazonaws.com/consumer.csv.gz"
restaurant_url = "https://data-architect-test-source.s3-sa-east-1.amazonaws.com/restaurant.csv.gz"
ab_test_ref_url = "https://data-architect-test-source.s3-sa-east-1.amazonaws.com/ab_test_ref.tar.gz"

#### Make download dir

In [7]:
download_dir = "./download_dir"
if not os.path.exists("./download_dir"):
    os.makedirs("./download_dir")

raw_parquet_dir = "./raw_parquet"
if not os.path.exists("./raw_parquet"):
    os.makedirs("./raw_parquet")

#### orders

In [None]:
local_orders_path = os.path.join(download_dir, "order.json.gz")
orders_parquet_path = os.path.join(raw_parquet_dir, "orders.parquet")

if download_data(order_url, local_orders_path):
    orders_df_raw = read_data(local_orders_path,
                          spark.read.json,
                          orders_parquet_path,
                          ["order_created_at"])
else:
    orders_df_raw = None

In [None]:
raw_data_info(orders_df_raw)

#### consumer

In [16]:
local_consumers_path = os.path.join(download_dir, "consumer.csv.gz")
consumers_parquet_path = os.path.join(raw_parquet_dir, "consumers.parquet")

if download_data(consumer_url, local_consumers_path):
    consumers_df_raw = read_data(local_consumers_path,
                             spark.read.csv,
                             consumers_parquet_path,
                             header=True)
else:
    consumers_df_raw = None

'consumer.csv.gz' already exists.
Reading 'consumer_csv_gz_df' from raw source: ./download_dir/consumer.csv.gz
Initial load of 'consumer_csv_gz_df' from raw source completed.
Saving 'consumer_csv_gz_df' to raw Parquet: ./raw_parquet/consumers.parquet
Successfully saved 'consumer_csv_gz_df' to raw Parquet.

DataFrame 'consumer_csv_gz_df' from Parquet (raw):
root
 |-- customer_id: string (nullable = true)
 |-- language: string (nullable = true)
 |-- created_at: string (nullable = true)
 |-- active: string (nullable = true)
 |-- customer_name: string (nullable = true)
 |-- customer_phone_area: string (nullable = true)
 |-- customer_phone_number: string (nullable = true)

+--------------------+--------+--------------------+------+-------------+-------------------+---------------------+
|         customer_id|language|          created_at|active|customer_name|customer_phone_area|customer_phone_number|
+--------------------+--------+--------------------+------+-------------+------------------

In [17]:
raw_data_info(consumers_df_raw)

Total rows: 806156

Null values per column:
+-----------+--------+----------+------+-------------+-------------------+---------------------+
|customer_id|language|created_at|active|customer_name|customer_phone_area|customer_phone_number|
+-----------+--------+----------+------+-------------+-------------------+---------------------+
|          0|       0|         0|     0|            0|                  0|                    0|
+-----------+--------+----------+------+-------------+-------------------+---------------------+



#### restaurant

In [18]:
local_restaurants_path = os.path.join(download_dir, "restaurant.csv.gz")
restaurants_parquet_path = os.path.join(raw_parquet_dir, "restaurants.parquet")

if download_data(restaurant_url, local_restaurants_path):
    restaurants_df_raw = read_data(local_restaurants_path,
                               spark.read.csv,
                               restaurants_parquet_path,
                               header=True)
else:
    restaurants_df_raw = None

'restaurant.csv.gz' already exists.
Reading 'restaurant_csv_gz_df' from raw source: ./download_dir/restaurant.csv.gz
Initial load of 'restaurant_csv_gz_df' from raw source completed.
Saving 'restaurant_csv_gz_df' to raw Parquet: ./raw_parquet/restaurants.parquet
Successfully saved 'restaurant_csv_gz_df' to raw Parquet.

DataFrame 'restaurant_csv_gz_df' from Parquet (raw):
root
 |-- id: string (nullable = true)
 |-- created_at: string (nullable = true)
 |-- enabled: string (nullable = true)
 |-- price_range: string (nullable = true)
 |-- average_ticket: string (nullable = true)
 |-- takeout_time: string (nullable = true)
 |-- delivery_time: string (nullable = true)
 |-- minimum_order_value: string (nullable = true)
 |-- merchant_zip_code: string (nullable = true)
 |-- merchant_city: string (nullable = true)
 |-- merchant_state: string (nullable = true)
 |-- merchant_country: string (nullable = true)

+--------------------+--------------------+-------+-----------+--------------+---------

In [19]:
raw_data_info(restaurants_df_raw)

Total rows: 7292

Null values per column:
+---+----------+-------+-----------+--------------+------------+-------------+-------------------+-----------------+-------------+--------------+----------------+
| id|created_at|enabled|price_range|average_ticket|takeout_time|delivery_time|minimum_order_value|merchant_zip_code|merchant_city|merchant_state|merchant_country|
+---+----------+-------+-----------+--------------+------------+-------------+-------------------+-----------------+-------------+--------------+----------------+
|  0|         0|      0|          0|             0|           0|            1|                 95|                0|            0|             0|               0|
+---+----------+-------+-----------+--------------+------------+-------------+-------------------+-----------------+-------------+--------------+----------------+



#### ab_test_ref

In [20]:
local_ab_test_tar_path = os.path.join(download_dir, "ab_test_ref.tar.gz")
local_ab_test_path = os.path.join(download_dir, "ab_test_ref.csv")
ab_test_ref_parquet_path = os.path.join(raw_parquet_dir, "ab_test_ref.parquet")
ab_test_df_raw = None

if download_data(ab_test_ref_url, local_ab_test_tar_path):
    expected_ab_test_extensions = ['.csv', '.json']
    extracted_ab_test_file = extract_data(
        local_ab_test_tar_path,
        download_dir,
        expected_extensions=expected_ab_test_extensions
    )

    if extracted_ab_test_file:
        if extracted_ab_test_file.endswith('.csv'):
            ab_test_df_raw = read_data(extracted_ab_test_file,
                                   spark.read.csv,
                                   ab_test_ref_parquet_path,
                                   header=True,
                                   inferSchema=True)
        elif extracted_ab_test_file.endswith('.json'):
            ab_test_df_raw = read_data(extracted_ab_test_file,
                                   spark.read.json,
                                   ab_test_ref_parquet_path)
        else:
            print(f"File format '{extracted_ab_test_file}' not supported for direct Spark loading.")
    else:
        print("Could not find a compatible data file in the A/B archive.")
else:
    print("ab_test_ref.tar.gz file download failed.")

'ab_test_ref.tar.gz' already exists.
Extracting 'ab_test_ref.tar.gz' to './download_dir'...
File 'ab_test_ref.csv' already extracted and is valid. Skipping extraction.
Reading 'ab_test_ref_csv_df' from raw source: ./download_dir/ab_test_ref.csv
Initial load of 'ab_test_ref_csv_df' from raw source completed.
Saving 'ab_test_ref_csv_df' to raw Parquet: ./raw_parquet/ab_test_ref.parquet
Successfully saved 'ab_test_ref_csv_df' to raw Parquet.

DataFrame 'ab_test_ref_csv_df' from Parquet (raw):
root
 |-- customer_id: string (nullable = true)
 |-- is_target: string (nullable = true)

+--------------------+---------+
|         customer_id|is_target|
+--------------------+---------+
|755e1fa18f25caec5...|   target|
|b821aa8372b8e5b82...|  control|
|d425d6ee4c9d4e211...|  control|
|6a7089eea0a5dc294...|   target|
|dad6b7e222bab31c0...|  control|
+--------------------+---------+
only showing top 5 rows



In [21]:
raw_data_info(ab_test_df_raw)

Total rows: 806467

Null values per column:
+-----------+---------+
|customer_id|is_target|
+-----------+---------+
|          0|        0|
+-----------+---------+



### Cleaning and Enriching

#### Utils

#### orders

In [None]:
orders_df = orders_df_raw

cols_timestamp = ["order_created_at", "order_scheduled_date"]
cols_double = ["delivery_address_latitude",
               "delivery_address_longitude",
               "merchant_latitude",
               "merchant_longitude"]

orders_df = cast_to_timestamp(orders_df, cols_timestamp)
orders_df = cast_to_double(orders_df, cols_double)

orders_df.printSchema()

In [None]:
orders_df.select("items") \
.where("customer_name == 'GUSTAVO' and order_created_at < '2019-01-17'").collect()[3]