# Training and Inference in Neural Networks
In this example, let's look at a complete cycle of training and inference in a neural network. We will assume that our data is already preprocessed and is ready for training. We will show a simple but practical example using PyTorch.

We will use the MNIST dataset, which consists of handwritten digits from 0 to 9. We will build a neural network to classify these digits. So, this is a _classification_ problem with 10 classes (digits 0 to 9). 

In this example, we keep the model and steps simple. For a more advanced implementation, see [this official PyTorch example](https://github.com/pytorch/examples/blob/main/mnist/main.py).



## Data
As always in machine learning, we start with exploring our data and [Exploratory Data Analysis (EDA)](https://pooya.io/ai/ai_machine_learning_overview/#exploratory-data-analysis-eda-and-feature-engineering)

The MNIST dataset contains 60,000 training images and 10,000 test images. 
- Each image is 28x28 pixels, and the labels are the digits from 0 to 9.
- The images are grayscale (one channel), so the pixel values range from 0 to 255, which is the brightness (intensity) of the pixel.
- The dataset is split into a training set and a test set.

The training data set is a matrix of 60,000 rows and 28x28x1= 784 columns. Each row represents a single image, which has 784 columns (features).

$$Channel \times Height \times Width = 1 \times 28 \times 28 = 784$$




$$X_{\text{train}} \in \mathbb{R}^{60000 \times 784}$$

$$X_{\text{train}} = \begin{bmatrix}
\vec{\mathbf{x}}^{(1)} \\
\vec{\mathbf{x}}^{(2)} \\
\vdots \\
\vec{\mathbf{x}}^{(60000)}
\end{bmatrix}$$

Where:
- $\vec{\mathbf{x}}^{(i)} \in \mathbb{R}^{784}$ is the $i$-th image in the training set.

$X_{\text{test}}$ similarly is a matrix of 10,000 rows and 784 columns.



The labels are a vector of same size (60,000) as the number of training images. Each label is an integer from 0 to 9, representing the digit in the corresponding image.

$$y_{\text{train}} \in \mathbb{R}^{60000}$$

Similarly, $y_{\text{test}}$ is a vector of size 10,000.

$$y_{\text{test}} \in \mathbb{R}^{10000}$$

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

Let's download the MNIST dataset using `torchvision` offered by PyTorch.

In [2]:
# Define the transformation to be applied to the images
transform = transforms.Compose([transforms.ToTensor()])

# Download the MNIST training and test datasets
train_data = datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)
test_data = datasets.MNIST(
    root="data", train=False, download=True, transform=transform
)

In [3]:
print(
    f"X_train shape: {train_data.data.shape}, dtype: {train_data.data.dtype}"
)
print(f"y_train shape: {train_data.targets.shape}")

print(f"X_test shape: {test_data.data.shape}, dtype: {test_data.data.dtype}")
print(f"y_test shape: {test_data.targets.shape}")

X_train shape: torch.Size([60000, 28, 28]), dtype: torch.uint8
y_train shape: torch.Size([60000])
X_test shape: torch.Size([10000, 28, 28]), dtype: torch.uint8
y_test shape: torch.Size([10000])


In [4]:
print(f"First image shape: {train_data.data[0].shape}")
print(f"First image label: {train_data.targets[0]}")

First image shape: torch.Size([28, 28])
First image label: 5


**Labels:**

In MNIST dataset, where we are classifying handwritten digits from 0 to 9, the labels are simply the digits themselves. In other words, the class 0 (label 0) corresponds to the digit 0, class 1 (label 1) corresponds to the digit 1, and so on.

| Class (Label) | Value |
|---------------|-------|
| 0             | 0     |
| 1             | 1     |
| 2             | 2     |
| 3             | 3     |
| 4             | 4     |
| 5             | 5     |
| 6             | 6     |
| 7             | 7     |
| 8             | 8     |
| 9             | 9     |

However, in a more complex dataset, the labels are not always integers. For example, in [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist?tab=readme-ov-file#labels), the labels are strings representing the class names. The labels are as follows:

| Class (Label) | Value |
|---------------|-------|
| 0             | T-shirt/top |
| 1             | Trouser     |
| 2             | Pullover    |
| ...           | ...         |

So, this is important to note that regardless of the actual value of the classes, we always map them to integers starting from 0 which in that case the logits of the output layer are automatically mapped to the corresponding index of the class. For example, $z_{0}$ will be mapped to class 0, $z_{1}$ will be mapped to class 1, and so on.

**Batching:**

In PyTorch, we wrap our dataset in a `DataLoader` object which allows us to iterate over the dataset in batches and support shuffling, sampling, and multi-processing. The `DataLoader` object is the one that feeds the data to the model batch by batch.

We define batch size as 64. It means in each iteraction of the training (Gradient Descent) we will use 64 images to calculate the cost, and gradients and then update the parameters of the model.

In [5]:
batch_size = 64

train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

As we saw earlier our training data is 3D matrix of size (60000, 28, 28). In other words, we have 60,000 examples which each is a 28x28 pixels image. So, overall we have a matrix of 60,000 rows which each row is a matrix of 28 rows and 28 columns.

However, as soon as we wrap it in the `DataLoader` object, then it breaks down the whole dataset into batch size chunks. 

In [6]:
for X, y in train_dataloader:
    print(f"Shape of X [Batch size, Channel, Height, Weight]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [Batch size, Channel, Height, Weight]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


## Creating the Model

In this example we create a very simple model using 3 fully connected layers (also called linear layers, or Dense layers). 

**Layer 1 (Dense):**
- $28 \times 28 \times 1 = 784$ inputs, $512$ outputs
- Activation function: ReLU
- In this layer we have $512$ neurons. the shape of matrix $W1$ is (512, 784) and the shape of vector $b1$ is (512,)

**Layer 2 (Dense):**
- $512$ inputs, $512$ outputs.
- Activation function: ReLU
- In this layer we have $512$ neurons. the shape of matrix $W2$ is (512, 512) and the shape of vector $b2$ is (512,)

**Layer 3 (Dense):**
- $512$ inputs, $10$ outputs (one for each class)
- Activation function: None (linear Activation)
- In this layer we have $10$ neurons. the shape of matrix $W3$ is (10, 512) and the shape of vector $b3$ is (10,)

**Placement of Activation Function for the Output Layer:**<br>

As we discussed [here](), the output layer's activation function is applied separately to the logits of the output layer. In here, we have a multi-class classification problem, so we will use the softmax activation function. 



In [7]:
# Define the neural network
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.linear1 = nn.Linear(in_features=28 * 28, out_features=512)
        self.linear2 = nn.Linear(in_features=512, out_features=512)
        self.linear3 = nn.Linear(in_features=512, out_features=10)

    def forward(self, x):

        logits = self.linear_relu_stack(x)
        return logits