# *PROYECTO ML - TRAIN*

## Importamos bibliotecas necesarias para llevar a cabo el proyecto:

In [1]:
import sqlite3
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
#from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
#from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.model_selection import RandomizedSearchCV
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error
import pickle

In [2]:
def cat_var(df, cols):
    '''
    Return: a Pandas dataframe object with the following columns:
        - "categorical_variable" => every categorical variable include as an input parameter (string).
        - "number_of_possible_values" => the amount of unique values that can take a given categorical variable (integer).
        - "values" => a list with the posible unique values for every categorical variable (list).

    Input parameters:
        - df -> Pandas dataframe object: a dataframe with categorical variables.
        - cols -> list object: a list with the name (string) of every categorical variable to analyse.
    '''
    cat_list = []
    for col in cols:
        cat = df[col].unique()
        cat_num = len(cat)
        cat_dict = {"categorical_variable":col,
                    "number_of_possible_values":cat_num,
                    "values":cat}
        cat_list.append(cat_dict)
    df = pd.DataFrame(cat_list).sort_values(by="number_of_possible_values", ascending=False)
    return df.reset_index(drop=True)

## Leo la base de datos usando SQLite3

### 1º. Creo un conector al fichero .db

In [3]:
conexion = sqlite3.connect('../data/diamonds_train.db')

### 2º. Creo la query

Siempre que usemos este tipo de biblioteca debemos usar "sqlite_master"

In [4]:
sql_query = """SELECT name FROM sqlite_master  
  WHERE type='table';"""

### 3º. Instancio un cursor

In [5]:
cursor = conexion.cursor()

### 4º. Ejecuto las queries usanto el cursor

In [6]:
tables = cursor.execute(sql_query)

In [7]:
print(cursor.execute("""SELECT * FROM diamonds_dimensions;"""))
df = cursor.fetchall()

<sqlite3.Cursor object at 0x12ead4e40>


### 5º. Creo los Pandas DF's usando "pd.read_sql_query"

In [8]:
diamonds_dimensions_df = pd.read_sql_query("SELECT * FROM diamonds_dimensions", conexion)
diamonds_dimensions_df

Unnamed: 0,index_id,depth,table,x,y,z
0,5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91...,62.4,58.0,6.83,6.79,4.25
1,6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d...,63.0,57.0,4.35,4.38,2.75
2,d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f...,65.5,55.0,5.62,5.53,3.65
3,4e07408562bedb8b60ce05c1decfe3ad16b72230967de0...,63.8,56.0,4.68,4.72,3.00
4,4b227777d4dd1fc61c6f884f48641d02b4d121d3fd328c...,60.5,59.0,6.55,6.51,3.95
...,...,...,...,...,...,...
40450,f0bc79169405ebeb24e308055156b946ffd819db9b4f75...,62.7,57.0,7.10,7.04,4.43
40451,339916a23bf22b052b54cb2a9b36ee8418c1c68b46acad...,57.1,60.0,8.31,8.25,4.73
40452,46957922b99954654c1deb8d854c3f069bf118b2ce9415...,62.7,56.0,6.37,6.42,4.01
40453,9d733392d362d5c6f1d9b9659b601c7d4b5a1c1c8df579...,61.9,54.3,4.45,4.47,2.76


In [9]:
diamonds_transactional_df = pd.read_sql_query("SELECT * FROM diamonds_transactional", conexion)
diamonds_transactional_df

