## Example inference on the Iris flower dataset
In this example, we demonstrate how to learn using a Negative Log Likelihood (NLL) loss with the uGMM-NN model on the Iris dataset.

## Requirements

In [2]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.abspath(os.path.join(cwd, '..'))
sys.path.append(parent_dir)

from test_architectures import *

import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

### Posterior Inference on the Iris Dataset

For generative training on the Iris dataset, the uGMM-NN models the joint distribution over features and labels.  
To make predictions, we compute the posterior probability of each class given the input features:

$$
P(y \mid \mathbf{x}) \propto P(y, \mathbf{x}).
$$

Since the number of classes is small (3 labels), posterior inference can be performed by evaluating the joint likelihood once per class and selecting the label with the highest probability.  

This corresponds to Maximum A Posteriori (MAP) prediction:

$$
\hat{y} = \arg \max_{c} P(y = c, \mathbf{x}).
$$

Thus, inference is carried out by running a forward pass for each class hypothesis and choosing the most probable outcome.  


In [9]:
def predictNLLMPE(model, data, label, device, epoch):
    mpe = model.infer_mpe(data, mpe_vars=[4], mpe_states=[0.,1.,2.])
    predictions = mpe.argmax(dim=0).squeeze()
    accuracy = (predictions == label).sum() / len(label)
    print(f'epoch: {epoch}%, MPE Accuracy: {accuracy * 100}%')

## Define the model architecture:

Layers in a uGMM-NN model are organized similarly to a standard multilayer perceptron (MLP) architecture, with sequential layers transforming the data from input to output. The model begins with a variable layer representing the input features. This is followed by multiple fully connected Gaussian Mixture layers, 

In [10]:
def iris_nll_fc_ugmm(device):
    n_variables = 5
    model = uGMMNet(device)

    input_layer = InputLayer(n_variables=n_variables, n_var_nodes=5) 
    model.addLayer(input_layer)

    g1 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=20)
    model.addLayer(g1)
    
    g2 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=8)
    model.addLayer(g2)

    root = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=1)
    model.addLayer(root)
    return model.to(device)

## Training the Model

In [11]:
def Classify_iris_nll():
      device = "cuda"
      random_seed = 0
      torch.manual_seed(random_seed)
      features, label = load_iris(return_X_y=True)
      scaler = StandardScaler()
      features = scaler.fit_transform(features)
      features = torch.tensor(features, dtype=torch.float32).to(device)
      label = torch.tensor(label, dtype=torch.int).to(device)
      data = torch.cat([features, label.unsqueeze(1).int()], dim=1).to(device)

      model = iris_nll_gmm(device)
      optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

      test = True
      for i in range(3000):
            optimizer.zero_grad()
            output = model.infer(data, training=True)
            loss = -1 * output.mean()
            loss.backward()
            optimizer.step()

            if i % 200 == 0:
                  print("log-likelihood: {output.sum().item():10.4f}")

                  if test:
                        predictNLLMPE(model, data, label, device, i)

In [12]:
Classify_iris_nll()

log-likelihood: {output.sum().item():10.4f}
epoch: 0%, MPE Accuracy: 37.33333206176758%
log-likelihood: {output.sum().item():10.4f}
epoch: 200%, MPE Accuracy: 62.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 400%, MPE Accuracy: 69.33333587646484%
log-likelihood: {output.sum().item():10.4f}
epoch: 600%, MPE Accuracy: 80.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 800%, MPE Accuracy: 84.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 1000%, MPE Accuracy: 92.66666412353516%
log-likelihood: {output.sum().item():10.4f}
epoch: 1200%, MPE Accuracy: 98.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 1400%, MPE Accuracy: 98.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 1600%, MPE Accuracy: 97.33333587646484%
log-likelihood: {output.sum().item():10.4f}
epoch: 1800%, MPE Accuracy: 96.66667175292969%
log-likelihood: {output.sum().item():10.4f}
epoch: 2000%, MPE Accuracy: 98.0%
log-likelihood: {output.sum().item():10.4f}
epoch: 2200%, MPE Accuracy: 100.0%
log-l