## Example (discriminative) inference on the MNIST dataset
In this example, we demonstrate how to perform classification on the MNIST dataset using a fully connected deep univariate Gaussian Mixture Model (uGMM-NN) network trained with a **cross-entropy loss**.

## Requirements

In [1]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.abspath(os.path.join(cwd, '..'))
sys.path.append(parent_dir)

from test_architectures import *

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import MNIST

## Define code to test accuracy
Inference with the uGMM-NN model is more computationally efficient when trained with a discriminative loss, as it includes a dedicated output node for each class label. However, this discriminative setup reduces flexibility it can only perform classification tasks (e.g., estimating P(Digit | Data)) and does not support more general probabilistic queries.

In [2]:
def testAccuracy(model, test_loader, device):
        correct, total = 0, 0
        with torch.no_grad():
            for batch_idx, (test_batch_data, test_batch_labels) in enumerate(test_loader):
                batch_size = test_batch_data.shape[0]
                data = test_batch_data.reshape(batch_size, 28*28)
                data = data.to(device)
                output = model.infer(data, training = False)
                predictions = output.argmax(dim=1)
                labels = test_batch_labels.to(device)
                total += labels.size(0)
                correct += (predictions == labels).sum().item()
    
        accuracy = correct / total
        print(f'Test Accuracy: {accuracy * 100:.2f}%')


## Define the model architecture:

Layers in a uGMM-NN model are defined similarly to the standard MLP architecture. In this example architecture, we introduce a dropout of p=0.5 in the first hidden layer.

In [3]:
def mnist_fc_ugmm(device):
    n_variables = 28*28
    model = uGMMNet(device)

    input_layer = InputLayer(n_variables=n_variables, n_var_nodes=n_variables) 
    model.addLayer(input_layer)   

    g1 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=128, dropout=0.3)
    model.addLayer(g1)

    g2 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=64, dropout=0.0)
    model.addLayer(g2)
 
    root = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=10, dropout=0.0)
    model.addLayer(root)
    return model.to(device)


## Training the Model

In [4]:
def mnistCrossEntropy():    
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])    
    torch.manual_seed(0)
    batch_size = 256

    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = "cuda"

    model = mnist_fc_ugmm(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    num_epochs = 65
    for epoch in range(num_epochs):
        for batch_index, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()  
            batch_size = inputs.shape[0]
            data = inputs.reshape(batch_size, 28*28)
            data = data.to(device)
            output = model.infer(data, training=True)
            loss = criterion(output, labels.to(device))         
            
            loss.backward()  
            optimizer.step()            
            
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')
        testAccuracy(model, test_loader, device)

In [9]:
mnistCrossEntropy()

Epoch 1/65, Loss: 0.7547492980957031
Test Accuracy: 85.24%
Epoch 2/65, Loss: 0.5385421514511108
Test Accuracy: 90.78%
Epoch 3/65, Loss: 0.4680165946483612
Test Accuracy: 91.54%
Epoch 4/65, Loss: 0.38211575150489807
Test Accuracy: 91.69%
Epoch 5/65, Loss: 0.349941611289978
Test Accuracy: 92.59%
Epoch 6/65, Loss: 0.3358570337295532
Test Accuracy: 93.49%
Epoch 7/65, Loss: 0.3844209909439087
Test Accuracy: 93.58%
Epoch 8/65, Loss: 0.30797672271728516
Test Accuracy: 94.23%
Epoch 9/65, Loss: 0.37255048751831055
Test Accuracy: 94.30%
Epoch 10/65, Loss: 0.27254417538642883
Test Accuracy: 94.84%
Epoch 11/65, Loss: 0.21158216893672943
Test Accuracy: 94.61%
Epoch 12/65, Loss: 0.17260713875293732
Test Accuracy: 94.87%
Epoch 13/65, Loss: 0.17667162418365479
Test Accuracy: 94.47%
Epoch 14/65, Loss: 0.08908198028802872
Test Accuracy: 95.72%
Epoch 15/65, Loss: 0.11320378631353378
Test Accuracy: 94.97%
Epoch 16/65, Loss: 0.11895287036895752
Test Accuracy: 95.48%
Epoch 17/65, Loss: 0.08914307504892349
T