# Train DNNs with SecML-Torch

In this notebook, we will use the basic training functionalities of SecML-Torch to train a regular PyTorch Deep Neural Network (DNN) classifier.

In [None]:
import torch


class MNISTNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 200)
        self.fc2 = torch.nn.Linear(200, 200)
        self.fc3 = torch.nn.Linear(200, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


net = MNISTNet()
device = "cpu"
net = net.to(device)


In [1]:
%%capture
import torchvision.datasets
from torch.utils.data import DataLoader

dataset_path = "data/datasets/"
training_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=True,
    root=dataset_path,
    download=True,
)
training_data_loader = DataLoader(training_dataset, batch_size=64, shuffle=False)
test_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=False,
    root=dataset_path,
    download=True,
)
test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [None]:
from pathlib import Path

import torch
from secmlt.metrics.classification import Accuracy
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from torch.optim import Adam



optimizer = Adam(lr=1e-3, params=net.parameters())

# Training MNIST model
trainer = BasePyTorchTrainer(optimizer, epochs=1)
model = BasePytorchClassifier(net, trainer=trainer)
model.train(training_data_loader)

# Test MNIST model
accuracy = Accuracy()(model, test_data_loader)
print("test accuracy: ", accuracy)

model_path = Path("data/models/mnist")
if not model_path.exists():
    model_path.mkdir(parents=True, exist_ok=True)
torch.save(model.model.state_dict(), model_path / "mnist_model.pt")


test accuracy:  tensor(0.9517)
