In [52]:
import itertools
import joblib
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import OrdinalEncoder
from sklearn.impute import SimpleImputer
import yaml
import gdown
import os
import pprint
import zipfile
from sklearn.metrics import make_scorer 
# https://drive.google.com/file/d/1v5PlKhhafsmWEvRUBj4Zpo2kkNDDR4HY/view?usp=sharing

In [7]:
%cd ..

c:\Users\Paulo\Documents\Projects\poisonous_mushroom


In [30]:
with open('params.yaml') as conf_file:
    config = yaml.safe_load(conf_file)

pprint.pprint(config)

{'base': {'random_state': 42},
 'data_decompress': {'processed_path': 'data/processed'},
 'data_load': {'compact_name': 'mushroom.zip',
               'file_id': '1v5PlKhhafsmWEvRUBj4Zpo2kkNDDR4HY',
               'raw_path': 'data/raw'},
 'train': {'train_file': 'data/processed/train.csv'}}


In [28]:
# raw_data_path = config['data_load']['raw_path']
# filename = config['data_load']['compact_name']
# file_id = config['data_load']['file_id']


def download_compact_file(config_path: str) -> None:
    with open(config_path) as conf_file:
        config = yaml.safe_load(conf_file)
    raw_data_path = config['data_load']['raw_path']
    filename = config['data_load']['compact_name']
    file_id = config['data_load']['file_id']
    download_url = f'https://drive.google.com/uc?id={file_id}'
    output = os.path.join(raw_data_path, filename)
    print(f'Downloading to {output}')
    gdown.download(download_url, output, quiet=False)



download_compact_file('params.yaml')

Downloading to data/raw\mushroom.zip


Downloading...
From (original): https://drive.google.com/uc?id=1v5PlKhhafsmWEvRUBj4Zpo2kkNDDR4HY
From (redirected): https://drive.google.com/uc?id=1v5PlKhhafsmWEvRUBj4Zpo2kkNDDR4HY&confirm=t&uuid=4d3396ee-b9d9-41e0-9457-7ba1331cd785
To: c:\Users\Paulo\Documents\Projects\poisonous_mushroom\data\raw\mushroom.zip
100%|██████████| 86.3M/86.3M [00:15<00:00, 5.54MB/s]


In [29]:
def extract_datasets(config_path: str) -> None:
    raw_data_path = config['data_load']['raw_path']
    filename = config['data_load']['compact_name']
    zip_file_path = os.path.join(raw_data_path, filename)
    processed_path = config['data_decompress']['processed_path']
    os.makedirs(processed_path, exist_ok=True)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(processed_path)


extract_datasets('params.yaml')

In [61]:
with open('params.yaml') as conf_file:
    config = yaml.safe_load(conf_file)

train_file = config['train']['train_file']
id_column = config['metadata']['id_col']
target_column = config['metadata']['target_col']

train_data = pd.read_csv(train_file, index_col=id_column)
train_data.drop(columns=['stem-height', 'stem-width', 'cap-diameter'], inplace=True)

train_data[target_column] = train_data[target_column].replace({'e':0.0, 'p':1.0})

x_train, x_val, y_train, y_val = train_test_split(
    train_data.drop(target_column, axis=1), train_data[target_column],
    test_size=config['metadata']['test_size'],
    random_state=config['base']['random_state']
)

imputer = SimpleImputer(strategy='most_frequent')
x_train = imputer.fit_transform(x_train)
x_val = imputer.transform(x_val)

ordinal_encoding = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
x_train = ordinal_encoding.fit_transform(x_train)
x_val = ordinal_encoding.transform(x_val)


rfc = RandomForestClassifier(
    n_estimators=config['train']['random_forest']['n_estimators'],
    max_depth=config['train']['random_forest']['max_depth'],
    random_state=config['base']['random_state']
)

rfc.fit(x_train, y_train)

y_pred_train = rfc.predict(x_train)
y_pred = rfc.predict(x_val)


print(f'Train f1-Score: {f1_score(y_train, y_pred_train)}')
print(f'Validation f1-Score: {f1_score(y_val, y_pred)}')




    










Train f1-Score: 0.9500045715055652
Validation f1-Score: 0.9495553815491462


In [57]:
x_val

array([['s', 't', 'r', ..., 'k', 'd', 'a'],
       ['o', 't', 'g', ..., 'k', 'd', 's'],
       ['f', 't', 'n', ..., 'k', 'd', 'a'],
       ...,
       ['x', 's', 'n', ..., 'k', 'd', 'a'],
       ['x', 't', 'y', ..., 'k', 'd', 'u'],
       ['b', 'h', 'n', ..., 'k', 'd', 'a']], dtype=object)

In [58]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 3116945 entries, 0 to 3116944
Data columns (total 18 columns):
 #   Column                Dtype 
---  ------                ----- 
 0   class                 object
 1   cap-shape             object
 2   cap-surface           object
 3   cap-color             object
 4   does-bruise-or-bleed  object
 5   gill-attachment       object
 6   gill-spacing          object
 7   gill-color            object
 8   stem-root             object
 9   stem-surface          object
 10  stem-color            object
 11  veil-type             object
 12  veil-color            object
 13  has-ring              object
 14  ring-type             object
 15  spore-print-color     object
 16  habitat               object
 17  season                object
dtypes: object(18)
memory usage: 451.8+ MB


In [39]:
train_data.head()

Unnamed: 0_level_0,class,cap-diameter,cap-shape,cap-surface,cap-color,does-bruise-or-bleed,gill-attachment,gill-spacing,gill-color,stem-height,...,stem-root,stem-surface,stem-color,veil-type,veil-color,has-ring,ring-type,spore-print-color,habitat,season
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,e,8.8,f,s,u,f,a,c,w,4.51,...,,,w,,,f,f,,d,a
1,p,4.51,x,h,o,f,a,c,n,4.79,...,,y,o,,,t,z,,d,w
2,e,6.94,f,s,b,f,x,c,w,6.85,...,,s,n,,,f,f,,l,w
3,e,3.88,f,y,g,f,s,,g,4.16,...,,,w,,,f,f,,d,u
4,e,5.85,x,l,w,f,d,,w,3.37,...,,,w,,,f,f,,g,a
