In [1]:
import torch
from dataloader import *
from CNN_Models import *
from CNN_Training import *
from model_io import *
from utils.constants import *

Configure Paths

In [2]:
current_path = ""
image_path = current_path + "data/images_type"
label_path = current_path + "data/type_label.csv"

Setting Up Correct Device

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
            "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
            "and/or you do not have an MPS-enabled device on this machine.")
else:
    device = torch.device("mps")
    
print(f"Currently Using Device: {device}")

Currently Using Device: mps


Obtain data loader, optimizer and model

In [4]:
(train, valid) = get_type_datasets(label_path, image_path, 16) # Use batch size of 16
model = get_resnet18_classifier(TYPE_N_CLASSES, dropout=0.2) # Use smaller dropout

optimizer = get_optimizer(model, "sgd", lr = 1e-3, weight_decay = 1e-4, momentum = 0.9)
scheduler = get_scheduler(optimizer, "exponential", gamma = 0.95)
loss_criterion = get_loss() 

Train the Model in Stages

In [5]:
# Save model checkpoint every 5 epochs
num_stages = 3
epoch_per_stage = 10
model_path = current_path + "model/"

# Initialize the logger object
logger = {"train_loss" : [], "valid_loss" : [], "train_acc" : [], "valid_acc" : []}


for stage in range(num_stages):
    CNN_train(model, train, valid, epoch_per_stage,
              loss_criterion, optimizer, scheduler, logger, True)
    model_save_name = model_path + f"model_stage{stage+1}.pth"
    save_checkpoint(model_save_name, model, optimizer, scheduler, logger)

                                                               

Epoch 1: Current LR: [0.00095]
Train Loss: 1.05665 Valid Loss: 0.88703 Train Acc:  0.65174 Valid Acc:  0.69623 


                                                               

Epoch 2: Current LR: [0.0009025]
Train Loss: 0.9138  Valid Loss: 0.8494  Train Acc:  0.70202 Valid Acc:  0.71748 


                                                               

Epoch 3: Current LR: [0.000857375]
Train Loss: 0.86185 Valid Loss: 0.81365 Train Acc:  0.71457 Valid Acc:  0.73727 


                                                               

Epoch 4: Current LR: [0.0008145062499999999]
Train Loss: 0.81613 Valid Loss: 0.81984 Train Acc:  0.73262 Valid Acc:  0.72334 


                                                               

Epoch 5: Current LR: [0.0007737809374999998]
Train Loss: 0.78454 Valid Loss: 0.76626 Train Acc:  0.73674 Valid Acc:  0.74606 


                                                               

Epoch 6: Current LR: [0.0007350918906249997]
Train Loss: 0.74792 Valid Loss: 0.7871  Train Acc:  0.75277 Valid Acc:  0.74313 


                                                               

Epoch 7: Current LR: [0.0006983372960937497]
Train Loss: 0.72674 Valid Loss: 0.74556 Train Acc:  0.75717 Valid Acc:  0.75742 


                                                               

Epoch 8: Current LR: [0.0006634204312890621]
Train Loss: 0.70302 Valid Loss: 0.76275 Train Acc:  0.76459 Valid Acc:  0.75229 


                                                               

Epoch 9: Current LR: [0.000630249409724609]
Train Loss: 0.66158 Valid Loss: 0.79105 Train Acc:  0.7764  Valid Acc:  0.73837 


                                                                

Epoch 10: Current LR: [0.0005987369392383785]
Train Loss: 0.64463 Valid Loss: 0.74676 Train Acc:  0.78263 Valid Acc:  0.76218 
Model Chekcpoint saved to model/model_stage1.pth


                                                                

Epoch 11: Current LR: [0.0005688000922764595]
Train Loss: 0.61178 Valid Loss: 0.80692 Train Acc:  0.79317 Valid Acc:  0.74496 


                                                                

