In [13]:
!wget -q https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv
!wget -q https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv

In [1]:
from spn.algorithms.Inference import log_likelihood
from spn.algorithms.LearningWrappers import learn_parametric
from spn.structure.Base import Context
from spn.structure.leaves.parametric.Parametric import Categorical, Gaussian, Bernoulli
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

red = pd.read_csv("winequality-red.csv", sep=";")
oldnames = red.columns
rename_dict = {old: old.replace(" ", "_") for old in oldnames}
red = red.rename(rename_dict, axis=1)

white = pd.read_csv("winequality-white.csv", sep=";")
oldnames = white.columns
rename_dict = {old: old.replace(" ", "_") for old in oldnames}
white = white.rename(rename_dict, axis=1)
print ("%d red wines and %d white wines" % (len(red), len(white)))

frame = pd.concat([red, white], ignore_index=True)
print ("%d rows" % len(frame))
print (frame.shape)
frame.head()

1599 red wines and 4898 white wines
6497 rows
(6497, 12)


Unnamed: 0,fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality
0,7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5
1,7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5
2,7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5
3,11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6
4,7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5


In [2]:
features = frame.columns.tolist()
ptypes = [Gaussian] * len(features)
categorical = [False] * len(features)

data = frame.to_numpy().astype(float)

train, test = train_test_split(data, random_state = 0)
print ("%d in train, %d in test" % (len(train), len(test)))

4872 in train, 1625 in test


In [3]:
net = learn_parametric(train, 
                       ds_context = Context(parametric_types=ptypes).add_domains(train), 
                       rows = "gmm",
                       min_instances_slice = len(train) / 100)



In [28]:
import csi2
from spn.structure.Base import get_nodes_by_type, Product


import importlib
importlib.reload(csi2)

def format_condition(condition):
    try:
        a, sign, b = condition.split(" ")
    except ValueError as e:
        raise ValueError("Some nodes don't have conditions. reduce min_impurity_decrease")
    return "%s %s %.4f" % (a, sign, float(b))

print ('#product nodes = ', len(get_nodes_by_type(net, (Product))))
names = features
csi2.annotate_spn(net, names, categorical, 
             min_impurity_decrease = 0.05, max_depth = 2)
rules = csi2.context_specific_independences(net, instance_threshold = 0)


rules = csi2.context_specific_independences(net, instance_threshold = 0)
csis = []
ac = []
cc = []
for i, rule in enumerate(rules):
    antecedent, consequent, *scores = rule
    A = csi2.format_antecedent(antecedent, format_condition)
    C = csi2.format_consequent(consequent)
    a_count = csi2.antecedent_count(A)
    c_count = csi2.consequent_count(C)
    ac.append(a_count)
    cc.append(c_count)
    csis.append ("{%s} => {%s} | %.2f, %.2f, %d | %d %d" % (A, C, *scores, a_count, c_count))
print ("%d, %.2f, %.2f" % (len(csis), np.mean(ac), np.mean(cc)))

rules = csi2.context_specific_independences(net, instance_threshold = len(train) / 20,
                                           precision_threshold = 0.7, recall_threshold = 0.7)
csis = []
ac = []
cc = []
for i, rule in enumerate(rules):
    antecedent, consequent, *scores = rule
    A = csi2.format_antecedent(antecedent, format_condition)
    C = csi2.format_consequent(consequent)
    a_count = csi2.antecedent_count(A)
    c_count = csi2.consequent_count(C)
    ac.append(a_count)
    cc.append(c_count)
    csis.append ("{%s} => {%s} | %.2f, %.2f, %d | %d %d" % (A, C, *scores, a_count, c_count))
print ("%d, %.2f, %.2f" % (len(csis), np.mean(ac), np.mean(cc)))

#product nodes =  236
236, 12.45, 6.76
5, 3.60, 2.60


In [24]:
csis

['{chlorides <= 0.0585} => {(fixed_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,alcohol,quality), (volatile_acidity), (sulphates)} | 0.94, 0.95, 3411 | 1 3',
 '{[chlorides > 0.0585] & [(total_sulfur_dioxide <= 72.5000) & (residual_sugar <= 3.3500)]} => {(fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,density,pH,sulphates,alcohol,quality), (free_sulfur_dioxide,total_sulfur_dioxide)} | 0.87, 0.85, 884 | 3 2',
 '{[chlorides <= 0.0585] & [density > 0.9948]} => {(fixed_acidity,citric_acid,residual_sugar,density,pH,alcohol), (chlorides), (free_sulfur_dioxide,total_sulfur_dioxide), (quality)} | 0.87, 0.79, 1455 | 2 4',
 '{[chlorides <= 0.0585] & [density > 0.9948] & [alcohol <= 9.8500] & [residual_sugar <= 11.7750] & [residual_sugar > 3.2500]} => {(fixed_acidity,residual_sugar,density,pH,alcohol), (citric_acid)} | 0.82, 0.79, 366 | 5 2',
 '{[chlorides <= 0.0585] & [density > 0.9948] & [free_sulfur_dioxide <= 88.5000] & [

In [10]:
from spn.algorithms.Inference import log_likelihood

print ("%.2f" % np.mean(log_likelihood(net, test)))

-3.55