Unnamed: 0,index_id,price,city_id,carat
0,5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91...,4268,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,1.21
1,6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d...,505,89c7286890f7347ab235234e74d406596a127ae3679042...,0.32
2,d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f...,2686,2bd25cd960aba8b706e2b67f2bb38b750ee5384b0e9883...,0.71
3,4e07408562bedb8b60ce05c1decfe3ad16b72230967de0...,738,89c7286890f7347ab235234e74d406596a127ae3679042...,0.41
4,4b227777d4dd1fc61c6f884f48641d02b4d121d3fd328c...,4882,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,1.02
...,...,...,...,...
40450,f0bc79169405ebeb24e308055156b946ffd819db9b4f75...,10070,ca3aa21a5b70c3e88cc6336682c8c7da928a0c66a5ead4...,1.34
40451,339916a23bf22b052b54cb2a9b36ee8418c1c68b46acad...,12615,e9c722cbefc2f055ae60b4e2cbe73a2d99537eab0c37f3...,2.02
40452,46957922b99954654c1deb8d854c3f069bf118b2ce9415...,5457,89c7286890f7347ab235234e74d406596a127ae3679042...,1.01
40453,9d733392d362d5c6f1d9b9659b601c7d4b5a1c1c8df579...,456,89c7286890f7347ab235234e74d406596a127ae3679042...,0.33


In [10]:
diamonds_properties_df = pd.read_sql_query("SELECT * FROM diamonds_properties", conexion)
diamonds_properties_df

Unnamed: 0,index_id,cut_id,color_id,clarity_id
0,5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91...,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...
1,6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d...,388655e25e91872329272fc10128ef5354b3b19a05d7e8...,44bd7ae60f478fae1061e11a7739f4b94d1daf917982d3...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...
2,d4735e3a265e16eee03f59718b9b5d03019c07d8b6c51f...,f7b19afcde965ea4942b878d266f89f8ba9a5a833e60f7...,333e0a1e27815d0ceee55c473fe3dc93d56c63e3bee2b3...,ef736c1f91cd1900c3d9cde925b1bf4f013adc0211a9ee...
3,4e07408562bedb8b60ce05c1decfe3ad16b72230967de0...,c939327ca16dcf97ca32521d8b834bf1de16573d21deda...,3f39d5c348e5b79d06e842c114e6cc571583bbf44e4b0e...,bd4f4a250da88f87729febc739ae97f439a14f9d38f0e3...
4,4b227777d4dd1fc61c6f884f48641d02b4d121d3fd328c...,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,333e0a1e27815d0ceee55c473fe3dc93d56c63e3bee2b3...,bd4f4a250da88f87729febc739ae97f439a14f9d38f0e3...
...,...,...,...,...
40450,f0bc79169405ebeb24e308055156b946ffd819db9b4f75...,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,333e0a1e27815d0ceee55c473fe3dc93d56c63e3bee2b3...,ef736c1f91cd1900c3d9cde925b1bf4f013adc0211a9ee...
40451,339916a23bf22b052b54cb2a9b36ee8418c1c68b46acad...,c939327ca16dcf97ca32521d8b834bf1de16573d21deda...,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,03c358cbd92e83278fd523f58dc6a9b4b198d00728af65...
40452,46957922b99954654c1deb8d854c3f069bf118b2ce9415...,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,44bd7ae60f478fae1061e11a7739f4b94d1daf917982d3...,bd4f4a250da88f87729febc739ae97f439a14f9d38f0e3...
40453,9d733392d362d5c6f1d9b9659b601c7d4b5a1c1c8df579...,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,ef736c1f91cd1900c3d9cde925b1bf4f013adc0211a9ee...


In [11]:
diamonds_cut_df = pd.read_sql_query("SELECT * FROM diamonds_cut", conexion)
diamonds_cut_df

Unnamed: 0,cut_id,cut
0,388655e25e91872329272fc10128ef5354b3b19a05d7e8...,Very Good
1,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,Ideal
2,c939327ca16dcf97ca32521d8b834bf1de16573d21deda...,Good
3,de88c121a82a06352bf1aaceba20578356408a334ba046...,Premium
4,f7b19afcde965ea4942b878d266f89f8ba9a5a833e60f7...,Fair


In [12]:
diamonds_color_df = pd.read_sql_query("SELECT * FROM diamonds_color", conexion)
diamonds_color_df

Unnamed: 0,color_id,color
0,333e0a1e27815d0ceee55c473fe3dc93d56c63e3bee2b3...,G
1,3f39d5c348e5b79d06e842c114e6cc571583bbf44e4b0e...,D
2,44bd7ae60f478fae1061e11a7739f4b94d1daf917982d3...,H
3,6da43b944e494e885e69af021f93c6d9331c78aa228084...,J
4,a83dd0ccbffe39d071cc317ddf6e97f5c6b1c87af91919...,I
5,a9f51566bd6705f7ea6ad54bb9deb449f795582d6529a0...,E
6,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,F


