# Testing the BoolFunction class for probabilistic boolean function learning

In [1]:
import babool as bb
import numpy as np
import pandas as pd

## Reading cat-in-the-dat dataset

In [2]:
dfcat = pd.read_csv('data/cat-in-the-dat/train.csv', index_col = 'id')

In [3]:
dfcat.shape

(300000, 24)

In [4]:
dfb = pd.get_dummies(dfcat)

In [5]:
dfb.shape

(300000, 16440)

In [8]:
del dfcat

## Probabilistic Boolean function learning

### Creates object

In [6]:
model = bb.BoolFunction(pgeomm = 0.5, theta = 10)

### Train model

In [7]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc

In [None]:
y = dfb.target.values
M = dfb.drop('target', axis = 1).values

In [14]:

Xtrain, Xtest, ytrain, ytest = train_test_split(M, y, test_size = 0.5, stratify = y, random_state = 1910)

In [72]:
%%time
nchains = 20
njobs = -1
nsteps = 10000
nstart = 1

thetas = [5]
pgeos = [0.5]


# Cross validation
res = []
psis = []
for theta in thetas:
    for pgeo in pgeos:
        model = bb.BoolFunction(pgeomm = pgeo, theta = theta)
        _ = model.fit(Xtrain, ytrain, nchains, njobs, nsteps, nstart )
        ypred = model.predict(Xtest, binary = False)

        false_positive_rate, true_positive_rate, thresholds = roc_curve(ytest, ypred)
        roc_auc = auc(false_positive_rate, true_positive_rate)
        res.append([theta, pgeo, roc_auc])
        psis.append(model.psi)

CPU times: user 1.41 s, sys: 46.9 ms, total: 1.46 s
Wall time: 1min 50s


In [54]:
import pickle

with open('mush2.pkl', 'wb') as arq:
    pickle.dump(res, arq)

In [59]:
import pickle

with open("mush2.pkl", 'rb') as arq:
    res = pickle.load(arq)

In [73]:
dfres = pd.DataFrame({'theta' : [r[0] for r in res], 'pgeo' : [r[1] for r in res], 'auc' : [r[2] for r in res]})
dfres

Unnamed: 0,theta,pgeo,auc
0,5,0.5,0.998979


Converting function $\psi$ to the logical classification rule

In [74]:
expr = '('
for m in model.psi:
    for v in m:
        expr += dfb.columns[v+1] + " AND "
    expr = expr[:-5]
    expr += ') OR ('
expr = expr[:-5]
expr

'(spore-print-color_r) OR (odor_f) OR (gill-color_b) OR (odor_c) OR (odor_p) OR (odor_n AND stalk-surface-below-ring_y AND ring-type_e) OR (stalk-color-below-ring_c AND veil-color_w)'