In [1]:
import sys
sys.path.insert(0, "../..")

import torch
import torch.nn as nn
from src.data import make_dataset
from pathlib import Path
from loguru import logger

In [2]:
datadir = Path("../../data/raw/")
batch_size = 64
train_dataloader, test_dataloader = make_dataset.get_MNIST(datadir, batch_size=batch_size) 

In [3]:
datadir.resolve()

PosixPath('/Users/raoulgrouls/code/ML22/data/raw')

In [4]:
len(train_dataloader), len(test_dataloader)

(938, 157)

We can obtain an item:

In [5]:
x, y = next(iter(train_dataloader))
x.shape, y.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

The image follows the channels-first convention: (channel, width, height). The label is an integer.

In [6]:
import torch
from torch import nn

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class CNN(nn.Module):
    def __init__(self, filters, units1, units2, input_size=(32, 1, 28, 28)):
        super().__init__()

        self.convolutions = nn.Sequential(
            nn.Conv2d(1, filters, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        activation_map_size = self._conv_test(input_size)
        logger.info(f"Aggregating activationmap with size {activation_map_size}")
        self.agg = nn.AvgPool2d(activation_map_size)

        self.dense = nn.Sequential(
            nn.Flatten(),
            nn.Linear(filters, units1),
            nn.ReLU(),
            nn.Linear(units1, units2),
            nn.ReLU(),
            nn.Linear(units2, 10)
        )

    def _conv_test(self, input_size = (32, 1, 28, 28)):
        x = torch.ones(input_size)
        x = self.convolutions(x)
        return x.shape[-2:]

    def forward(self, x):
        x = self.convolutions(x)
        x = self.agg(x)
        logits = self.dense(x)
        return logits

model = CNN(filters=32, units1=128, units2=64).to(device)

2023-02-10 13:34:59.872 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])


Using cpu device


In [7]:
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 32, 12, 12]           9,248
              ReLU-5           [-1, 32, 12, 12]               0
         MaxPool2d-6             [-1, 32, 6, 6]               0
            Conv2d-7             [-1, 32, 4, 4]           9,248
              ReLU-8             [-1, 32, 4, 4]               0
         MaxPool2d-9             [-1, 32, 2, 2]               0
        AvgPool2d-10             [-1, 32, 1, 1]               0
          Flatten-11                   [-1, 32]               0
           Linear-12                  [-1, 128]           4,224
             ReLU-13                  [-1, 128]               0
           Linear-14                   

In [8]:
import torch.optim as optim
from src.models import metrics
optimizer = optim.Adam
loss_fn = torch.nn.CrossEntropyLoss()
accuracy = metrics.Accuracy()

In [9]:
yhat = model(x)
accuracy(y, yhat)

tensor(0.1250)

We now have everything we need to train the model.

In [10]:
import mlflow
mlflow.set_tracking_uri("sqlite:///mlflow.db")
mlflow.set_experiment("mnist_convolutions")

<Experiment: artifact_location='./mlruns/1', creation_time=1675954831805, experiment_id='1', last_update_time=1675954831805, lifecycle_stage='active', name='mnist_convolutions', tags={}>

In [11]:
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

In [12]:
import torch.optim as optim
from src.models import metrics
from src.models import train_model
from datetime import datetime
modeldir = Path("./models")

def objective(params):
    with mlflow.start_run():
        mlflow.set_tag("model", "convnet")
        mlflow.set_tag("dev", "raoul")
        mlflow.log_params(params)
        mlflow.log_param("datadir", f"{datadir.resolve()}")
        mlflow.log_param("batchsize", f"{batch_size}")

        optimizer = optim.Adam
        loss_fn = torch.nn.CrossEntropyLoss()
        accuracy = metrics.Accuracy()
        model = CNN(**params)
        model = train_model.trainloop(
            epochs=3,
            model=model,
            optimizer=optimizer,
            learning_rate=1e-3,
            loss_fn=loss_fn,
            metrics=[accuracy],
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            log_dir="modellog",
            train_steps=100, #len(train_dataloader),
            eval_steps=100, #len(test_dataloader),
        )

        tag = datetime.now().strftime("%Y%m%d-%H%M")
        modelpath = modeldir / (tag + "model.pt")
        torch.save(model, modelpath)
        mlflow.log_artifact(local_path=modelpath, artifact_path="pytorch_models")
        return {'loss' : test_loss, 'status': STATUS_OK}

