In [None]:
#Optimizing the NN structure with Pytorch and Optuna, mnist fashion example
!pip install optuna torchvision

In [None]:
#import the libraries
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [None]:
import torch as torch
import torch.nn as nn
import torch.optim as optim

In [None]:
import optuna

In [None]:
#Define the PyTorch model
class Net(nn.Module):
    def __init__(self, n_units):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, n_units)
        self.fc2 = nn.Linear(n_units, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
def objective(trial):
    # 1. Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = MNIST('.', train=True, download=True, transform=transform)
    test_dataset = MNIST('.', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    # 2. Define model, optimizer, and criterion
    n_units = trial.suggest_int('n_units', 32, 512)
    model = Net(n_units)
    optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD'])
    lr = trial.suggest_float('lr', 1e-4, 1e-2, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # 3. Train model
    for epoch in range(10):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

    # 4. Evaluate model
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    return accuracy


In [None]:
#Run Optuna on the model
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=10)

print('Number of finished trials: ', len(study.trials))
print('Best trial:')
trial = study.best_trial
print('Value: ', trial.value)
print('Params: ')
for key, value in trial.params.items():
    print(f'    {key}: {value}')


# Now that we have the right parameters, we just need to apply them
# defining the model:

In [None]:
class Net(nn.Module):
    def __init__(self, n_units):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, n_units)
        self.fc2 = nn.Linear(n_units, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
# using the parameters from Optuna:
n_units_optimal = 372
model = Net(n_units_optimal)

In [None]:
#loading the mnist dataset
from torchvision.datasets import FashionMNIST
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = FashionMNIST('.', train=True, download=True, transform=transform)
test_dataset = FashionMNIST('.', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
#Define the optimizer and loss function using the optimal parameters:
lr_optimal = 0.000451910438456066
optimizer = optim.Adam(model.parameters(), lr=lr_optimal)
criterion = nn.CrossEntropyLoss()

In [None]:
#Train the model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Evaluate on the test set
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy*100:.2f}%")

#Exporting the model 
# First, let's convert the model to a script

In [None]:
scripted_model = torch.jit.script(model)

In [None]:
#Then, let's export it
torch.jit.save(scripted_model, "PTmodel.pt")

In [None]:
#The model can then be loaded and run with:
loaded_model = torch.jit.load("PTmodel.pt")
output = loaded_model(input_data) # where input_data is the inference source data