In [13]:
diamonds_clarity_df = pd.read_sql_query("SELECT * FROM diamonds_clarity", conexion)
diamonds_clarity_df

Unnamed: 0,clarity_id,clarity
0,03c358cbd92e83278fd523f58dc6a9b4b198d00728af65...,SI2
1,3f9db06236e9719b61c826b612b882fa702ec81574e44c...,VVS2
2,66686ae1f0c9c400ba32dc600a34ff0aa173395bcbc2d8...,VVS1
3,7020fd7aaf1656dea5c0c0c3d4bb5a28ebf6243fba95d3...,IF
4,bd4f4a250da88f87729febc739ae97f439a14f9d38f0e3...,SI1
5,c2818bc4e5ec4ae4a357a0df6fed73652e169ec676f7d4...,I1
6,ef736c1f91cd1900c3d9cde925b1bf4f013adc0211a9ee...,VS1
7,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,VS2


In [14]:
diamonds_city_df = pd.read_sql_query("SELECT * FROM diamonds_city", conexion)
diamonds_city_df

Unnamed: 0,city_id,city
0,0013c01fe0d094209b8bd3d23b8f96dbabcd01ddd2a039...,Amsterdam
1,1e73b1647343b286269d517e6f07e6e07ccef10cd7b785...,Zurich
2,2bd25cd960aba8b706e2b67f2bb38b750ee5384b0e9883...,Las Vegas
3,5a59ef2e40c5f89adb8c0c6ad0a8019b0e252fff530cf2...,New York City
4,5dd272b4f316b776a7b8e3d0894b37e1e42be3d5d3b204...,Paris
5,68371d5bdaab31b5cbc25fbf94b8f9c1238294fc50e715...,Tel Aviv
6,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,Dubai
7,89c7286890f7347ab235234e74d406596a127ae3679042...,Kimberly
8,ba04d229e496b8383a4df91f5e77c194a95cf1a069b0b2...,Surat
9,ca3aa21a5b70c3e88cc6336682c8c7da928a0c66a5ead4...,Antwerp


## Unimos todas las tablas en una única tabla

In [15]:
union_data = diamonds_dimensions_df.merge(diamonds_transactional_df)
union_data
union_data = union_data.merge(diamonds_properties_df)
union_data
union_data = union_data.merge(diamonds_cut_df)
union_data
union_data = union_data.merge(diamonds_color_df)
union_data
union_data = union_data.merge(diamonds_clarity_df)
union_data
union_data = union_data.merge(diamonds_city_df)
union_data

