# Train DNNs with SecML-Torch

In this notebook, we will use the basic training functionalities of SecML-Torch to train a regular PyTorch Deep Neural Network (DNN) classifier.

We will train a classifier for the MNIST dataset.
First, we define the model as a `torch.nn.Module`, as usually done in the `torch` library.


In [None]:
import torch


class MNISTNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 200)
        self.fc2 = torch.nn.Linear(200, 200)
        self.fc3 = torch.nn.Linear(200, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


net = MNISTNet()
device = "cpu"
net = net.to(device)


We import the training and testing dataset of MNIST from `torchvision`, and provide them to the dedicated data loaders.

In [None]:
%%capture
import torchvision.datasets
from torch.utils.data import DataLoader

dataset_path = "data/datasets/"
training_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=True,
    root=dataset_path,
    download=True,
)
training_data_loader = DataLoader(training_dataset, batch_size=64, shuffle=False)
test_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=False,
    root=dataset_path,
    download=True,
)
test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Finally, we initialize the optimizer to use for training the model.

In [None]:
from torch.optim import Adam
optimizer = Adam(lr=1e-3, params=net.parameters())

Now we will start using the SecML-Torch functionalities to train the previously-defined model on the MNIST dataset just loaded.

We will use the class `secmlt.models.pytorch.base_pytorch_trainer.BasePyTorchTrainer` to prepare a training loop. 
This class implements the regular training loop which performs optimization steps (with the optimizer of choice) on a for loop on the batches of samples, for a given amount of epochs (passed as an input parameter).

We wrap the model into a `secmlt.models.pytorch.base_pytorch_nn.BasePytorchClassifier` class, which provides the APIs to use models subclassing the `torch.nn.Module` within SecML-Torch. 
Then, we can train our model by calling `model.train(dataloader=training_data_loader)`.


In [None]:
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer

# Training MNIST model
trainer = BasePyTorchTrainer(optimizer=optimizer, epochs=1)
model = BasePytorchClassifier(model=net, trainer=trainer)
model.train(dataloader=training_data_loader)


test accuracy:  tensor(0.9517)


We can check how the model performs on the testing dataset by using the `secmlt.metrics.classification.Accuracy` wrapper. 
This provides the accuracy scoring loop that queries the model with all the batches and counts how many predictions are correct.

In [None]:
from secmlt.metrics.classification import Accuracy

# Test MNIST model
accuracy = Accuracy()(model, test_data_loader)
print("test accuracy: ", accuracy)

Finally, we can save our model weights with the `torch` saving functionalities.
To get the model, we can access the `model` attribute of the `secmlt.models.pytorch.base_pytorch_nn.BasePytorchClassifier`.

In [None]:
from pathlib import Path

model_path = Path("data/models/mnist")
if not model_path.exists():
    model_path.mkdir(parents=True, exist_ok=True)
torch.save(model.model.state_dict(), model_path / "mnist_model.pt")