## Example inference on the Iris flower dataset
In this example, we demonstrate how to learn using a Negative Log Likelihood (NLL) loss with the uGMM-NN model on the Iris dataset.

## Requirements

In [34]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.abspath(os.path.join(cwd, '..'))
sys.path.append(parent_dir)

from spn import SPN, Layer
from variable_layer import *
from gmm_layer import *

import torch
from torch import nn

from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

## Define code to perform bruteforce MPE inference on the class variable
Due to the lack of an efficient MPE (Most Probable Explanation) inference algorithm in the uGMM-NN model, inference is performed by evaluating the model separately for each possible class label. In effect, this treats the class variable as a batch of hypotheses, performing one forward pass per class.

For the Iris dataset, this approach is feasible because the number of class labels (3 classes) is small.


In [35]:
def predictNLLMPE(spn, data, label, device, epoch):
    mpe = spn.infer_mpe(data, mpe_vars=[4], mpe_states=[0.,1.,2.])
    predictions = mpe.argmax(dim=0).squeeze()
    accuracy = (predictions == label).sum() / len(label)
    print(f'epoch: {epoch}%, MPE Accuracy: {accuracy * 100}%')

## Define the model architecture:

Layers in a uGMM-NN model are organized similarly to a standard multilayer perceptron (MLP) architecture, with sequential layers transforming the data from input to output. The model begins with a variable layer representing the input features. This is followed by multiple fully connected Gaussian Mixture layers, 

In [36]:
def iris_nll_fc_ugmm(device):
    n_variables = 5
    spn = SPN(device)

    leaf_layer = VariableLayer(n_variables=n_variables, n_var_nodes=5) 
    spn.addLayer(leaf_layer)

    layer1 = GMixture(prev_layer=spn.layers[-1], n_prod_nodes=20)
    spn.addLayer(layer1)
    
    layer2 = GMixture(prev_layer=spn.layers[-1], n_prod_nodes=8)
    spn.addLayer(layer2)

    root_layer = GMixture(prev_layer=spn.layers[-1], n_prod_nodes=1)
    root_layer.type = TYPE_GPRODUCT_ROOT
    spn.addLayer(root_layer)
    return spn.to(device)

## Training the Model

In [37]:
def Classify_iris_nll():
      device = "cpu"
      # device = "cuda"
      random_seed = 0
      torch.manual_seed(random_seed)
      features, label = load_iris(return_X_y=True)
      scaler = StandardScaler()
      features = scaler.fit_transform(features)
      features = torch.tensor(features, dtype=torch.float32).to(device)
      label = torch.tensor(label, dtype=torch.int).to(device)
      data = torch.cat([features, label.unsqueeze(1).int()], dim=1).to(device)

      spn = iris_nll_fc_ugmm(device)
      optimizer = torch.optim.Adam(spn.parameters(), lr=0.001)

      for i in range(3000):
            optimizer.zero_grad()
            output = spn.infer(data, training=True)
            loss = -1 * output.mean()
            loss.backward()
            optimizer.step()

            if i % 200 == 0:
                  print(f"Epoch: {i}, log-likelihood: {output.sum().item():10.4f}")
                  predictNLLMPE(spn, data, label, device, i)


In [38]:
Classify_iris_nll()

Epoch: 0, log-likelihood:  -474.2241
epoch: 0%, MPE Accuracy: 37.33333206176758%
Epoch: 200, log-likelihood:  -221.9651
epoch: 200%, MPE Accuracy: 62.0%
Epoch: 400, log-likelihood:  -122.2520
epoch: 400%, MPE Accuracy: 69.33333587646484%
Epoch: 600, log-likelihood:   -68.5943
epoch: 600%, MPE Accuracy: 80.0%
Epoch: 800, log-likelihood:   -23.7078
epoch: 800%, MPE Accuracy: 84.0%
Epoch: 1000, log-likelihood:    23.6015
epoch: 1000%, MPE Accuracy: 92.66666412353516%
Epoch: 1200, log-likelihood:    76.9703
epoch: 1200%, MPE Accuracy: 98.0%
Epoch: 1400, log-likelihood:   137.1010
epoch: 1400%, MPE Accuracy: 98.0%
Epoch: 1600, log-likelihood:   201.3132
epoch: 1600%, MPE Accuracy: 97.33333587646484%
Epoch: 1800, log-likelihood:   262.8509
epoch: 1800%, MPE Accuracy: 96.66666412353516%
Epoch: 2000, log-likelihood:   318.1485
epoch: 2000%, MPE Accuracy: 98.66667175292969%
Epoch: 2200, log-likelihood:   367.4844
epoch: 2200%, MPE Accuracy: 99.33333587646484%
Epoch: 2400, log-likelihood:   412.