In [1]:
from argparse import Namespace
import torch.optim

torch.manual_seed(1)


# <------------- Hyperparameters/Config ------------->
config = Namespace(
    DEVICE = torch.device("mps"), # apple silicon M series
    NUM_WORKERS = 2,
    PIN_MEMORY = True,
    
    EPOCHS = 50,
    LEARNING_RATE = 2e-5,
    BATCH_SIZE = 64,
    WEIGHT_DECAY = 0, #TODO plasplay with weight decay

    # load a model with weights that u have been trained to train it more
    CON_TRAIN = False, # continue to train a model
    LOAD_MODEL_FILE = "./checkpoints/Yolov1_facemask_objectDetection_epoch50_2025-04-09-18h_31m.pt",
    
    DATASET_DIR = "./data", # root path to the dataset dir
    IMAGE_SIZE = 448,

    C = 18, # how many classes in the dataset
    B = 2, # how many bounding boxes does the model perdict per cell
    S = 7, # split_size, how to split the image, 7x7=49 grid cells,
    IOU_THRESHOLD = 0.5, # the iou threshold when comparing bounding boxes for NMS
    MIN_THRESHOLD = 0.4, # the minimal confidence to keep a predicted bounding box
)

config.NUM_NODES_PER_CELL = config.C + 5 * config.B # The total number of nodes per cell, which would be the size ==> [*classes, pc_1, bbox1_x_y_w_h, pc_2, bbox2_x_y_w_h] = 28 nodes.
config.NUM_NODES_PER_IMG = config.S * config.S * (config.C + config.B * 5) # The total number of nodes that each image has. If S=7 C=18 B=2 ==> 7*7 * (18 + 2 * 5) = 1,372 | 28*49 = 1,372 | the *5 is for pc_score, x, y, w, h



In [2]:
S=7; C=18; B=2

In [3]:
import torch

# model output is torch.Size([1, 1372])), example below is reshaping it after output

tensor = torch.arange(28).repeat(49).reshape(1, 7, 7, 28)
tensor, tensor.shape

(tensor([[[[ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           ...,
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27]],
 
          [[ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           ...,
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27]],
 
          [[ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           ...,
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27]],
 
          ...,
 
          [[ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           [ 0,  1,  2,  ..., 25, 26, 27],
           ...,
       

In [6]:
t = tensor
t[0][0]

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27]])

In [13]:
first_bbox_coord_slice = slice(config.NUM_NODES_PER_CELL-9, config.NUM_NODES_PER_CELL-5)
second_bbox_coord_slice = slice(config.NUM_NODES_PER_CELL-4, config.NUM_NODES_PER_CELL)

In [15]:
t[..., first_bbox_coord_slice]
# t[..., second_bbox_coord_slice]

tensor([[[[19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22]],

         [[19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22]],

         [[19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22]],

         [[19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22]],

         [[19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22],
          [19, 20, 21, 22]],

         [

In [36]:
i = slice(config.C+5, config.C+5+1)

In [45]:
t[..., 18].unsqueeze(3).shape # identity of obj i in paper, which tells us is there an object in cell i

torch.Size([1, 7, 7, 1])

In [46]:
t[..., slice(config.C, config.C + 1) ]

tensor([[[[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]],

         [[18],
          [18],
          [18],
          [18],
          [18],
          [18],
          [18]]]])

In [48]:
pc_1_slice = slice(config.C, config.C + 1)  # grab index 18
pc_1_slice

slice(18, 19, None)

In [47]:
t[..., :3]

tensor([[[[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2]],

         [[0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
          [0, 1, 2],
 

In [51]:
from datetime import datetime

date_str = datetime.now().strftime("%Y-%m-%d")
print(date_str)

2025-04-27


In [54]:
def save_checkpoint(state):
    """
    Save a model’s parameter dictionary using a deserialized state_dict

    Parameters
    ----------
        state : pytorch state_dict
            The state_dict of a model
    """
    print(f"\n=> Saving checkpoint\n")

    date_str = datetime.now().strftime("%Y-%m-%d")
    # {model_architecture}_{dataset_name}_{input_size}_{extra_info}.pth
    file_name = f"Yolo_v1_taco_448_448_{date_str}.pt"
    # torch.save(state, file_name)
    print(file_name)

In [55]:
save_checkpoint(None)


=> Saving checkpoint

Yolo_v1_taco_448_448_2025-04-27.pt


In [56]:
def change_list(lst):
    lst.append(10)

mylist = [1, 2, 3]
change_list(mylist)
print(mylist)  # [1, 2, 3, 10]

[1, 2, 3, 10]


In [60]:
print("\n" + "#" * 32)
print(f"Saving checkpoint: ")

print("\n" + "#" * 32, "\n")



################################
Saving checkpoint: 

################################ 



In [None]:
filename = "Yolo_v1_taco_448_448_epoch_150_2025-04-27.pt"

# Split by underscore
parts = filename.split('_')
prev_epoch_num = 0
# Find 'epoch' and the next part
for i, part in enumerate(parts):
    if part == 'epoch':
        prev_epoch_num = int(parts[i+1])

# print(result)  # Outputs: epoch_150
prev_epoch_num

150