Unnamed: 0,index_id,depth,table,x,y,z,price,city_id,carat,cut_id,color_id,clarity_id,cut,color,clarity,city
0,5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91...,62.4,58.0,6.83,6.79,4.25,4268,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,1.21,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,Premium,J,VS2,Dubai
1,41667f6e2629360aecaf00b20f8732e3310417ebd54b24...,61.6,58.0,6.40,6.35,3.93,3513,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,1.02,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,Premium,J,VS2,Dubai
2,01f8667f50d52677bea23231a74156e4f92360d7bc3db6...,62.3,58.0,5.86,5.80,3.63,1792,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,0.77,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,Premium,J,VS2,Dubai
3,c3867352aab641358faec75d733af012dbe2259a014ea8...,59.6,60.0,7.58,7.48,4.49,7553,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,1.51,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,Premium,J,VS2,Dubai
4,0da4b104c4d8589fcb96a03aa0787549a2631935b0f499...,60.2,62.0,5.40,5.33,3.23,1176,6c425048aa7badd9d84615bd8620ca1864efd81cfdb69d...,0.57,de88c121a82a06352bf1aaceba20578356408a334ba046...,6da43b944e494e885e69af021f93c6d9331c78aa228084...,f0b2a1d0db08cc64f85d74f1d15c2191e0e49039f4d8f2...,Premium,J,VS2,Dubai
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40450,3df6f3da962b819057888dbbe3cf4d11f9a59261ba0af7...,62.2,54.0,5.24,5.27,3.27,2729,ba04d229e496b8383a4df91f5e77c194a95cf1a069b0b2...,0.54,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,7020fd7aaf1656dea5c0c0c3d4bb5a28ebf6243fba95d3...,Ideal,F,IF,Surat
40451,e5bc95d36abda5bfd67641eca60f2ab92f37b27c0397e0...,61.9,54.0,5.22,5.25,3.24,2802,ba04d229e496b8383a4df91f5e77c194a95cf1a069b0b2...,0.53,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,7020fd7aaf1656dea5c0c0c3d4bb5a28ebf6243fba95d3...,Ideal,F,IF,Surat
40452,5ad0fcca0db9b3d399e31a3d909a1a3f4da38c663256b3...,62.3,55.0,4.30,4.34,2.69,886,ba04d229e496b8383a4df91f5e77c194a95cf1a069b0b2...,0.30,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,7020fd7aaf1656dea5c0c0c3d4bb5a28ebf6243fba95d3...,Ideal,F,IF,Surat
40453,c8559278b1ac356e1e414320a4efbbe788dc16623e7873...,60.9,55.0,4.15,4.23,2.55,768,ba04d229e496b8383a4df91f5e77c194a95cf1a069b0b2...,0.26,4e3cfaa334cbafb57a399a98fad8d3812ece460018f457...,f67ab10ad4e4c53121b6a5fe4da9c10ddee905b978d378...,7020fd7aaf1656dea5c0c0c3d4bb5a28ebf6243fba95d3...,Ideal,F,IF,Surat


## LIMPIAMOS LOS DATOS QUE NO NOS INTERESAN

### Dropeamos las columnas innecesarias

In [16]:
drop_data = union_data.drop(columns = ['index_id', 'city_id', 'cut_id', 'color_id', 'clarity_id'])
drop_data

Unnamed: 0,depth,table,x,y,z,price,carat,cut,color,clarity,city
0,62.4,58.0,6.83,6.79,4.25,4268,1.21,Premium,J,VS2,Dubai
1,61.6,58.0,6.40,6.35,3.93,3513,1.02,Premium,J,VS2,Dubai
2,62.3,58.0,5.86,5.80,3.63,1792,0.77,Premium,J,VS2,Dubai
3,59.6,60.0,7.58,7.48,4.49,7553,1.51,Premium,J,VS2,Dubai
4,60.2,62.0,5.40,5.33,3.23,1176,0.57,Premium,J,VS2,Dubai
...,...,...,...,...,...,...,...,...,...,...,...
40450,62.2,54.0,5.24,5.27,3.27,2729,0.54,Ideal,F,IF,Surat
40451,61.9,54.0,5.22,5.25,3.24,2802,0.53,Ideal,F,IF,Surat
40452,62.3,55.0,4.30,4.34,2.69,886,0.30,Ideal,F,IF,Surat
40453,60.9,55.0,4.15,4.23,2.55,768,0.26,Ideal,F,IF,Surat


### Comprobamos si hay valores nulos o valores que no nos interesen

In [17]:
#No hay valores nulos

