#### importing

In [73]:
from kad import KAD
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from combinations import elems_comp

#### preprocessing

In [64]:
shrooms = pd.read_csv('dataset/mushrooms.csv', usecols=['class', 'cap-shape','cap-surface', 'cap-color', 'gill-size', 'stalk-shape', 'veil-type', 'population'])
shrooms_sample = shrooms.groupby('class').sample(n=500, random_state=123)
shrooms_sample

Unnamed: 0,class,cap-shape,cap-surface,cap-color,gill-size,stalk-shape,veil-type,population
4831,e,k,y,e,b,e,p,c
605,e,b,s,w,b,e,p,n
6904,e,x,y,c,b,e,p,y
3409,e,f,y,n,b,t,p,v
1661,e,x,f,g,b,t,p,a
...,...,...,...,...,...,...,...,...
5131,p,b,f,y,n,e,p,v
6365,p,x,s,n,n,t,p,v
3270,p,x,f,p,n,e,p,s
6701,p,x,s,e,n,t,p,v


In [65]:
X = shrooms_sample[shrooms_sample.columns.drop('class')]
y = shrooms_sample[['class']]

In [66]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123, stratify=y)

#### training

In [68]:
kad = KAD(X_train, y_train)
kad.process()

Processing combinations: 100%|██████████| 7/7 [00:15<00:00,  2.28s/it]


In [69]:
kad.rules

[([None, None, None, 'n', None, None, None],
  'p',
  np.float64(0.9099099099099099)),
 (['x', None, None, 'b', None, None, None],
  'e',
  np.float64(0.7160493827160493)),
 ([None, None, None, 'b', 't', None, None],
  'e',
  np.float64(0.9248826291079812)),
 ([None, None, None, 'n', 't', None, None],
  'p',
  np.float64(0.9337349397590361)),
 ([None, None, None, 'b', 't', 'p', None],
  'e',
  np.float64(0.9248826291079812)),
 ([None, None, None, 'n', 't', 'p', None],
  'p',
  np.float64(0.9337349397590361)),
 ([None, None, None, 'n', 't', None, 'v'],
  'p',
  np.float64(0.9337349397590361)),
 ([None, None, None, 'n', 't', 'p', 'v'], 'p', np.float64(0.9337349397590361))]

#### testing

In [70]:
y_pred = []
kad.rules.sort(key=lambda x: x[-1], reverse=True)
for elem in np.array(X_test):
    for rule in kad.rules:
        if elems_comp(elem, rule[0]):
            y_pred.append(rule[1])
            break

In [71]:
n = len(y_pred)

In [74]:
print(confusion_matrix(y_test[:n], y_pred))
print(classification_report(y_test[:n], y_pred))

[[70 52]
 [78 49]]
              precision    recall  f1-score   support

           e       0.47      0.57      0.52       122
           p       0.49      0.39      0.43       127

    accuracy                           0.48       249
   macro avg       0.48      0.48      0.47       249
weighted avg       0.48      0.48      0.47       249