In [13]:
search_space = {
    'filters' : scope.int(hp.quniform('filters', 16, 128, 8)),
    'units1' : scope.int(hp.quniform('units1', 16, 128, 8)),
    'units2' : scope.int(hp.quniform('units2', 16, 128, 8)),
}

best_result = fmin(
    fn=objective,
    space=search_space,
    algo=tpe.suggest,
    max_evals=10,
    trials=Trials()
)

  0%|          | 0/10 [00:00<?, ?trial/s, best loss=?]

2023-02-10 13:35:26.522 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:35:26.524 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1335
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  2%|[38;2;30;71;6m2         [0m| 2/100 [00:00<00:05, 17.60it/s][A
  5%|[38;2;30;71;6m5         [0m| 5/100 [00:00<00:04, 19.75it/s][A
  8%|[38;2;30;71;6m8         [0m| 8/100 [00:00<00:04, 20.50it/s][A
 11%|[38;2;30;71;6m#1        [0m| 11/100 [00:00<00:04, 20.62it/s][A
 14%|[38;2;30;71;6m#4        [0m| 14/100 [00:00<00:04, 20.12it/s][A
 17%|[38;2;30;71;6m#7        [0m| 17/100 [00:00<00:04, 20.50it/s][A
 20%|[38;2;30;71;6m##        [0m| 20/100 [00:00<00:03, 20.74it/s][A
 23%|[38;2;30;71;6m##3       [0m| 23/100 [00:01<00:03, 20.95it/s][A
 26%|[38;2;30;71;6m##6       [0m| 26/100 [00:01<00:03, 20.90it/s][A
 29%|[38;2;30;

 10%|█         | 1/10 [00:19<02:55, 19.54s/trial, best loss: 0.8182338976860046]

2023-02-10 13:35:46.011 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:35:46.012 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1335
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  2%|[38;2;30;71;6m2         [0m| 2/100 [00:00<00:06, 14.35it/s][A
  4%|[38;2;30;71;6m4         [0m| 4/100 [00:00<00:06, 14.99it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:06, 15.31it/s][A
  8%|[38;2;30;71;6m8         [0m| 8/100 [00:00<00:05, 15.62it/s][A
 10%|[38;2;30;71;6m#         [0m| 10/100 [00:00<00:05, 15.67it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:05, 15.71it/s][A
 14%|[38;2;30;71;6m#4        [0m| 14/100 [00:00<00:05, 15.78it/s][A
 16%|[38;2;30;71;6m#6        [0m| 16/100 [00:01<00:05, 15.72it/s][A
 18%|[38;2;30;71;6m#8        [0m| 18/100 [00:01<00:05, 15.78it/s][A
 20%|[38;2;30;7

 20%|██        | 2/10 [00:44<03:01, 22.68s/trial, best loss: 0.7253523311018943]

2023-02-10 13:36:10.878 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:36:10.879 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1336
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  4%|[38;2;30;71;6m4         [0m| 4/100 [00:00<00:02, 37.68it/s][A
  8%|[38;2;30;71;6m8         [0m| 8/100 [00:00<00:02, 37.74it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:02, 37.62it/s][A
 16%|[38;2;30;71;6m#6        [0m| 16/100 [00:00<00:02, 37.92it/s][A
 20%|[38;2;30;71;6m##        [0m| 20/100 [00:00<00:02, 38.10it/s][A
 24%|[38;2;30;71;6m##4       [0m| 24/100 [00:00<00:02, 37.92it/s][A
 28%|[38;2;30;71;6m##8       [0m| 28/100 [00:00<00:01, 37.73it/s][A
 32%|[38;2;30;71;6m###2      [0m| 32/100 [00:00<00:01, 38.04it/s][A
 36%|[38;2;30;71;6m###6      [0m| 36/100 [00:00<00:01, 37.86it/s][A
 40%|[38;2;30

 30%|███       | 3/10 [00:54<01:59, 17.10s/trial, best loss: 0.7253523311018943]

2023-02-10 13:36:21.339 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:36:21.340 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1336
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:04, 23.34it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:04, 23.34it/s][A
  9%|[38;2;30;71;6m9         [0m| 9/100 [00:00<00:03, 22.89it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:03, 22.91it/s][A
 15%|[38;2;30;71;6m#5        [0m| 15/100 [00:00<00:03, 22.76it/s][A
 18%|[38;2;30;71;6m#8        [0m| 18/100 [00:00<00:03, 22.41it/s][A
 21%|[38;2;30;71;6m##1       [0m| 21/100 [00:00<00:03, 22.54it/s][A
 24%|[38;2;30;71;6m##4       [0m| 24/100 [00:01<00:03, 22.14it/s][A
 27%|[38;2;30;71;6m##7       [0m| 27/100 [00:01<00:03, 22.24it/s][A
 30%|[38;2;30;

 40%|████      | 4/10 [01:11<01:41, 16.93s/trial, best loss: 0.7253523311018943]

2023-02-10 13:36:38.014 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:36:38.016 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1336
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:03, 28.04it/s][A
  7%|[38;2;30;71;6m7         [0m| 7/100 [00:00<00:03, 29.34it/s][A
 10%|[38;2;30;71;6m#         [0m| 10/100 [00:00<00:03, 29.53it/s][A
 13%|[38;2;30;71;6m#3        [0m| 13/100 [00:00<00:02, 29.52it/s][A
 16%|[38;2;30;71;6m#6        [0m| 16/100 [00:00<00:02, 29.51it/s][A
 20%|[38;2;30;71;6m##        [0m| 20/100 [00:00<00:02, 29.88it/s][A
 24%|[38;2;30;71;6m##4       [0m| 24/100 [00:00<00:02, 30.00it/s][A
 28%|[38;2;30;71;6m##8       [0m| 28/100 [00:00<00:02, 30.05it/s][A
 32%|[38;2;30;71;6m###2      [0m| 32/100 [00:01<00:02, 30.14it/s][A
 36%|[38;2;30

 50%|█████     | 5/10 [01:24<01:17, 15.48s/trial, best loss: 0.7253523311018943]

2023-02-10 13:36:50.921 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:36:50.922 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1336
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:03, 29.86it/s][A
  7%|[38;2;30;71;6m7         [0m| 7/100 [00:00<00:03, 30.41it/s][A
 11%|[38;2;30;71;6m#1        [0m| 11/100 [00:00<00:02, 30.30it/s][A
 15%|[38;2;30;71;6m#5        [0m| 15/100 [00:00<00:02, 30.44it/s][A
 19%|[38;2;30;71;6m#9        [0m| 19/100 [00:00<00:02, 30.39it/s][A
 23%|[38;2;30;71;6m##3       [0m| 23/100 [00:00<00:02, 30.36it/s][A
 27%|[38;2;30;71;6m##7       [0m| 27/100 [00:00<00:02, 30.47it/s][A
 31%|[38;2;30;71;6m###1      [0m| 31/100 [00:01<00:02, 30.51it/s][A
 35%|[38;2;30;71;6m###5      [0m| 35/100 [00:01<00:02, 30.59it/s][A
 39%|[38;2;30

 60%|██████    | 6/10 [01:37<00:58, 14.63s/trial, best loss: 0.7253523311018943]

2023-02-10 13:37:03.915 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:37:03.916 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1337
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:04, 20.40it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:04, 20.71it/s][A
  9%|[38;2;30;71;6m9         [0m| 9/100 [00:00<00:04, 20.72it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:04, 20.65it/s][A
 15%|[38;2;30;71;6m#5        [0m| 15/100 [00:00<00:04, 20.73it/s][A
 18%|[38;2;30;71;6m#8        [0m| 18/100 [00:00<00:03, 20.67it/s][A
 21%|[38;2;30;71;6m##1       [0m| 21/100 [00:01<00:03, 20.51it/s][A
 24%|[38;2;30;71;6m##4       [0m| 24/100 [00:01<00:03, 20.50it/s][A
 27%|[38;2;30;71;6m##7       [0m| 27/100 [00:01<00:03, 20.59it/s][A
 30%|[38;2;30;

 70%|███████   | 7/10 [01:56<00:48, 16.11s/trial, best loss: 0.7253523311018943]

2023-02-10 13:37:23.065 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:37:23.066 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1337
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:03, 26.76it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:03, 26.99it/s][A
  9%|[38;2;30;71;6m9         [0m| 9/100 [00:00<00:03, 26.86it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:03, 27.09it/s][A
 15%|[38;2;30;71;6m#5        [0m| 15/100 [00:00<00:03, 27.12it/s][A
 18%|[38;2;30;71;6m#8        [0m| 18/100 [00:00<00:03, 27.23it/s][A
 21%|[38;2;30;71;6m##1       [0m| 21/100 [00:00<00:02, 27.26it/s][A
 24%|[38;2;30;71;6m##4       [0m| 24/100 [00:00<00:02, 27.34it/s][A
 27%|[38;2;30;71;6m##7       [0m| 27/100 [00:00<00:02, 27.41it/s][A
 30%|[38;2;30;

 80%|████████  | 8/10 [02:11<00:31, 15.59s/trial, best loss: 0.7253523311018943]

2023-02-10 13:37:37.537 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:37:37.538 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1337
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  3%|[38;2;30;71;6m3         [0m| 3/100 [00:00<00:03, 29.58it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:03, 28.80it/s][A
  9%|[38;2;30;71;6m9         [0m| 9/100 [00:00<00:03, 28.12it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:03, 27.66it/s][A
 15%|[38;2;30;71;6m#5        [0m| 15/100 [00:00<00:03, 28.20it/s][A
 19%|[38;2;30;71;6m#9        [0m| 19/100 [00:00<00:02, 29.29it/s][A
 22%|[38;2;30;71;6m##2       [0m| 22/100 [00:00<00:02, 29.29it/s][A
 25%|[38;2;30;71;6m##5       [0m| 25/100 [00:00<00:02, 29.46it/s][A
 28%|[38;2;30;71;6m##8       [0m| 28/100 [00:00<00:02, 29.52it/s][A
 31%|[38;2;30;

 90%|█████████ | 9/10 [02:24<00:14, 14.79s/trial, best loss: 0.7253523311018943]

2023-02-10 13:37:50.574 | INFO     | __main__:__init__:26 - Aggregating activationmap with size torch.Size([2, 2])
2023-02-10 13:37:50.575 | INFO     | src.data.data_tools:dir_add_timestamp:129 - Logging to modellog/20230210-1337
  0%|[38;2;30;71;6m          [0m| 0/3 [00:00<?, ?it/s]
  0%|[38;2;30;71;6m          [0m| 0/100 [00:00<?, ?it/s][A
  2%|[38;2;30;71;6m2         [0m| 2/100 [00:00<00:07, 13.05it/s][A
  4%|[38;2;30;71;6m4         [0m| 4/100 [00:00<00:07, 13.60it/s][A
  6%|[38;2;30;71;6m6         [0m| 6/100 [00:00<00:06, 13.78it/s][A
  8%|[38;2;30;71;6m8         [0m| 8/100 [00:00<00:06, 13.69it/s][A
 10%|[38;2;30;71;6m#         [0m| 10/100 [00:00<00:06, 13.74it/s][A
 12%|[38;2;30;71;6m#2        [0m| 12/100 [00:00<00:06, 13.78it/s][A
 14%|[38;2;30;71;6m#4        [0m| 14/100 [00:01<00:06, 13.81it/s][A
 16%|[38;2;30;71;6m#6        [0m| 16/100 [00:01<00:06, 13.75it/s][A
 18%|[38;2;30;71;6m#8        [0m| 18/100 [00:01<00:05, 13.82it/s][A
 20%|[38;2;30;7

100%|██████████| 10/10 [02:53<00:00, 17.30s/trial, best loss: 0.7253523311018943]


In [14]:
best_result

{'filters': 88.0, 'units1': 96.0, 'units2': 48.0}