**Piotr Duszak - klasyfikacja, czy grzyb jest trujący, czy jadalny**

Zająłem się problemem klasyfikacji, czy grzyb jest jadalny, czy trujący (jest to jeden z problemów podany w repozytorium: https://github.com/pbiecek/InterpretableMachineLearning2020/issues/5). Do każdego grzyba podana jest seria parametrów dotyczących m. in. kształtu kapelusza, jego kolora, pierścienia na nóżce, blaszek, itp.

Użyłem sieci neruonowych. Wykorzystałem do tego Pythona oraz bibliotekę Keras.

In [None]:
import pandas as pd
import numpy as np
# Wczytanie danych i wyświetlenie kilku pierwszych i ostatnich wierszy
shrooms_data = pd.read_csv('mushrooms.csv')
shrooms_data

Unnamed: 0,class,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,stalk-shape,stalk-root,stalk-surface-above-ring,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
1,e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g
2,e,b,s,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,n,m
3,p,x,y,w,t,p,f,c,n,n,e,e,s,s,w,w,p,w,o,p,k,s,u
4,e,x,s,g,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,n,a,g
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8119,e,k,s,n,f,n,a,c,b,y,e,?,s,s,o,o,p,o,o,p,b,c,l
8120,e,x,s,n,f,n,a,c,b,y,e,?,s,s,o,o,p,n,o,p,b,v,l
8121,e,f,s,n,f,n,a,c,b,n,e,?,s,s,o,o,p,o,o,p,b,c,l
8122,p,k,y,n,f,y,f,c,n,b,t,?,s,k,w,w,p,w,o,e,w,v,l


Jak widać powyżej dane mają 8124 wierszy. Każdy rekord (wiersz) składa się z informacji, czy grzyb jest jadalny, ale również serii parametrów. Wszystkie parametry są kategoryczne (nie ma żadnych danych liczbowych).

In [None]:
# Wyświetlenie liczności każdej kategorii, itp.
shrooms_data.describe()

Unnamed: 0,class,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,stalk-shape,stalk-root,stalk-surface-above-ring,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
count,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124,8124
unique,2,6,4,10,2,9,2,2,2,12,2,5,4,4,9,9,1,4,3,5,9,6,7
top,e,x,y,n,f,n,f,c,b,b,t,b,s,s,w,w,p,w,o,p,w,v,d
freq,4208,3656,3244,2284,4748,3528,7914,6812,5612,1728,4608,3776,5176,4936,4464,4384,8124,7924,7488,3968,2388,4040,3148


Parametry mają różną liczbę kategorii, do której mogą należeć. Przykładowo jest 6 rodzajów kształtów kapelusza, ale aż 10 rodzajów koloru. Aczkolwiek w takiej formie te parametry są nieprzydatne jako wejście do sieci neuronowej. Żeby były przydatne trzeba przerobić je na kodowanie liczbowe albo jeszcze lepiej one-hot-encoding.

In [None]:
# Zamiana kodowania na one-hot-encoding
for column in shrooms_data.columns[1:]:
  for vals in shrooms_data[column].unique():
    shrooms_data.loc[shrooms_data[column]==vals , '_'.join([column, vals])] = 1
    shrooms_data.loc[shrooms_data[column]!=vals , '_'.join([column, vals])] = 0
  shrooms_data.drop(column, axis=1, inplace=True)
shrooms_data.loc[shrooms_data['class']=='e', 'class'] = 1
shrooms_data.loc[shrooms_data['class']=='p', 'class'] = 0
shrooms_data

Unnamed: 0,class,cap-shape_x,cap-shape_b,cap-shape_s,cap-shape_f,cap-shape_k,cap-shape_c,cap-surface_s,cap-surface_y,cap-surface_f,cap-surface_g,cap-color_n,cap-color_y,cap-color_w,cap-color_g,cap-color_e,cap-color_p,cap-color_b,cap-color_u,cap-color_c,cap-color_r,bruises_t,bruises_f,odor_p,odor_a,odor_l,odor_n,odor_f,odor_c,odor_y,odor_s,odor_m,gill-attachment_f,gill-attachment_a,gill-spacing_c,gill-spacing_w,gill-size_n,gill-size_b,gill-color_k,gill-color_n,...,stalk-color-below-ring_n,stalk-color-below-ring_e,stalk-color-below-ring_y,stalk-color-below-ring_o,stalk-color-below-ring_c,veil-type_p,veil-color_w,veil-color_n,veil-color_o,veil-color_y,ring-number_o,ring-number_t,ring-number_n,ring-type_p,ring-type_e,ring-type_l,ring-type_f,ring-type_n,spore-print-color_k,spore-print-color_n,spore-print-color_u,spore-print-color_h,spore-print-color_w,spore-print-color_r,spore-print-color_o,spore-print-color_y,spore-print-color_b,population_s,population_n,population_a,population_v,population_y,population_c,habitat_u,habitat_g,habitat_m,habitat_d,habitat_p,habitat_w,habitat_l
0,0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
2,1,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
3,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8119,1,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8120,1,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8121,1,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0,...,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8122,0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


Po przetworzeniu danych na one-hot-encoding dla każdego grzyba mamy 117 parametrów wejściowych oraz 1 wyjściowy oznaczający, czy grzyb jest jadalny, czy trujący.

In [None]:
# Przekonwertowanie do numpy array, pomieszanie wierszy i rozdzielenie na zbriór uczący, testowy i walidacyjny
data = shrooms_data.to_numpy()
np.random.shuffle(data)
train, val, test = data[:4062], data[4062:6093], data[6093:]
train_y, train_x = train[:, 0].astype(int), train[:, 1:].astype(float)
val_y, val_x = val[:, 0].astype(int), val[:, 1:].astype(float)
test_y, test_x = test[:, 0].astype(int), test[:, 1:].astype(float)

Dane zostały podzielone na dane treningowe, walidacyjne oraz testowe (każdy zestaw na wejście oraz wyjście).


Model, który został użyty to sieć neuronowa z trzema warstwammi ukrytymi (pierwsza ma 64 neurony, druga 128, a trzecia 64). Warstwa wejściowa ma 117 neuronów, a wyjściowa 1 (wszystkie oprócz ostatniej używają funkcji aktywacji relu, ostatania używa funkcji sigmoidalnej i określa prawdopobieńśtwo, czy grzyb jest trujący, czy jadalny).

In [None]:
# Skontruowanie modelu
from tensorflow.keras import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *

model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(117,)))
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

# Uczenie sieci
model.compile(optimizer=SGD(0.01), loss='binary_crossentropy', metrics='acc')
model.fit(train_x, train_y, batch_size=8, epochs=10, validation_data=(val_x, val_y))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f3033b887d0>

In [None]:
# Sprawdzenie skuteczności na zbiorze treningowym
model.evaluate(train_x, train_y)



[0.0016603524563834071, 1.0]

In [None]:
# Sprawdzenie skuteczności na zbiorze walidacyjnym
model.evaluate(val_x, val_y)



[0.0023846866097301245, 0.9995076060295105]

In [None]:
# Sprawdzenie skuteczności na zbiorze testowym
model.evaluate(test_x, test_y)



[0.0015886153560131788, 1.0]

Ostateczne wyniki klasyfikatora to:
* 100% trafności dla zbioru uczącego
* 99.95% trafności dla zbioru walidacyjnego
* 100% trafności dla zbioru testowego

Można więc uznać, że to bardzo dobry klasyfikator.