Epoch 12: Current LR: [0.0005403600876626365]
Train Loss: 0.59757 Valid Loss: 0.78329 Train Acc:  0.79674 Valid Acc:  0.74569 


                                                                

Epoch 13: Current LR: [0.0005133420832795047]
Train Loss: 0.5729  Valid Loss: 0.77053 Train Acc:  0.80517 Valid Acc:  0.75339 


                                                                

Epoch 14: Current LR: [0.00048767497911552944]
Train Loss: 0.54209 Valid Loss: 0.79608 Train Acc:  0.81286 Valid Acc:  0.75522 


                                                                

Epoch 15: Current LR: [0.00046329123015975297]
Train Loss: 0.51792 Valid Loss: 0.79549 Train Acc:  0.82376 Valid Acc:  0.74826 


                                                                

Epoch 16: Current LR: [0.0004401266686517653]
Train Loss: 0.48887 Valid Loss: 0.81183 Train Acc:  0.83375 Valid Acc:  0.76621 


                                                                

Epoch 17: Current LR: [0.00041812033521917703]
Train Loss: 0.46929 Valid Loss: 0.81095 Train Acc:  0.84419 Valid Acc:  0.75889 


                                                                

Epoch 18: Current LR: [0.00039721431845821814]
Train Loss: 0.44149 Valid Loss: 0.82377 Train Acc:  0.85271 Valid Acc:  0.75082 


                                                                

Epoch 19: Current LR: [0.0003773536025353072]
Train Loss: 0.4165  Valid Loss: 0.82417 Train Acc:  0.85747 Valid Acc:  0.75449 


                                                                

Epoch 20: Current LR: [0.0003584859224085418]
Train Loss: 0.39886 Valid Loss: 0.87748 Train Acc:  0.86663 Valid Acc:  0.75742 
Model Chekcpoint saved to model/model_stage2.pth


                                                                

Epoch 21: Current LR: [0.0003405616262881147]
Train Loss: 0.37406 Valid Loss: 0.87997 Train Acc:  0.86993 Valid Acc:  0.74386 


                                                                

Epoch 22: Current LR: [0.00032353354497370894]
Train Loss: 0.36245 Valid Loss: 0.93089 Train Acc:  0.87552 Valid Acc:  0.7358  


                                                                

Epoch 23: Current LR: [0.00030735686772502346]
Train Loss: 0.35053 Valid Loss: 0.86338 Train Acc:  0.87955 Valid Acc:  0.74936 


                                                                

Epoch 24: Current LR: [0.00029198902433877225]
Train Loss: 0.3158  Valid Loss: 0.93105 Train Acc:  0.89209 Valid Acc:  0.74569 


                                                                

Epoch 25: Current LR: [0.00027738957312183364]
Train Loss: 0.29637 Valid Loss: 0.95733 Train Acc:  0.89823 Valid Acc:  0.74093 


                                                                

Epoch 26: Current LR: [0.0002635200944657419]
Train Loss: 0.28936 Valid Loss: 0.9683  Train Acc:  0.89979 Valid Acc:  0.74496 


                                                                

Epoch 27: Current LR: [0.0002503440897424548]
Train Loss: 0.26705 Valid Loss: 1.02024 Train Acc:  0.91041 Valid Acc:  0.74716 


                                                                

Epoch 28: Current LR: [0.00023782688525533205]
Train Loss: 0.26203 Valid Loss: 0.95038 Train Acc:  0.90867 Valid Acc:  0.75046 


                                                                

Epoch 29: Current LR: [0.00022593554099256544]
Train Loss: 0.24012 Valid Loss: 0.99107 Train Acc:  0.91674 Valid Acc:  0.75486 


                                                                

Epoch 30: Current LR: [0.00021463876394293716]
Train Loss: 0.23211 Valid Loss: 0.99521 Train Acc:  0.91829 Valid Acc:  0.75559 
Model Chekcpoint saved to model/model_stage3.pth


