<a href="https://colab.research.google.com/github/tommasomncttn/NAS4CNN/blob/main/Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Importing 

In [54]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import datasets
from sklearn.datasets import load_digits
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

### Dataset 

In [12]:
class TensorizedDigits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode = "train", transforms = None, tensorized = True):
        digits = load_digits()
        if mode == "train":
            self.data = digits.data[:1000].astype(np.float32)
            self.targets = digits.target[:1000]
        elif mode == "val":
            self.data = digits.data[1000:1350].astype(np.float32)
            self.targets = digits.target[1000:1350]
        else:
            self.data = digits.data[1350:].astype(np.float32)
            self.targets = digits.target[1350:]

        self.transforms = transforms

        if tensorized:
          self.transforms = TensorizedDigits.tensorization_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_x = self.data[idx]
        sample_y = self.targets[idx]
        
        if True:
          sample_x, sample_y = self.transforms(sample_x, sample_y)
        

        return (sample_x, sample_y)

    @staticmethod
    def tensorization_transform(x, y):
        
        # reshape to get a valid input for a CNN
        sample_x = x.reshape(1, 8, 8)
        sample_y = y

        # transform it to torch tensor to move them to cuda
        if torch.cuda.is_available():

          sample_x = torch.from_numpy(sample_x).to("cuda")
          sample_y = np.array(y)
          sample_y = torch.from_numpy(sample_y).to("cuda")


        return sample_x, sample_y
    
    def visualize_datapoint(self, idx):

      x,y = self.__getitem__( idx)
      plt.imshow(x[0].cpu(), cmap="gray")
      plt.axis("off")
      plt.show()


### Module

