In [1]:
from candle.trainer import Trainer
from candle.callbacks import EarlyStopping, LRTracker
import os

In [2]:
import torch

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
from torchvision.datasets import MNIST

In [5]:
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [6]:
import torch
print(torch.__version__)
print(torch.version.cuda)  # Displays CUDA version

2.5.1+cu124
12.4


In [7]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:28:36_Pacific_Standard_Time_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0


In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and standard deviation for MNIST
])

# Step 2: Load the MNIST dataset
train_ds = MNIST(root='./data', train=True, download=True, transform=transform)
val_ds = MNIST(root='./data', train=False, download=True, transform=transform)

# Step 3: Create DataLoaders
train_loader = DataLoader(dataset=train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=64, shuffle=False)

In [9]:
train_ds[0][0].shape

torch.Size([1, 28, 28])

In [10]:
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class GarmentClassifier(nn.Module):
  def __init__(self):
    super(GarmentClassifier,self).__init__()
    self.pool = nn.MaxPool2d(2, stride=2)
    self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(3,3),padding=1)
    self.batchnorm1 = nn.BatchNorm2d(6)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=(3,3),padding=1)
    self.batchnorm2 = nn.BatchNorm2d(16)
    self.conv3 = nn.Conv2d(16, 32, kernel_size=(3,3),padding=1)
    self.batchnorm3 = nn.BatchNorm2d(32)
    self.conv4 = nn.Conv2d(32, 64, kernel_size=(3,3),padding=1)
    self.batchnorm4 = nn.BatchNorm2d(64)
    self.fc1 = nn.Linear(64 * 1 * 1, 128)
    self.batchnorm5 = nn.BatchNorm1d(128)
    self.fc2 = nn.Linear(128, 128)
    self.batchnorm6 = nn.BatchNorm1d(128)
    self.fc3 = nn.Linear(128, 64)
    self.batchnorm7 = nn.BatchNorm1d(64)
    self.fc4 = nn.Linear(64, 32)
    self.batchnorm8 = nn.BatchNorm1d(32)
    self.fc5 = nn.Linear(32, 10)
    self.dropout = nn.Dropout1d(0.1)
  def forward(self,x):
    x = self.pool(F.mish(self.batchnorm1(self.conv1(x))))
    x = self.pool(F.mish(self.batchnorm2(self.conv2(x))))
    x = self.pool(F.mish(self.batchnorm3(self.conv3(x))))
    x = self.pool(F.mish(self.batchnorm4(self.conv4(x))))
    x = x.view(-1, 64 * 1 * 1)
    x = self.dropout(x)
    x = F.leaky_relu(self.batchnorm5(self.fc1(x)))
    x = F.leaky_relu(self.batchnorm6(self.fc2(x)))
    x = F.leaky_relu(self.batchnorm7(self.fc3(x)))
    x = F.leaky_relu(self.batchnorm8(self.fc4(x)))
    x = self.fc5(x)
    return x


In [12]:
from candle.models.vision import BasicCNNClassifier

In [13]:
model = BasicCNNClassifier(input_shape=(1,28,28), num_output_classes=10)
model = model.to(device)

In [14]:
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             160
       BatchNorm2d-2           [-1, 16, 28, 28]              32
              ReLU-3           [-1, 16, 28, 28]               0
         MaxPool2d-4           [-1, 16, 14, 14]               0
            Conv2d-5           [-1, 32, 14, 14]           4,640
       BatchNorm2d-6           [-1, 32, 14, 14]              64
         LeakyReLU-7           [-1, 32, 14, 14]               0
         AvgPool2d-8             [-1, 32, 7, 7]               0
            Conv2d-9             [-1, 64, 7, 7]          18,496
      BatchNorm2d-10             [-1, 64, 7, 7]             128
              ELU-11             [-1, 64, 7, 7]               0
        MaxPool2d-12             [-1, 64, 3, 3]               0
           Linear-13                  [-1, 128]          73,856
      BatchNorm1d-14                  [

In [15]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [16]:
from candle.metrics import Accuracy

In [17]:
accuracy = Accuracy(binary_output=False)

In [18]:
trainer = Trainer( model,
                 criterion=loss_fn,
                 input_shape=(1,28,28),
                 optimizer=optimizer,
                 display_time_elapsed=False,
                 metrics=[accuracy],
                 callbacks= [EarlyStopping(basis="val_accuracy", metric_minimize=True, patience=1),
                             LRTracker()],
                 use_amp= True,
                 device=device)

In [19]:
history = trainer.fit(train_loader,val_loader, epochs=5, epoch_start=0)

EPOCH 0: : 938it [00:17, 53.72it/s]
--> Metrics:   accuracy: 0.9723 ,val_accuracy: 0.9854 ,loss: 0.0900 ,val_loss: 0.0442
----------------------------------------------------------------------------------------------------
EPOCH 1: : 938it [00:18, 51.98it/s]
--> Metrics:   accuracy: 0.9861 ,val_accuracy: 0.9868 ,loss: 0.0445 ,val_loss: 0.0425
----------------------------------------------------------------------------------------------------
Early-stopping at epoch 1, basis : val_accuracy↑
----------------------------------------------------------------------------------------------------
Restoring best weights...
	Best epoch: 0
	Training loss: 0.0900
	Validation loss: 0.0442
	Training accuracy: 0.9723
	Validation accuracy: 0.9854
----------------------------------------------------------------------------------------------------


In [20]:
history

{'accuracy': [0.9723314232409381, 0.9861074093816631],
 'val_accuracy': [0.9853702229299363, 0.9867635350318471],
 'loss': [0.08996184853944125, 0.04450817240564935],
 'val_loss': [0.044190079780520905, 0.04249656062786746],
 'lr': [0.01, 0.01]}

In [20]:
# trainer.load_progress(r"F:\Projects\ML\MyLibs\furnance\temp\saves")

In [25]:
# trainer.tracker.metrics['val_accuracy'].latest