In [1]:
from candle.trainers import Trainer

In [2]:
import torch

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
from torchvision.datasets import MNIST

In [5]:
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [6]:
import torch
print(torch.__version__)
print(torch.version.cuda)  # Displays CUDA version

2.5.1+cu124
12.4


In [7]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:28:36_Pacific_Standard_Time_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0


In [8]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Mean and standard deviation for MNIST
])

# Step 2: Load the MNIST dataset
train_ds = MNIST(root='./data', train=True, download=True, transform=transform)
val_ds = MNIST(root='./data', train=False, download=True, transform=transform)

# Step 3: Create DataLoaders
train_loader = DataLoader(dataset=train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=64, shuffle=False)

In [9]:
train_ds[0][0].shape

torch.Size([1, 28, 28])

In [10]:
import torch.nn as nn
import torch.nn.functional as F

In [11]:
class GarmentClassifier(nn.Module):
  def __init__(self):
    super(GarmentClassifier,self).__init__()
    self.pool = nn.MaxPool2d(2, stride=2)
    self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(3,3),padding=1)
    self.batchnorm1 = nn.BatchNorm2d(6)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=(3,3),padding=1)
    self.batchnorm2 = nn.BatchNorm2d(16)
    self.conv3 = nn.Conv2d(16, 32, kernel_size=(3,3),padding=1)
    self.batchnorm3 = nn.BatchNorm2d(32)
    self.conv4 = nn.Conv2d(32, 64, kernel_size=(3,3),padding=1)
    self.batchnorm4 = nn.BatchNorm2d(64)
    self.fc1 = nn.Linear(64 * 1 * 1, 128)
    self.batchnorm5 = nn.BatchNorm1d(128)
    self.fc2 = nn.Linear(128, 128)
    self.batchnorm6 = nn.BatchNorm1d(128)
    self.fc3 = nn.Linear(128, 64)
    self.batchnorm7 = nn.BatchNorm1d(64)
    self.fc4 = nn.Linear(64, 32)
    self.batchnorm8 = nn.BatchNorm1d(32)
    self.fc5 = nn.Linear(32, 10)
    self.dropout = nn.Dropout1d(0.1)
  def forward(self,x):
    x = self.pool(F.mish(self.batchnorm1(self.conv1(x))))
    x = self.pool(F.mish(self.batchnorm2(self.conv2(x))))
    x = self.pool(F.mish(self.batchnorm3(self.conv3(x))))
    x = self.pool(F.mish(self.batchnorm4(self.conv4(x))))
    x = x.view(-1, 64 * 1 * 1)
    x = self.dropout(x)
    x = F.leaky_relu(self.batchnorm5(self.fc1(x)))
    x = F.leaky_relu(self.batchnorm6(self.fc2(x)))
    x = F.leaky_relu(self.batchnorm7(self.fc3(x)))
    x = F.leaky_relu(self.batchnorm8(self.fc4(x)))
    x = self.fc5(x)
    return x


In [12]:
# model = BasicCNNClassifier(input_shape=(1,28,28), num_output_classes=10)
model = GarmentClassifier()
model = model.to(device)


In [13]:
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]              60
       BatchNorm2d-2            [-1, 6, 28, 28]              12
         MaxPool2d-3            [-1, 6, 14, 14]               0
            Conv2d-4           [-1, 16, 14, 14]             880
       BatchNorm2d-5           [-1, 16, 14, 14]              32
         MaxPool2d-6             [-1, 16, 7, 7]               0
            Conv2d-7             [-1, 32, 7, 7]           4,640
       BatchNorm2d-8             [-1, 32, 7, 7]              64
         MaxPool2d-9             [-1, 32, 3, 3]               0
           Conv2d-10             [-1, 64, 3, 3]          18,496
      BatchNorm2d-11             [-1, 64, 3, 3]             128
        MaxPool2d-12             [-1, 64, 1, 1]               0
        Dropout1d-13                   [-1, 64]               0
           Linear-14                  [

In [14]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [15]:
from candle.metrics import Accuracy

In [16]:
accuracy = Accuracy(binary_output=False)

In [17]:
from candle.callbacks import EarlyStopping, LRTracker, ConsoleLogger

In [18]:
callbacks = [
            ConsoleLogger(display_time_elapsed=True, progress_bar_positions=["train", "validation"]),
            EarlyStopping(basis="val_accuracy", metric_minimize=True, patience=1),
            LRTracker()]

In [19]:
trainer = Trainer( model,
                 criterion=loss_fn,
                 optimizer=optimizer,
                 metrics=[accuracy],
                 callbacks= callbacks,
                 use_amp= True,
                 device=device)

In [20]:
trainer.callbacks.callbacks

[<candle.callbacks.console_logger.ConsoleLogger at 0x1f00cef05b0>,
 <candle.callbacks.early_stopping.EarlyStopping at 0x1f00cef2770>,
 <candle.callbacks.lr_handlers.LRTracker at 0x1f00cef3760>]

In [21]:
history = trainer.fit(train_loader,val_loader, epochs=5, epoch_start=0)

---------------------------------------------Progress---------------------------------------------
EPOCH 0: 100%|██████████| 938/938 [00:50<00:00, 18.67it/s]
Validation: 100%|██████████| 157/157 [00:05<00:00, 28.43it/s]
--> Metrics:   accuracy: 0.8662 ,val_accuracy: 0.9812 ,loss: 0.3925 ,val_loss: 0.0705
Time elapsed: 55.77954721450806 s
----------------------------------------------------------------------------------------------------
EPOCH 1: 100%|██████████| 938/938 [00:48<00:00, 19.35it/s]
Validation: 100%|██████████| 157/157 [00:06<00:00, 24.15it/s]
--> Metrics:   accuracy: 0.8882 ,val_accuracy: 0.9820 ,loss: 0.3098 ,val_loss: 0.0625
Time elapsed: 110.80687236785889 s
----------------------------------------------------------------------------------------------------
Early-stopping at epoch 1, basis : val_accuracy↑
----------------------------------------------------------------------------------------------------
Restoring best weights...
	Best epoch: 0
	Training loss: 0.3925
	V

In [27]:
history

{'accuracy': [0.8879264392324094, 0.8948560767590619],
 'val_accuracy': [0.9785031847133758, 0.9794984076433121],
 'loss': [0.30896627043546643, 0.2870177571207031],
 'val_loss': [0.0773598939939669, 0.07421214381458273],
 'lr': [0.01, 0.01]}

In [21]:
# trainer.load_progress(r"F:\Projects\ML\MyLibs\furnance\temp\saves")

In [22]:
trainer.tracker['val_accuracy'].latest

0.9835788216560509

In [25]:
trainer.tracker['val_accuracy'][1]

0.9835788216560509