In [75]:
class MultiConfigCNN(nn.Module):
    '''Conv2d → f(.) → Pooling → Flatten → Linear 1 → f(.) → Linear 2 → Softmax
    '''
    def __init__(self, cnn_i_N = 1, cnn_o_N = 8, cnn_k_size = 3, stride = 1, padding = 1, pool_k_size = 2, fnn_o_N = 10):
        super(MultiConfigCNN, self).__init__()

        self.cnn_i_N = cnn_i_N
        self.cnn_o_N = cnn_o_N
        self.cnn_k_size = cnn_k_size
        self.stride = stride
        self.padding = padding
        self.pool_k_size = pool_k_size
        self.fnn_o_N = fnn_o_N
        
        self.cnn =  nn.Conv2d(in_channels = cnn_i_N, out_channels = cnn_o_N, kernel_size = cnn_k_size, stride = stride, padding = padding)
        self.activation1 = nn.ReLU() # or sigmoid, tanh, softplus, elu
        self.pool = nn.MaxPool2d(kernel_size = pool_k_size) # or avg pool
        self.flatten = nn.Flatten()

        self.linear1 = nn.Linear(in_features = self.compute_input_2_linear(), out_features = fnn_o_N)
        self.activation2 = nn.ReLU() # or sigmoid, tanh, softplus, elu
        self.linear2 = nn.Linear(in_features = fnn_o_N, out_features = 10)
        self.softmax = nn.LogSoftmax(dim=1)

        self.nll = nn.NLLLoss(reduction="none") 
    
    def compute_input_2_linear(self):

        # computing after convolution => [(W-K+2P)/S]+1
        after_cnn_channels = self.cnn_o_N
        after_cnn_height = after_cnn_width = ((8 - self.cnn_k_size + 2 * self.padding) / self.stride) + 1

        # computing after pooling => fixed values stride=kernel_dimension, padding=0, dilation=1 => formula at end of https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
        after_pool_height = after_pool_width = ((after_cnn_height - self.pool_k_size) / self.pool_k_size ) + 1

        # computing after flattening 

        return int(after_pool_height * after_pool_width * after_cnn_channels)

    def classify(self, log_prob):
        
        y_pred = torch.argmax(log_prob, dim = 1).long()        
        return y_pred

    def forward(self, x):
        
        x = self.cnn(x)
        x = self.activation1(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.activation2(x)
        x = self.linear2(x)
        log_prob = self.softmax(x)

        return log_prob


    def compute_loss(self, log_prob, y, reduction="avg"):

        loss = self.nll(log_prob, y)

        if reduction == "sum":
            return loss.sum()

        else:
            return loss.mean()

    def count_misclassified(self, predictions, targets):

        e = 1.0 * (predictions == targets)
        misclassified = (1.0 - e).sum().item()
        
        return misclassified

### Training and Validation Loop

In [76]:
def configure_optimizer(model, lr = 1e-3, wd = 1e-5):
  return torch.optim.Adamax(model.parameters(), lr=lr, weight_decay=wd)

In [82]:
def train_one_epoch(dataloader, model, optimizer, epoch_n = None):

    model.train()

    size = len(dataloader.dataset)
    total_loss = 0
    total_miss = 0
    train_step = 0

    visual_dl = tqdm(dataloader)

    for (X, y) in visual_dl:

        train_step += 1

        # logits
        log_prob = model(X)

        # classification
        predictions = model.classify(log_prob)

        # misclassified 
        missclassified = model.count_misclassified(predictions, y)
        total_miss += missclassified

        # loss
        loss = model.compute_loss(log_prob, y, reduction = "sum")
        total_loss += loss

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        visual_dl.set_postfix({'train_loss': total_loss / (train_step*X.shape[0])})

    visual_dl.close()

    
    
    # compute epoch statistics
    avg_loss = total_loss / size
    avg_ce = total_miss / size
    
    if epoch_n:
      print("")
      print(f"Results of epoch number {epoch_n}:")

    print('')
    print('    Average training loss: {0:.5f}'.format(avg_loss))
    print('    Average classification error: {0:.5f}'.format(avg_ce))

    return avg_loss, avg_ce


### Inference 

In [83]:
# Initialize training, validation and test sets.
train_data = TensorizedDigits(mode="train")
val_data = TensorizedDigits(mode="val")
test_data = TensorizedDigits(mode="test")

# Initialize data loaders.
training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

In [84]:
cnn = MultiConfigCNN().to("cuda")

In [85]:
opt = configure_optimizer(cnn)

In [87]:
for i in range(0,100):
  train_one_epoch(dataloader = training_loader, model = cnn, optimizer = opt, epoch_n = i)

  0%|          | 0/16 [00:00<?, ?it/s]


    Average training loss: 1.19561
    Average classification error: 0.33100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 1:

    Average training loss: 1.11555
    Average classification error: 0.28400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 2:

    Average training loss: 1.04566
    Average classification error: 0.26700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 3:

    Average training loss: 0.97771
    Average classification error: 0.23800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 4:

    Average training loss: 0.91606
    Average classification error: 0.22600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 5:

    Average training loss: 0.85888
    Average classification error: 0.20700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 6:

    Average training loss: 0.80298
    Average classification error: 0.21400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 7:

    Average training loss: 0.74896
    Average classification error: 0.19000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 8:

    Average training loss: 0.70437
    Average classification error: 0.17300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 9:

    Average training loss: 0.66332
    Average classification error: 0.16500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 10:

    Average training loss: 0.62377
    Average classification error: 0.15600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 11:

    Average training loss: 0.59023
    Average classification error: 0.15500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 12:

    Average training loss: 0.55929
    Average classification error: 0.14300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 13:

    Average training loss: 0.53008
    Average classification error: 0.14100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 14:

    Average training loss: 0.50208
    Average classification error: 0.12900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 15:

    Average training loss: 0.47575
    Average classification error: 0.11800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 16:

    Average training loss: 0.45243
    Average classification error: 0.12300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 17:

    Average training loss: 0.42913
    Average classification error: 0.11900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 18:

    Average training loss: 0.40908
    Average classification error: 0.11200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 19:

    Average training loss: 0.39348
    Average classification error: 0.10700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 20:

    Average training loss: 0.37412
    Average classification error: 0.09700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 21:

    Average training loss: 0.35950
    Average classification error: 0.09700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 22:

    Average training loss: 0.34598
    Average classification error: 0.09500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 23:

    Average training loss: 0.33053
    Average classification error: 0.09000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 24:

    Average training loss: 0.31692
    Average classification error: 0.08700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 25:

    Average training loss: 0.30578
    Average classification error: 0.08300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 26:

    Average training loss: 0.29330
    Average classification error: 0.08000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 27:

    Average training loss: 0.28310
    Average classification error: 0.07500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 28:

    Average training loss: 0.27262
    Average classification error: 0.07600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 29:

    Average training loss: 0.26454
    Average classification error: 0.07400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 30:

    Average training loss: 0.25586
    Average classification error: 0.07200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 31:

    Average training loss: 0.24712
    Average classification error: 0.06600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 32:

    Average training loss: 0.23796
    Average classification error: 0.06200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 33:

    Average training loss: 0.23042
    Average classification error: 0.06000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 34:

    Average training loss: 0.22250
    Average classification error: 0.05800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 35:

    Average training loss: 0.21627
    Average classification error: 0.05900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 36:

    Average training loss: 0.20948
    Average classification error: 0.05700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 37:

    Average training loss: 0.20377
    Average classification error: 0.05400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 38:

    Average training loss: 0.19830
    Average classification error: 0.05200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 39:

    Average training loss: 0.19104
    Average classification error: 0.05500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 40:

    Average training loss: 0.18921
    Average classification error: 0.05600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 41:

    Average training loss: 0.18224
    Average classification error: 0.05000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 42:

    Average training loss: 0.17622
    Average classification error: 0.04900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 43:

    Average training loss: 0.17469
    Average classification error: 0.04500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 44:

    Average training loss: 0.16990
    Average classification error: 0.04800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 45:

    Average training loss: 0.16165
    Average classification error: 0.04000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 46:

    Average training loss: 0.15887
    Average classification error: 0.04300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 47:

    Average training loss: 0.15242
    Average classification error: 0.04000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 48:

    Average training loss: 0.15112
    Average classification error: 0.04100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 49:

    Average training loss: 0.14876
    Average classification error: 0.03800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 50:

    Average training loss: 0.14320
    Average classification error: 0.04200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 51:

    Average training loss: 0.13948
    Average classification error: 0.03500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 52:

    Average training loss: 0.13668
    Average classification error: 0.03600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 53:

    Average training loss: 0.13448
    Average classification error: 0.03600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 54:

    Average training loss: 0.12931
    Average classification error: 0.03300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 55:

    Average training loss: 0.12725
    Average classification error: 0.03400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 56:

    Average training loss: 0.12565
    Average classification error: 0.03200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 57:

    Average training loss: 0.12211
    Average classification error: 0.03100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 58:

    Average training loss: 0.11697
    Average classification error: 0.03000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 59:

    Average training loss: 0.11433
    Average classification error: 0.02900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 60:

    Average training loss: 0.11306
    Average classification error: 0.03100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 61:

    Average training loss: 0.10961
    Average classification error: 0.02800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 62:

    Average training loss: 0.10795
    Average classification error: 0.02900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 63:

    Average training loss: 0.10735
    Average classification error: 0.03000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 64:

    Average training loss: 0.10438
    Average classification error: 0.02800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 65:

    Average training loss: 0.10000
    Average classification error: 0.02200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 66:

    Average training loss: 0.09732
    Average classification error: 0.02600


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 67:

    Average training loss: 0.09581
    Average classification error: 0.02800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 68:

    Average training loss: 0.09363
    Average classification error: 0.02400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 69:

    Average training loss: 0.09340
    Average classification error: 0.02400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 70:

    Average training loss: 0.09101
    Average classification error: 0.02200


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 71:

    Average training loss: 0.08996
    Average classification error: 0.02700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 72:

    Average training loss: 0.08738
    Average classification error: 0.02100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 73:

    Average training loss: 0.08760
    Average classification error: 0.02300


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 74:

    Average training loss: 0.08163
    Average classification error: 0.01900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 75:

    Average training loss: 0.08100
    Average classification error: 0.01900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 76:

    Average training loss: 0.07786
    Average classification error: 0.01800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 77:

    Average training loss: 0.07711
    Average classification error: 0.01500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 78:

    Average training loss: 0.07680
    Average classification error: 0.01700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 79:

    Average training loss: 0.07510
    Average classification error: 0.01700


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 80:

    Average training loss: 0.07262
    Average classification error: 0.01800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 81:

    Average training loss: 0.07201
    Average classification error: 0.01500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 82:

    Average training loss: 0.07010
    Average classification error: 0.01500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 83:

    Average training loss: 0.06775
    Average classification error: 0.01400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 84:

    Average training loss: 0.06713
    Average classification error: 0.01500


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 85:

    Average training loss: 0.06615
    Average classification error: 0.01100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 86:

    Average training loss: 0.06559
    Average classification error: 0.01400


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 87:

    Average training loss: 0.06311
    Average classification error: 0.01000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 88:

    Average training loss: 0.06265
    Average classification error: 0.01100


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 89:

    Average training loss: 0.06098
    Average classification error: 0.01000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 90:

    Average training loss: 0.05906
    Average classification error: 0.00900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 91:

    Average training loss: 0.05922
    Average classification error: 0.00800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 92:

    Average training loss: 0.05667
    Average classification error: 0.01000


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 93:

    Average training loss: 0.05494
    Average classification error: 0.00800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 94:

    Average training loss: 0.05379
    Average classification error: 0.00800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 95:

    Average training loss: 0.05400
    Average classification error: 0.00900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 96:

    Average training loss: 0.05207
    Average classification error: 0.00900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 97:

    Average training loss: 0.05126
    Average classification error: 0.00800


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 98:

    Average training loss: 0.05013
    Average classification error: 0.00900


  0%|          | 0/16 [00:00<?, ?it/s]


Results of epoch number 99:

    Average training loss: 0.04938
    Average classification error: 0.00800


In [53]:
for (x,y) in training_loader:
  print(x.shape)
  print(y.shape)
  log_prob = cnn(x)
  prediction  = cnn.classify(log_prob)
  loss = cnn.compute_loss(log_prob,y)
  print(cnn.count_misclassified)
  print(log_prob)
  print(prediction)
  print(loss)
  break

torch.Size([64, 1, 8, 8])
torch.Size([64])
55.0
tensor([2, 4, 2, 0, 5, 1, 5, 3, 2, 1, 2, 2, 3, 3, 4, 3, 0, 5, 7, 5, 6, 9, 4, 7,
        4, 1, 0, 7, 1, 7, 6, 0, 4, 5, 1, 7, 9, 5, 0, 0, 3, 8, 9, 1, 6, 2, 4, 2,
        8, 2, 3, 8, 1, 2, 7, 9, 2, 9, 8, 3, 0, 8, 9, 1], device='cuda:0')
tensor([1, 1, 1, 8, 1, 8, 8, 1, 1, 1, 9, 1, 1, 1, 1, 1, 1, 9, 1, 8, 1, 8, 1, 9,
        1, 1, 8, 1, 1, 1, 1, 8, 1, 1, 1, 9, 9, 8, 8, 8, 9, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9, 1], device='cuda:0')
