## Example (discriminative) inference on the MNIST dataset
In this example, we demonstrate how to perform classification on the MNIST dataset using a fully connected deep univariate Gaussian Mixture Model (uGMM-NN) network trained with a **cross-entropy loss**.

## Requirements

In [1]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.abspath(os.path.join(cwd, '..'))
sys.path.append(parent_dir)

from test_architectures import *

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.datasets import MNIST

## Define code to test accuracy
After training, we can evaluate the uGMM-NN on the test set by comparing the predicted class with the true label.

In [2]:
def testAccuracy(model, test_loader, device):
        correct, total = 0, 0
        with torch.no_grad():
            for batch_idx, (test_batch_data, test_batch_labels) in enumerate(test_loader):
                batch_size = test_batch_data.shape[0]
                data = test_batch_data.reshape(batch_size, 28*28)
                data = data.to(device)
                output = model.infer(data, training = False)
                predictions = output.argmax(dim=1)
                labels = test_batch_labels.to(device)
                total += labels.size(0)
                correct += (predictions == labels).sum().item()
    
        accuracy = correct / total
        print(f'Test Accuracy: {accuracy * 100:.2f}%')


## Define the model architecture:

Layers in a uGMM-NN model are defined similarly to the standard MLP architecture. In this example architecture, we introduce a dropout of **p=0.2** in the first hidden layer.

In [4]:
def mnist_fc_ugmm(device):
    n_variables = 28*28
    model = uGMMNet(device)

    input_layer = InputLayer(n_variables=n_variables, n_var_nodes=n_variables) 
    model.addLayer(input_layer)   

    g1 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=128, dropout=0.2)
    model.addLayer(g1)

    g2 = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=64, dropout=0.0)
    model.addLayer(g2)
 
    root = uGMMLayer(prev_layer=model.layers[-1], n_ugmm_nodes=10, dropout=0.0)
    model.addLayer(root)
    return model.to(device)


## Training the Model

In [7]:
def mnistCrossEntropy():    
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])    
    torch.manual_seed(0)
    batch_size = 256

    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = "cuda"

    model = mnist_fc_ugmm(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, 
        milestones=[20, 45],
        gamma=0.1
    )
    criterion = nn.CrossEntropyLoss()

    num_epochs = 50
    for epoch in range(num_epochs):
        for batch_index, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()  
            batch_size = inputs.shape[0]
            data = inputs.reshape(batch_size, 28*28)
            data = data.to(device)
            output = model.infer(data, training=True)
            loss = criterion(output, labels.to(device))         
            
            loss.backward()  
            optimizer.step()

        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')
        testAccuracy(model, test_loader, device)
        scheduler.step()

In [8]:
mnistCrossEntropy()

Epoch 1/50, Loss: 0.5083840489387512
Test Accuracy: 90.10%
Epoch 2/50, Loss: 0.43400970101356506
Test Accuracy: 91.90%
Epoch 3/50, Loss: 0.4375993013381958
Test Accuracy: 92.64%
Epoch 4/50, Loss: 0.24422836303710938
Test Accuracy: 92.73%
Epoch 5/50, Loss: 0.34684380888938904
Test Accuracy: 94.46%
Epoch 6/50, Loss: 0.19134074449539185
Test Accuracy: 95.07%
Epoch 7/50, Loss: 0.1892576664686203
Test Accuracy: 95.29%
Epoch 8/50, Loss: 0.32417571544647217
Test Accuracy: 95.28%
Epoch 9/50, Loss: 0.20768284797668457
Test Accuracy: 95.49%
Epoch 10/50, Loss: 0.18329529464244843
Test Accuracy: 95.46%
Epoch 11/50, Loss: 0.17306679487228394
Test Accuracy: 95.73%
Epoch 12/50, Loss: 0.08532044291496277
Test Accuracy: 95.73%
Epoch 13/50, Loss: 0.06585118919610977
Test Accuracy: 95.98%
Epoch 14/50, Loss: 0.12104979902505875
Test Accuracy: 96.35%
Epoch 15/50, Loss: 0.1453668177127838
Test Accuracy: 96.11%
Epoch 16/50, Loss: 0.13293862342834473
Test Accuracy: 95.99%
Epoch 17/50, Loss: 0.1057614311575889