drop_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40455 entries, 0 to 40454
Data columns (total 11 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   depth    40455 non-null  float64
 1   table    40455 non-null  float64
 2   x        40455 non-null  float64
 3   y        40455 non-null  float64
 4   z        40455 non-null  float64
 5   price    40455 non-null  int64  
 6   carat    40455 non-null  float64
 7   cut      40455 non-null  object 
 8   color    40455 non-null  object 
 9   clarity  40455 non-null  object 
 10  city     40455 non-null  object 
dtypes: float64(6), int64(1), object(4)
memory usage: 3.4+ MB


In [18]:
clean_data = drop_data

## Realizamos el encoding

Analizamos las features

In [19]:
cat_cols = ['cut', 'color', 'clarity', 'city']
diamantes_encoded = clean_data[cat_cols]
diamantes_encoded

Unnamed: 0,cut,color,clarity,city
0,Premium,J,VS2,Dubai
1,Premium,J,VS2,Dubai
2,Premium,J,VS2,Dubai
3,Premium,J,VS2,Dubai
4,Premium,J,VS2,Dubai
...,...,...,...,...
40450,Ideal,F,IF,Surat
40451,Ideal,F,IF,Surat
40452,Ideal,F,IF,Surat
40453,Ideal,F,IF,Surat


In [20]:
col_diamantes = list(diamantes_encoded.columns)
col_diamantes

['cut', 'color', 'clarity', 'city']

In [21]:
cat_diamantes = cat_var(diamantes_encoded, col_diamantes)
cat_diamantes

Unnamed: 0,categorical_variable,number_of_possible_values,values
0,city,13,"[Dubai, Luxembourg, New York City, Antwerp, Ma..."
1,clarity,8,"[VS2, VVS2, SI1, VS1, SI2, I1, VVS1, IF]"
2,color,7,"[J, E, I, G, D, H, F]"
3,cut,5,"[Premium, Very Good, Fair, Good, Ideal]"


In [22]:
le = LabelEncoder()

In [23]:
for column in ['cut', 'color', 'clarity', 'city']:
    clean_data[column] = le.fit_transform(clean_data[column])

In [24]:
clean_data

Unnamed: 0,depth,table,x,y,z,price,carat,cut,color,clarity,city
0,62.4,58.0,6.83,6.79,4.25,4268,1.21,3,6,5,2
1,61.6,58.0,6.40,6.35,3.93,3513,1.02,3,6,5,2
2,62.3,58.0,5.86,5.80,3.63,1792,0.77,3,6,5,2
3,59.6,60.0,7.58,7.48,4.49,7553,1.51,3,6,5,2
4,60.2,62.0,5.40,5.33,3.23,1176,0.57,3,6,5,2
...,...,...,...,...,...,...,...,...,...,...,...
40450,62.2,54.0,5.24,5.27,3.27,2729,0.54,2,2,1,10
40451,61.9,54.0,5.22,5.25,3.24,2802,0.53,2,2,1,10
40452,62.3,55.0,4.30,4.34,2.69,886,0.30,2,2,1,10
40453,60.9,55.0,4.15,4.23,2.55,768,0.26,2,2,1,10


In [25]:
column_to_move = 'price'
cols = [col for col in clean_data.columns if col != column_to_move] + [column_to_move]
final_data = clean_data[cols]
final_data

Unnamed: 0,depth,table,x,y,z,carat,cut,color,clarity,city,price
0,62.4,58.0,6.83,6.79,4.25,1.21,3,6,5,2,4268
1,61.6,58.0,6.40,6.35,3.93,1.02,3,6,5,2,3513
2,62.3,58.0,5.86,5.80,3.63,0.77,3,6,5,2,1792
3,59.6,60.0,7.58,7.48,4.49,1.51,3,6,5,2,7553
4,60.2,62.0,5.40,5.33,3.23,0.57,3,6,5,2,1176
...,...,...,...,...,...,...,...,...,...,...,...
40450,62.2,54.0,5.24,5.27,3.27,0.54,2,2,1,10,2729
40451,61.9,54.0,5.22,5.25,3.24,0.53,2,2,1,10,2802
40452,62.3,55.0,4.30,4.34,2.69,0.30,2,2,1,10,886
40453,60.9,55.0,4.15,4.23,2.55,0.26,2,2,1,10,768


In [26]:
corrM = final_data.corr()
corrM

Unnamed: 0,depth,table,x,y,z,carat,cut,color,clarity,city,price
depth,1.0,-0.293114,-0.026348,-0.030966,0.094655,0.026528,-0.196852,0.047988,-0.058557,-0.000178,-0.014864
table,-0.293114,1.0,0.196059,0.184673,0.155189,0.183392,0.153463,0.03112,-0.084253,-0.007166,0.130111
x,-0.026348,0.196059,1.0,0.973712,0.984876,0.975688,0.026544,0.272498,-0.228392,-0.000813,0.885848
y,-0.030966,0.184673,0.973712,1.0,0.964828,0.951667,0.032142,0.265611,-0.219984,0.000329,0.866163
z,0.094655,0.155189,0.984876,0.964828,1.0,0.96757,0.005101,0.275022,-0.230862,-0.001165,0.8745
carat,0.026528,0.183392,0.975688,0.951667,0.96757,1.0,0.021164,0.294027,-0.218085,-0.000283,0.921935
cut,-0.196852,0.153463,0.026544,0.032142,0.005101,0.021164,1.0,-0.000461,0.029184,0.001864,0.044885
color,0.047988,0.03112,0.272498,0.265611,0.275022,0.294027,-0.000461,1.0,-0.031686,0.005613,0.174855
clarity,-0.058557,-0.084253,-0.228392,-0.219984,-0.230862,-0.218085,0.029184,-0.031686,1.0,0.005767,-0.074228
city,-0.000178,-0.007166,-0.000813,0.000329,-0.001165,-0.000283,0.001864,0.005613,0.005767,1.0,-0.000127


In [27]:
final_data = final_data.drop(columns = ['city'])
final_data

Unnamed: 0,depth,table,x,y,z,carat,cut,color,clarity,price
0,62.4,58.0,6.83,6.79,4.25,1.21,3,6,5,4268
1,61.6,58.0,6.40,6.35,3.93,1.02,3,6,5,3513
2,62.3,58.0,5.86,5.80,3.63,0.77,3,6,5,1792
3,59.6,60.0,7.58,7.48,4.49,1.51,3,6,5,7553
4,60.2,62.0,5.40,5.33,3.23,0.57,3,6,5,1176
...,...,...,...,...,...,...,...,...,...,...
40450,62.2,54.0,5.24,5.27,3.27,0.54,2,2,1,2729
40451,61.9,54.0,5.22,5.25,3.24,0.53,2,2,1,2802
40452,62.3,55.0,4.30,4.34,2.69,0.30,2,2,1,886
40453,60.9,55.0,4.15,4.23,2.55,0.26,2,2,1,768


In [28]:
corrM = final_data.corr()
corrM

Unnamed: 0,depth,table,x,y,z,carat,cut,color,clarity,price
depth,1.0,-0.293114,-0.026348,-0.030966,0.094655,0.026528,-0.196852,0.047988,-0.058557,-0.014864
table,-0.293114,1.0,0.196059,0.184673,0.155189,0.183392,0.153463,0.03112,-0.084253,0.130111
x,-0.026348,0.196059,1.0,0.973712,0.984876,0.975688,0.026544,0.272498,-0.228392,0.885848
y,-0.030966,0.184673,0.973712,1.0,0.964828,0.951667,0.032142,0.265611,-0.219984,0.866163
z,0.094655,0.155189,0.984876,0.964828,1.0,0.96757,0.005101,0.275022,-0.230862,0.8745
carat,0.026528,0.183392,0.975688,0.951667,0.96757,1.0,0.021164,0.294027,-0.218085,0.921935
cut,-0.196852,0.153463,0.026544,0.032142,0.005101,0.021164,1.0,-0.000461,0.029184,0.044885
color,0.047988,0.03112,0.272498,0.265611,0.275022,0.294027,-0.000461,1.0,-0.031686,0.174855
clarity,-0.058557,-0.084253,-0.228392,-0.219984,-0.230862,-0.218085,0.029184,-0.031686,1.0,-0.074228
price,-0.014864,0.130111,0.885848,0.866163,0.8745,0.921935,0.044885,0.174855,-0.074228,1.0


In [29]:
final_data.describe()

Unnamed: 0,depth,table,x,y,z,carat,cut,color,clarity,price
count,40455.0,40455.0,40455.0,40455.0,40455.0,40455.0,40455.0,40455.0,40455.0,40455.0
mean,61.752841,57.446133,5.729392,5.732819,3.537154,0.797706,2.55254,2.599234,3.840143,3928.444469
std,1.431725,2.233535,1.124453,1.14665,0.697062,0.475544,1.028828,1.70126,1.725009,3992.416147
min,43.0,43.0,0.0,0.0,0.0,0.2,0.0,0.0,0.0,326.0
25%,61.0,56.0,4.71,4.72,2.91,0.4,2.0,1.0,2.0,945.0
50%,61.8,57.0,5.69,5.71,3.52,0.7,2.0,3.0,4.0,2397.0
75%,62.5,59.0,6.54,6.54,4.035,1.04,3.0,4.0,5.0,5331.0
max,79.0,95.0,10.23,58.9,8.06,4.5,4.0,6.0,7.0,18823.0


## Entrenamos al modelo usando el Gradient Boosting Regressor

In [30]:
X = final_data.drop(columns = 'price')
y = final_data.loc[:,'price']
print(X.shape,y.shape)

(40455, 9) (40455,)


In [31]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"X_train: {X_train.shape}, X_test: {X_test.shape}, y_train: {y_train.shape}, y_test: {y_test.shape}")
print(f"X_train: {type(X_train)}, X_test: {type(X_test)}, y_train: {type(y_train)}, y_test: {type(y_test)}")

