In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder \
    .appName("PySpark Docker Expedia Hotel Recommendations") \
    .getOrCreate()

In [2]:
# Objetivo: Prever qual grupo de hotel (hotel cluster) um usuário vai reservar,
# com base nos eventos de busca e atributos do usuário registrados nos logs da Expedia.

df = spark.read.csv("./train.csv", header=True, inferSchema=True)

# df.show()

num_linhas = df.count()
print(f"Número de linhas no DataFrame: {num_linhas}")

Número de linhas no DataFrame: 37670293


In [3]:
# Entender a estrutura do dataset
df.printSchema()

root
 |-- date_time: timestamp (nullable = true)
 |-- site_name: integer (nullable = true)
 |-- posa_continent: integer (nullable = true)
 |-- user_location_country: integer (nullable = true)
 |-- user_location_region: integer (nullable = true)
 |-- user_location_city: integer (nullable = true)
 |-- orig_destination_distance: double (nullable = true)
 |-- user_id: integer (nullable = true)
 |-- is_mobile: integer (nullable = true)
 |-- is_package: integer (nullable = true)
 |-- channel: integer (nullable = true)
 |-- srch_ci: date (nullable = true)
 |-- srch_co: date (nullable = true)
 |-- srch_adults_cnt: integer (nullable = true)
 |-- srch_children_cnt: integer (nullable = true)
 |-- srch_rm_cnt: integer (nullable = true)
 |-- srch_destination_id: integer (nullable = true)
 |-- srch_destination_type_id: integer (nullable = true)
 |-- is_booking: integer (nullable = true)
 |-- cnt: integer (nullable = true)
 |-- hotel_continent: integer (nullable = true)
 |-- hotel_country: integer (n

In [4]:
# Criar temp view para análises SQL
df.createOrReplaceTempView("expedia")

In [5]:
# Criar temp view para dar suporte as análises com o valor total de reservas
df_reservas_totais = spark.sql("""
    SELECT COUNT(*) AS total
    FROM expedia
    WHERE is_booking = 1
""")

df_reservas_totais.createOrReplaceTempView("total_reservations")

In [6]:
# Distribuição das reservas por continente/país/região - qual tem maior concentração de reservas?
# Granularidade alta

spark.sql("""
    WITH country_reservations AS (
        SELECT hotel_continent, hotel_country, hotel_market, COUNT(*) AS total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1,2,3
    )
    SELECT cr.hotel_continent,
           cr.hotel_country, 
           cr.hotel_market,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC
""").show()

+---------------+-------------+------------+------------------+-------------------+
|hotel_continent|hotel_country|hotel_market|total_reservations|percentage_of_total|
+---------------+-------------+------------+------------------+-------------------+
|              2|           50|         628|            130374|               4.34|
|              2|           50|         675|             96262|               3.21|
|              2|           50|         365|             68722|               2.29|
|              6|           70|          19|             56677|               1.89|
|              2|           50|        1230|             51112|                1.7|
|              2|           50|         637|             45395|               1.51|
|              2|           50|         701|             44361|               1.48|
|              2|           50|         682|             44022|               1.47|
|              2|           50|         191|             38506|             

In [7]:
# Distribuição das reservas por continente/país - qual tem maior concentração de reservas?
# Granularidade baixa

spark.sql("""
    WITH country_reservations AS (
        SELECT hotel_continent, hotel_country, COUNT(*) AS total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1,2
    )
    SELECT cr.hotel_continent,
           cr.hotel_country, 
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC
""").show()

+---------------+-------------+------------------+-------------------+
|hotel_continent|hotel_country|total_reservations|percentage_of_total|
+---------------+-------------+------------------+-------------------+
|              2|           50|           1616055|              53.86|
|              2|          198|            197228|               6.57|
|              6|           70|            102651|               3.42|
|              6|          105|             93326|               3.11|
|              4|            8|             80990|                2.7|
|              6|          204|             71709|               2.39|
|              6|           77|             60720|               2.02|
|              6|          144|             60504|               2.02|
|              3|          106|             45909|               1.53|
|              3|          182|             44362|               1.48|
|              0|           63|             34712|               1.16|
|     

In [8]:
# Distribuição das reservas por país/região - qual localização do usuário tem a maior índice de reservas?
# Granularidade alta

# Cria o DataFrame com o resultado da consulta
spark.sql("""
    WITH country_reservations AS (
        SELECT user_location_country, user_location_region, COUNT(*) AS total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1,2
    )
    SELECT cr.user_location_country,
           cr.user_location_region,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC
""").show()

+---------------------+--------------------+------------------+-------------------+
|user_location_country|user_location_region|total_reservations|percentage_of_total|
+---------------------+--------------------+------------------+-------------------+
|                   66|                 174|            337487|              11.25|
|                  205|                 354|            141188|               4.71|
|                   66|                 348|            130891|               4.36|
|                   66|                 442|            126091|                4.2|
|                   66|                 220|            110479|               3.68|
|                   66|                 462|             78642|               2.62|
|                  205|                 155|             70260|               2.34|
|                    3|                  50|             66622|               2.22|
|                  205|                 135|             66269|             

In [9]:
# Distribuição das reservas por país/região - qual localização do usuário tem a maior índice de reservas?
# Granularidade baixa

# Cria o DataFrame com o resultado da consulta
spark.sql("""
    WITH country_reservations AS (
        SELECT user_location_country, COUNT(*) AS total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1
    )
    SELECT cr.user_location_country,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC
""").show()

+---------------------+------------------+-------------------+
|user_location_country|total_reservations|percentage_of_total|
+---------------------+------------------+-------------------+
|                   66|           1675105|              55.82|
|                  205|            345625|              11.52|
|                   69|            158944|                5.3|
|                    3|            138103|                4.6|
|                   46|             67071|               2.24|
|                   77|             62543|               2.08|
|                    1|             52266|               1.74|
|                  215|             45526|               1.52|
|                  133|             30537|               1.02|
|                   23|             21676|               0.72|
|                  195|             19730|               0.66|
|                   68|             19939|               0.66|
|                  231|             18516|             

In [10]:
# Comparação entre a localização do usuário e a do hotel
# Os usuários preferem reservar hotéis no mesmo país/região ou se há uma tendência de viajar para outros locais?

spark.sql("""
    WITH country_reservations AS (
        SELECT user_location_country, user_location_region, hotel_country, hotel_market, COUNT(*) as total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1,2,3,4
    )
    SELECT cr.user_location_country,
           cr.user_location_region, 
           cr.hotel_country,
           cr.hotel_market,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC    
""").show()

+---------------------+--------------------+-------------+------------+------------------+-------------------+
|user_location_country|user_location_region|hotel_country|hotel_market|total_reservations|percentage_of_total|
+---------------------+--------------------+-------------+------------+------------------+-------------------+
|                   66|                 174|           50|         628|             30708|               1.02|
|                   66|                 174|           50|         365|             23193|               0.77|
|                   66|                 174|           50|        1230|             16042|               0.53|
|                   66|                 174|           50|         368|             13828|               0.46|
|                   66|                 174|           50|         366|             11971|                0.4|
|                  205|                 354|          198|         397|             11693|               0.39|
|

In [11]:
# Comparação entre a localização do usuário e a do hotel
# Os usuários preferem reservar hotéis no mesmo país ou se há uma tendência de viajar para outros locais?
# Granularidade baixa

spark.sql("""
    WITH country_reservations AS (
        SELECT user_location_country, hotel_country, COUNT(*) as total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1,2
    )
    SELECT cr.user_location_country,
           cr.hotel_country,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM country_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC    
""").show()

+---------------------+-------------+------------------+-------------------+
|user_location_country|hotel_country|total_reservations|percentage_of_total|
+---------------------+-------------+------------------+-------------------+
|                   66|           50|           1294233|              43.13|
|                  205|          198|            141941|               4.73|
|                  205|           50|            122762|               4.09|
|                   66|            8|             51261|               1.71|
|                   66|          198|             41416|               1.38|
|                   69|           70|             38191|               1.27|
|                   69|           50|             27256|               0.91|
|                   66|          105|             26761|               0.89|
|                    3|           50|             25128|               0.84|
|                   66|           70|             23580|               0.79|

In [12]:
# Análise temporal

# Criar colunas úteis de datas 
df = df.withColumn("search_year", year("date_time")) \
       .withColumn("search_month", month("date_time")) \
       .withColumn("search_week", weekofyear("date_time")) \
       .withColumn("day_of_week_ci", dayofweek("srch_ci"))

# Atualizar a temp view com as colunas novas
df.createOrReplaceTempView("expedia")

In [13]:
# Existem reservas com problemas de data (check-in > check-out)?
spark.sql("""
    SELECT COUNT(*) as registros_invalidos
    FROM expedia
    WHERE srch_ci > srch_co
""").show()

+-------------------+
|registros_invalidos|
+-------------------+
|                798|
+-------------------+



In [14]:
# De acordo com a análise existem registros inválidos, então é preciso removê-los

# Filtrar apenas os registros válidos (check-in <= check-out)
df = df.filter(col("srch_ci") <= col("srch_co"))

# Criar coluna para cálculo de duração de estadia
df = df.withColumn("stay_duration", datediff("srch_co", "srch_ci"))

# Atualizar a temp view com as colunas novas
df.createOrReplaceTempView("expedia")

In [15]:
# Verificar volume de buscas por ano e mês
spark.sql("""
    SELECT search_year, search_month, COUNT(*) as total_searches
    FROM expedia
    GROUP BY 1, 2
    ORDER BY total_searches desc
""").show()

+-----------+------------+--------------+
|search_year|search_month|total_searches|
+-----------+------------+--------------+
|       2014|          12|       2926235|
|       2014|           7|       2796314|
|       2014|           8|       2748207|
|       2014|           9|       2741193|
|       2014|          10|       2732761|
|       2014|          11|       2620666|
|       2014|           6|       2181086|
|       2014|           5|       1877741|
|       2014|           4|       1768935|
|       2014|           3|       1740751|
|       2014|           1|       1180808|
|       2014|           2|       1156106|
|       2013|           7|       1091303|
|       2013|           8|       1022005|
|       2013|           6|       1004482|
|       2013|           3|        996982|
|       2013|          10|        970269|
|       2013|           5|        958031|
|       2013|           9|        956577|
|       2013|           4|        938713|
+-----------+------------+--------

In [16]:
# Verificar volume de buscas por semana e dia de checkin,
spark.sql("""
    SELECT search_week, day_of_week_ci, COUNT(*) as total_searches
    FROM expedia
    GROUP BY 1, 2
    ORDER BY total_searches desc
""").show()

+-----------+--------------+--------------+
|search_week|day_of_week_ci|total_searches|
+-----------+--------------+--------------+
|         31|             6|        179609|
|         37|             6|        179051|
|         38|             6|        176652|
|         32|             6|        176248|
|         40|             6|        175996|
|         36|             6|        175235|
|         30|             6|        175135|
|         33|             6|        175018|
|         41|             6|        174507|
|         39|             6|        172546|
|         29|             6|        172297|
|         34|             6|        171093|
|         28|             6|        169400|
|         35|             6|        168623|
|         43|             6|        166336|
|         49|             6|        165105|
|         42|             6|        163454|
|         45|             6|        161097|
|         44|             6|        158028|
|         46|             6|    

In [17]:
# Verificar volume de buscas por dia de checkin,
spark.sql("""
    SELECT day_of_week_ci, COUNT(*) as total_searches
    FROM expedia
    GROUP BY 1
    ORDER BY total_searches desc
""").show()

+--------------+--------------+
|day_of_week_ci|total_searches|
+--------------+--------------+
|             6|       7380262|
|             7|       6311519|
|             5|       5733717|
|             4|       4873614|
|             1|       4646788|
|             2|       4495046|
|             3|       4181461|
+--------------+--------------+



In [18]:
# Análise de perfil e comportamental

# Criando a coluna `is_solo_trip` para verificar viagens sozinhas - isso pode indicar se é a trabalho, passeio de família...
df = df.withColumn('is_solo_trip', when((col('srch_adults_cnt') == 1) & (col('srch_children_cnt') == 0), 1).otherwise(0))

# Atualizar a temp view com as colunas novas
df.createOrReplaceTempView("expedia")

In [19]:
# Identificar o tipo de cliente (viagem em grupo, casal, sozinho...) com média percentual

spark.sql("""
    SELECT 
        ROUND(AVG(is_solo_trip)* 100,2) AS avg_is_solo_trip,
        ROUND(AVG(srch_adults_cnt)* 100,2) AS avg_adults,
        ROUND(AVG(srch_children_cnt)* 100,2) AS avg_children,
        ROUND(AVG(srch_rm_cnt)* 100,2) AS avg_rooms
    FROM expedia
    WHERE is_booking = 1
""").show()

+----------------+----------+------------+---------+
|avg_is_solo_trip|avg_adults|avg_children|avg_rooms|
+----------------+----------+------------+---------+
|           28.91|    187.23|       27.67|   112.96|
+----------------+----------+------------+---------+



In [20]:
# Identificar usuários frequentes vs. eventuais.

spark.sql("""
    SELECT user_id, COUNT(*) AS num_reservas
    FROM expedia
    WHERE is_booking = 1
    GROUP BY user_id
    ORDER BY num_reservas DESC
""").show()

+-------+------------+
|user_id|num_reservas|
+-------+------------+
| 911486|         100|
| 235912|         100|
| 602738|         100|
|1137579|         100|
|1049940|         100|
| 966474|         100|
| 625627|         100|
| 913337|         100|
|1017892|         100|
|1090993|          99|
| 944701|          99|
| 216089|          99|
|1009736|          99|
| 770199|          99|
|1097534|          99|
|  40415|          99|
| 859246|          99|
|  20929|          99|
| 356149|          99|
| 494038|          98|
+-------+------------+
only showing top 20 rows



In [21]:
# Análise de reservas com base em categoria
# Família com crianças: Se houver 1 adulto e pelo menos 1 criança
# Família sem crianças: Se houver mais de 1 adulto e nenhuma criança
# Solo trip: Se houver 1 adulto e nenhuma criança

df = df.withColumn("trip_category", 
                   when((col("srch_adults_cnt") == 1) & (col("srch_children_cnt") > 0), "Family with Children")
                    .when((col("srch_adults_cnt") > 1) & (col("srch_children_cnt") == 0), "Family without Children")
                    .otherwise("Solo Trip"))

# Atualizar a temp view com as colunas novas
df.createOrReplaceTempView("expedia")

In [22]:
# Contabilizar as Reservas por Categoria de Viagem

spark.sql("""
    WITH trip_reservations AS (
        SELECT trip_category, COUNT(*) as total_reservations
        FROM expedia
        WHERE is_booking = 1
        GROUP BY 1
    )
    SELECT cr.trip_category,
           cr.total_reservations, 
           round((cr.total_reservations / tr.total) * 100, 2) AS percentage_of_total
    FROM trip_reservations cr
        CROSS JOIN total_reservations tr
    ORDER BY percentage_of_total DESC    
""").show()

+--------------------+------------------+-------------------+
|       trip_category|total_reservations|percentage_of_total|
+--------------------+------------------+-------------------+
|Family without Ch...|           1563399|               52.1|
|           Solo Trip|           1316243|              43.86|
|Family with Children|            121048|               4.03|
+--------------------+------------------+-------------------+



In [23]:
# Análise se viaja sozinho ou não por canal de venda

spark.sql("""
    SELECT is_solo_trip, channel, COUNT(*) AS total_reservations
    FROM expedia
    WHERE is_booking = 1
    GROUP BY 1,2
    ORDER BY total_reservations DESC;
""").show()

+------------+-------+------------------+
|is_solo_trip|channel|total_reservations|
+------------+-------+------------------+
|           0|      9|           1209474|
|           1|      9|            559309|
|           0|      0|            251896|
|           0|      1|            188330|
|           0|      5|            168816|
|           0|      2|            137945|
|           1|      0|             89461|
|           0|      3|             80511|
|           0|      4|             77661|
|           1|      1|             73924|
|           1|      5|             49808|
|           1|      2|             41241|
|           1|      4|             32800|
|           1|      3|             14414|
|           0|      7|              9392|
|           0|      8|              6598|
|           1|      7|              3071|
|           1|      8|              2717|
|           0|      6|              2118|
|           1|      6|               768|
+------------+-------+------------

In [24]:
# Análise por canal de venda e se a reserva foi em um app mobile

spark.sql("""
    SELECT site_name, channel, is_mobile, COUNT(*) AS total_reservations
    FROM expedia
    WHERE is_booking = 1
    GROUP BY 1,2,3
    ORDER BY total_reservations DESC;
""").show()

+---------+-------+---------+------------------+
|site_name|channel|is_mobile|total_reservations|
+---------+-------+---------+------------------+
|        2|      9|        0|           1148494|
|        2|      0|        0|            201327|
|        2|      1|        0|            145065|
|        2|      9|        1|            124315|
|       34|      9|        0|            116700|
|        2|      2|        0|            112445|
|       11|      9|        0|            100628|
|       37|      9|        0|             70155|
|        2|      3|        0|             58060|
|        2|      4|        0|             46644|
|       24|      9|        0|             35868|
|        2|      5|        0|             32540|
|        2|      0|        1|             31018|
|       24|      1|        0|             27347|
|       11|      5|        0|             26750|
|       13|      9|        0|             25773|
|       11|      0|        0|             25427|
|       37|      4| 

In [25]:
# Análise de reservas por tipo de destino

spark.sql("""
    SELECT srch_destination_id, srch_destination_type_id, COUNT(*) AS total_reservations
    FROM expedia
    WHERE is_booking = 1
    GROUP BY 1,2
    ORDER BY total_reservations DESC;
""").show()

+-------------------+------------------------+------------------+
|srch_destination_id|srch_destination_type_id|total_reservations|
+-------------------+------------------------+------------------+
|               8250|                       1|             97678|
|               8267|                       1|             57680|
|               8279|                       1|             30320|
|               8253|                       1|             28297|
|               8745|                       1|             25713|
|              12206|                       6|             23258|
|               8268|                       1|             23168|
|               8230|                       1|             21976|
|               8254|                       1|             19032|
|               8791|                       1|             18342|
|               8260|                       1|             18273|
|               8291|                       1|             17621|
|         

In [26]:
df.printSchema()

root
 |-- date_time: timestamp (nullable = true)
 |-- site_name: integer (nullable = true)
 |-- posa_continent: integer (nullable = true)
 |-- user_location_country: integer (nullable = true)
 |-- user_location_region: integer (nullable = true)
 |-- user_location_city: integer (nullable = true)
 |-- orig_destination_distance: double (nullable = true)
 |-- user_id: integer (nullable = true)
 |-- is_mobile: integer (nullable = true)
 |-- is_package: integer (nullable = true)
 |-- channel: integer (nullable = true)
 |-- srch_ci: date (nullable = true)
 |-- srch_co: date (nullable = true)
 |-- srch_adults_cnt: integer (nullable = true)
 |-- srch_children_cnt: integer (nullable = true)
 |-- srch_rm_cnt: integer (nullable = true)
 |-- srch_destination_id: integer (nullable = true)
 |-- srch_destination_type_id: integer (nullable = true)
 |-- is_booking: integer (nullable = true)
 |-- cnt: integer (nullable = true)
 |-- hotel_continent: integer (nullable = true)
 |-- hotel_country: integer (n

In [27]:
# Verificar colunas que são categóricas e removê-las

columns_to_keep = [
    col_name for col_name, dtype in df.dtypes 
    if dtype in ['int']
]

# Criar um novo DataFrame com apenas as colunas 'int' para machine learning
df_int = df.select([col(c) for c in columns_to_keep])

In [28]:
# Filtrar apenas usuários que reservaram (is_booking == 1)
df_booked = df_int.filter(df["is_booking"] == 1)

# Atualizar a temp view com as colunas novas
df_booked.createOrReplaceTempView("expedia_final")

In [29]:
df_booked.printSchema()

root
 |-- site_name: integer (nullable = true)
 |-- posa_continent: integer (nullable = true)
 |-- user_location_country: integer (nullable = true)
 |-- user_location_region: integer (nullable = true)
 |-- user_location_city: integer (nullable = true)
 |-- user_id: integer (nullable = true)
 |-- is_mobile: integer (nullable = true)
 |-- is_package: integer (nullable = true)
 |-- channel: integer (nullable = true)
 |-- srch_adults_cnt: integer (nullable = true)
 |-- srch_children_cnt: integer (nullable = true)
 |-- srch_rm_cnt: integer (nullable = true)
 |-- srch_destination_id: integer (nullable = true)
 |-- srch_destination_type_id: integer (nullable = true)
 |-- is_booking: integer (nullable = true)
 |-- cnt: integer (nullable = true)
 |-- hotel_continent: integer (nullable = true)
 |-- hotel_country: integer (nullable = true)
 |-- hotel_market: integer (nullable = true)
 |-- hotel_cluster: integer (nullable = true)
 |-- search_year: integer (nullable = true)
 |-- search_month: integ

In [30]:
# Verificar relação de todas as colunas com is_solo_trip
spark.sql("""
    SELECT 
        site_name,
        posa_continent,
        user_location_country,
        user_location_region,
        user_location_city,
        user_id,
        is_mobile,
        is_package,
        channel,
        srch_adults_cnt,
        srch_children_cnt,
        srch_rm_cnt,
        srch_destination_id,
        srch_destination_type_id,
        is_booking,
        cnt,
        hotel_continent,
        hotel_country,
        hotel_market,
        hotel_cluster,
        search_year,
        search_month,
        search_week,
        day_of_week_ci,
        stay_duration,
        is_solo_trip,
        COUNT(*) AS total_reservations
    FROM expedia_final
    WHERE is_solo_trip=1
    GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26
    ORDER BY total_reservations DESC
""").show()

+---------+--------------+---------------------+--------------------+------------------+-------+---------+----------+-------+---------------+-----------------+-----------+-------------------+------------------------+----------+---+---------------+-------------+------------+-------------+-----------+------------+-----------+--------------+-------------+------------+------------------+
|site_name|posa_continent|user_location_country|user_location_region|user_location_city|user_id|is_mobile|is_package|channel|srch_adults_cnt|srch_children_cnt|srch_rm_cnt|srch_destination_id|srch_destination_type_id|is_booking|cnt|hotel_continent|hotel_country|hotel_market|hotel_cluster|search_year|search_month|search_week|day_of_week_ci|stay_duration|is_solo_trip|total_reservations|
+---------+--------------+---------------------+--------------------+------------------+-------+---------+----------+-------+---------------+-----------------+-----------+-------------------+------------------------+----------

In [31]:
# Verificar valores nulos em todas as colunas
nulos = df_booked.select([
    sum(when(col(c).isNull(), 1).otherwise(0)).alias(f"{c}_nulls") for c in df_booked.columns
])

print(nulos.count())

1


In [32]:
# Machine Learning

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [33]:
# Obter todas as colunas exceto a target e associados
feature_cols = [field.name for field in df_booked.schema.fields if field.name not in ["is_solo_trip", "srch_adults_cnt", "srch_children_cnt"]]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
df_features = assembler.transform(df_booked).select("features", "is_solo_trip")

# Separar treino e teste
train_data, test_data = df_features.randomSplit([0.7, 0.3], seed=42)

# Treinar modelo
dt = DecisionTreeClassifier(labelCol="is_solo_trip", featuresCol="features", maxDepth=5)
model = dt.fit(train_data)

# Fazer previsões
predictions = model.transform(test_data)

# Avaliar
evaluator = MulticlassClassificationEvaluator(
    labelCol="is_solo_trip", predictionCol="prediction", metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions)
print(f"Acurácia da árvore de decisão: {accuracy:.2f}")


Acurácia da árvore de decisão: 0.72


In [34]:
print(model.toDebugString)

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_842d25127b18, depth=5, numNodes=17, numClasses=2, numFeatures=23
  If (feature 9 <= 1.5)
   If (feature 21 <= 5.5)
    If (feature 22 <= 2.5)
     If (feature 7 <= 0.5)
      Predict: 0.0
     Else (feature 7 > 0.5)
      Predict: 1.0
    Else (feature 22 > 2.5)
     Predict: 0.0
   Else (feature 21 > 5.5)
    Predict: 0.0
  Else (feature 9 > 1.5)
   If (feature 9 <= 4.5)
    Predict: 0.0
   Else (feature 9 > 4.5)
    If (feature 9 <= 5.5)
     Predict: 0.0
    Else (feature 9 > 5.5)
     If (feature 0 <= 13.5)
      If (feature 9 <= 7.5)
       Predict: 0.0
      Else (feature 9 > 7.5)
       Predict: 1.0
     Else (feature 0 > 13.5)
      Predict: 0.0



In [35]:
for idx, name in enumerate(feature_cols):
    print(f"feature {idx}: {name}")

feature 0: site_name
feature 1: posa_continent
feature 2: user_location_country
feature 3: user_location_region
feature 4: user_location_city
feature 5: user_id
feature 6: is_mobile
feature 7: is_package
feature 8: channel
feature 9: srch_rm_cnt
feature 10: srch_destination_id
feature 11: srch_destination_type_id
feature 12: is_booking
feature 13: cnt
feature 14: hotel_continent
feature 15: hotel_country
feature 16: hotel_market
feature 17: hotel_cluster
feature 18: search_year
feature 19: search_month
feature 20: search_week
feature 21: day_of_week_ci
feature 22: stay_duration


In [36]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import col

# Remoção de features relacionadas
excluded_cols = ["is_solo_trip", "srch_adults_cnt", "srch_children_cnt"]
feature_cols = [field.name for field in df_booked.schema.fields if field.name not in excluded_cols]

# Preparação dos dados
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
df_features = assembler.transform(df_booked).select("features", "is_solo_trip")

# Split treino-teste
train_data, test_data = df_features.randomSplit([0.7, 0.3], seed=42)

# Modelo e Grid Search
dt = DecisionTreeClassifier(labelCol="is_solo_trip", featuresCol="features")

param_grid = ParamGridBuilder() \
    .addGrid(dt.maxDepth, [3, 5, 7]) \
    .addGrid(dt.minInstancesPerNode, [1, 5, 10]) \
    .addGrid(dt.impurity, ["gini", "entropy"]) \
    .build()

evaluator = MulticlassClassificationEvaluator(
    labelCol="is_solo_trip", 
    predictionCol="prediction", 
    metricName="accuracy"  # Ou "f1" para dados desbalanceados
)

cv = CrossValidator(
    estimator=dt,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=3,
    seed=42
)

In [37]:
# Treinamento
cv_model = cv.fit(train_data)
best_model = cv_model.bestModel
predictions = best_model.transform(test_data)

# Avaliação
accuracy = evaluator.evaluate(predictions)
print(f"\nAcurácia do melhor modelo: {accuracy:.4f}")

# Métricas adicionais
evaluator_auc = BinaryClassificationEvaluator(
    labelCol="is_solo_trip", 
    metricName="areaUnderROC"
)
auc = evaluator_auc.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")

print("\nMatriz de Confusão:")
predictions.groupBy("is_solo_trip", "prediction").count().show()

# Parâmetros do melhor modelo
print("\nMelhores hiperparâmetros:")
print(f"maxDepth: {best_model.getMaxDepth()}")
print(f"minInstancesPerNode: {best_model.getMinInstancesPerNode()}")
print(f"impurity: {best_model.getImpurity()}")


Acurácia do melhor modelo: 0.7194
AUC-ROC: 0.6138

Matriz de Confusão:
+------------+----------+------+
|is_solo_trip|prediction| count|
+------------+----------+------+
|           1|       0.0|244499|
|           0|       0.0|631892|
|           1|       1.0| 16308|
|           0|       1.0|  8301|
+------------+----------+------+


Melhores hiperparâmetros:
maxDepth: 7
minInstancesPerNode: 5
impurity: gini