X_train: (32364, 9), X_test: (8091, 9), y_train: (32364,), y_test: (8091,)
X_train: <class 'pandas.core.frame.DataFrame'>, X_test: <class 'pandas.core.frame.DataFrame'>, y_train: <class 'pandas.core.series.Series'>, y_test: <class 'pandas.core.series.Series'>


In [32]:
rfr = RandomForestRegressor()

In [33]:
rfr.fit(X_train, y_train)

In [34]:
y_pred = rfr.predict(X_test)

In [35]:
rmse = mean_squared_error(y_test, y_pred, squared=False)
rmse

595.6631962374004

In [36]:
param_grid = {'n_estimators': [100, 2500],
              'max_depth': [None, 3, 10],
              'min_samples_split': [2, 10],
              'min_samples_leaf': [1, 7],
              'max_features': [None, 'sqrt', 'log2']
              }

In [37]:
grid_search = GridSearchCV(rfr,
                           param_grid,
                           cv=5,
                           verbose=4,
                           scoring='neg_root_mean_squared_error',
                           n_jobs=-1)

In [38]:
grid_search.fit(X,y)

Fitting 5 folds for each of 72 candidates, totalling 360 fits
[CV 2/5] END max_depth=None, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=100;, score=-537.813 total time=  21.0s
[CV 3/5] END max_depth=None, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=2500;, score=-584.820 total time= 8.5min
[CV 5/5] END max_depth=None, max_features=None, min_samples_leaf=1, min_samples_split=2, n_estimators=2500;, score=-556.721 total time= 8.2min
[CV 2/5] END max_depth=None, max_features=None, min_samples_leaf=7, min_samples_split=2, n_estimators=2500;, score=-553.290 total time= 5.2min
[CV 1/5] END max_depth=None, max_features=None, min_samples_leaf=7, min_samples_split=10, n_estimators=100;, score=-585.332 total time=  11.3s
[CV 2/5] END max_depth=None, max_features=None, min_samples_leaf=7, min_samples_split=10, n_estimators=100;, score=-553.767 total time=  12.9s
[CV 3/5] END max_depth=None, max_features=None, min_samples_leaf=7, min_samples_s

In [39]:
y_pred = grid_search.predict(X_test)
y_pred

array([8734.5852    ,  500.7072    , 8004.172     , ..., 8922.8268    ,
       3623.196     , 4662.40696429])

In [40]:
rmse = mean_squared_error(y_test, y_pred, squared=False)
rmse

207.9637052262315

In [41]:
rmse = mean_squared_error(y_test, y_pred, squared=False)
rmse

207.9637052262315

In [42]:
filename = '../modelos/RandomForestRegressor(prueba 3).sav'
pickle.dump(grid_search, open(filename, 'wb'))