In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from glob import glob
from tqdm import tqdm
from collections import Counter

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
class CustomVOCdataset(torchvision.datasets.VOCDetection):
    def init_config_yolo(self, class_mapping, S=7, B=2, C=20, custom_transforms=None):
        # Khởi tạo các tham số cấu hình YOLO cụ thể.
        self.S = S  # Kích thước lưới S x S
        self.B = B  # Số lượng hộp dự đoán
        self.C = C  # Số lượng lớp
        self.class_mapping = class_mapping  # Ánh xạ tên lớp sang chỉ số lớp
        self.custom_transforms = custom_transforms

    def __getitem__(self, index):
        # Lấy hình ảnh và nhãn từ tập VOC.
        image, target = super(CustomVOCdataset, self).__getitem__(index)
        img_width, img_height = image.size

        # Chuyển đổi nhãn thành định dạng YOLO.
        boxes = convert_to_yolo_format(target, img_width, img_height, self.class_mapping)

        # Nếu có các phép biến đổi tùy chỉnh, áp dụng chúng.
        just_boxes = boxes[:, 1:]  # Bỏ chỉ số lớp
        labels = boxes[:, 0]        # Chỉ số lớp

        if self.custom_transforms:
            sample = {
                'image': np.array(image),
                'boxes': just_boxes,
                'labels': labels
            }
            sample = self.custom_transforms(**sample)
            image = sample['image']
            boxes = sample['boxes']
            labels = sample['labels']

        # Tạo một ma trận nhãn trống cho mục tiêu YOLO.
        label_matrix = torch.zeros((self.S, self.S, self.C + 5 * self.B))

        # Chuyển đổi sang dạng tensor.
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)
        image = torch.tensor(image, dtype=torch.float32)

        # Lặp qua từng hộp bao và định dạng cho YOLO.
        for box, class_label in zip(boxes, labels):
            x, width, height = box.tolist()
            class_label = int(class_label)

            # Tính toán ô lưới (i, j) mà hộp này thuộc về.
            i = int(self.S * y)  # Cần thêm y vào đây
            j = int(self.S * x)

            # Tính toán chiều rộng và chiều cao của hộp tương đối với ô lưới.
            width_cell, height_cell = width * self.S, height * self.S

            # Nếu không có đối tượng nào được tìm thấy trong ô lưới này.
            if label_matrix[i, j, 20] == 0:  # Mark as an object
                label_matrix[i, j, 20] = 1

            # Lưu tọa độ của hộp bao dưới dạng khoảng cách từ biên tế ô lưới.
            box_coordinates = torch.tensor([x_cell, y_cell, width_cell, height_cell])
            label_matrix[i, j, 21:25] = box_coordinates

            # Lưu mã hóa one-hot cho lớp.
            label_matrix[i, j, class_label] = 1

        return image, label_matrix

In [None]:
def convert_to_yolo_format(target, img_width, img_height, class_mapping):
    """
    Convert annotation data from VOC format to YOLO format.

    Parameters:
    target (dict): Annotation data from VOCDetection dataset.
    img_width (int): Width of the original image.
    img_height (int): Height of the original image.
    class_mapping (dict): Mapping from class names to integer IDs.

    Returns:
    torch.Tensor: Tensor of shape [N, 5] for N bounding boxes,
    each with [class_id, x_center, y_center, width, height].
    """
    # Extract the list of annotations from the target dictionary.
    annotations = target['annotation']['object']

    # Get the real width and height of the image from the annotation.
    real_width = int(target['annotation']['size']['width'])
    real_height = int(target['annotation']['size']['height'])

    # Ensure that annotations is a list, even if there's only one object.
    if not isinstance(annotations, list):
        annotations = [annotations]

    # Initialize an empty list to store the converted bounding boxes.
    boxes = []

    # Loop through each annotation and convert it to YOLO format.
    for anno in annotations:
        xmin = int(anno['bndbox']['xmin']) / real_width
        xmax = int(anno['bndbox']['xmax']) / real_width
        ymin = int(anno['bndbox']['ymin']) / real_height
        ymax = int(anno['bndbox']['ymax']) / real_height

        # Calculate the center coordinates, width, and height of the bounding box.
        x_center = (xmin + xmax) / 2
        y_center = (ymin + ymax) / 2
        width = xmax - xmin
        height = ymax - ymin

        # Retrieve the class name from the annotation and map it to an integer ID.
        class_name = anno['name']
        class_id = class_mapping[class_name] if class_name in class_mapping else -1

        # Append the YOLO formatted bounding box to the list.
        boxes.append([class_id, x_center, y_center, width, height])

    # Convert the list of boxes to a torch tensor.
    return torch.tensor(boxes)

In [None]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Calculate the Intersection over Union (IoU) between bounding boxes.

    Parameters:
    boxes_preds (tensor): Predicted bounding boxes (BATCH_SIZE, 4).
    boxes_labels (tensor): Ground truth bounding boxes (BATCH_SIZE, 4).
    box_format (str): Box format, can be "midpoint" or "corners".

    Returns:
    tensor: Intersection over Union scores for each example.
    """

    if box_format == "midpoint":
        # Calculate coordinates of top-left (x1, y1) and bottom-right (x2, y2).
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2

        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
    elif box_format == "corners":
        # Extract coordinates for predicted boxes
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]

        # Extract coordinates for ground truth boxes
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    # Calculate coordinates of the intersection rectangle
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # Compute the area of the intersection rectangle
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    # Calculate the areas of the predicted and ground truth boxes
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    # Calculate the Intersection over Union, adding a small epsilon to avoid division by zero
    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [None]:
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
    """
    Perform Non-Maximum Suppression on a list of bounding boxes.

    Parameters:
    bboxes (list): List of bounding boxes, each represented as [class_pred, prob_score, x1, y1, x2, y2].
    iou_threshold (float): IoU threshold to determine correct predicted bounding boxes.
    threshold (float): Threshold to discard predicted bounding boxes (independent of IoU).
    box_format (str): "midpoint" or "corners" to specify the format of bounding boxes.

    Returns:
    list: List of bounding boxes after performing NMS with a specific IoU threshold.
    """

    # Check the data type of the input parameter
    assert type(bboxes) == list

    # Filter predicted bounding boxes based on probability threshold
    bboxes = [box for box in bboxes if box[1] > threshold]

    # Sort bounding boxes by probability in descending order
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)

    # List to store bounding boxes after NMS
    bboxes_after_nms = []

    # Continue looping until the list of bounding boxes is empty
    while bboxes:
        # Get the bounding box with the highest probability
        chosen_box = bboxes.pop(0)

        # Remove bounding boxes with IoU greater than the specified threshold with the chosen box
        bboxes = [
            box for box in bboxes
            if box[0] != chosen_box[0] or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format) < iou_threshold
        ]

        # Add the chosen bounding box to the list after NMS
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

In [None]:
def mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20):
    """
    Calculate the mean average precision (mAP).

    Parameters:
    pred_boxes (list): A list containing predicted bounding boxes with each
                       box defined as [train_idx, class_pred, prob_score, x1, y1, x2, y2].
    true_boxes (list): Similar to pred_boxes but containing information about true boxes.
    iou_threshold (float): IoU threshold, where predicted boxes are considered correct.
    box_format (str): "midpoint" or "corners" used to specify the format of the boxes.
    num_classes (int): Number of classes.

    Returns:
    float: The mAP value across all classes with a specific IoU threshold.
    """

    # List to store mAP for each class
    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # Iterate through all predictions and targets, and only add those belonging to the current class 'c'.
        for detection in pred_boxes:
            if detection[1] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        # Find the number of boxes for each training example.
        amount_bboxes = Counter([gt[0] for gt in ground_truths])
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # Sort by box probability
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # If there are no ground truth boxes for this class, it can be safely skipped
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_bboxes[detection[0]][best_gt_idx] == 0:  # True positive
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:  # False positive
                    FP[detection_idx] = 1
            else:
                FP[detection_idx] = 1

        # Calculate cumulative TP and FP
        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        precisions = torch.divide(TP_cumsum, TP_cumsum + FP_cumsum + epsilon)
        recalls = torch.cat((torch.tensor([0]), precisions))

        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions) if average_precisions else 0

In [None]:
import torch
import torch.nn as nn
from collections import Counter

# Cấu hình kiến trúc
architecture_config = [
    (7, 64, 2, 3),     # Khối tích chập 1
    "M",               # Lớp Max-pooling 1
    (3, 192, 1, 1),   # Khối tích chập 2
    "M",               # Lớp Max-pooling 2
    (1, 128, 1, 0),   # Khối tích chập 3
    (3, 256, 1, 1),   # Khối tích chập 4
    (1, 256, 1, 0),   # Khối tích chập 5
    (3, 512, 1, 1),   # Khối tích chập 6
    "M",               # Lớp Max-pooling 3
    (3, 512, 1, 1),   # Khối tích chập 7
    (3, 512, 1, 1),   # Khối tích chập 8
    (3, 512, 1, 1),   # Khối tích chập 9
    "M",               # Lớp Max-pooling 4
    (3, 1024, 1, 1),  # Khối tích chập 10
    "M",               # Lớp Max-pooling 5
    (3, 1024, 1, 1),  # Khối tích chập 11
]

# Định nghĩa khối CNN
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))

# Định nghĩa lớp YOLOv1
class YoloV1(nn.Module):
    def __init__(self, in_channels=3, **kwargs):
        super(YoloV1, self).__init__()
        self.in_channels = in_channels
        self.architecture_config = kwargs.get('architecture_config', [])
        self.conv_layers = self.create_conv_layers(self.architecture_config)
        self.fc_layers = self.create_fc_layers(**kwargs)

    def forward(self, x):
        x = self.conv_layers(x)
        return self.fc_layers(torch.flatten(x, start_dim=1))

    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if isinstance(x, tuple):
                layers += [
                    CNNBlock(
                        in_channels, x[1],
                        kernel_size=x[0],
                        stride=x[2],
                        padding=x[3]
                    )
                ]
                in_channels = x[1]
            elif x == "M":
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

        return nn.Sequential(*layers)

    def create_fc_layers(self, split_size, num_boxes, num_classes):
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * split_size * split_size, 4096),
            nn.Dropout(0.0),
            nn.LeakyReLU(0.1),
            nn.Linear(4096, split_size * split_size * (num_boxes * 5 + num_classes))
        )

# Tạo mô hình với các thông số cụ thể
model = YoloV1(architecture_config=architecture_config, split_size=7, num_boxes=2, num_classes=20)
print(model)

In [None]:
import torch
import torch.nn as nn

class YoloLoss(nn.Module):
    """
    Tính toán loss cho mô hình YOLO (v1).
    """
    def __init__(self, S=7, B=2, C=20):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.S = S
        self.B = B
        self.C = C
        self.lambda_noobj = 0.5
        self.lambda_coord = 5

    def forward(self, predictions, target):
        # Reshape dự đoán thành định dạng (BATCH_SIZE, S, S, C + B*5)
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)

        # Tính toán IoU cho các bounding box
        iou_b1 = self.intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = self.intersection_over_union(predictions[..., 26:30], target[..., 21:25])

        iou = torch.cat((iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)), dim=0)

        # Lấy box có IoU cao nhất
        iou_maxes, bestbox = torch.max(iou, dim=0)
        exists_box = target[..., 20].unsqueeze(3)  # đại diện cho obj 1j trong tài liệu

        # Tính toán box_targets
        box_targets = exists_box * target[..., 21:25]
        box_predictions = exists_box * (
            bestbox * predictions[..., 26:30] +
            (1 - bestbox) * predictions[..., 21:25]
        )

        # Tính toán loss cho box
        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2)
        )

        # ================= #
        #   FOR OBJECT LOSS  #
        # ================= #
        pred_box = (
            bestbox * predictions[..., 25:26] +
            (1 - bestbox) * predictions[..., 21:21]
        )
        object_loss = self.mse(
            torch.flatten(exists_box * pred_box),
            torch.flatten(exists_box * target[..., 20:21])
        )

        # =================== #
        #   FOR NO OBJECT LOSS  #
        # =================== #
        no_object_loss = self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 25:26]),
            torch.flatten((1 - exists_box) * target[..., 20:21])
        )

        # ================= #
        #      FOR CLASS LOSS      #
        # ================= #
        class_loss = self.mse(
            torch.flatten(exists_box * predictions[..., :20]),
            torch.flatten(exists_box * target[..., :20])
        )

        # Tính toán loss cuối cùng bằng cách kết hợp các thành phần trên.
        loss = (
            self.lambda_coord * box_loss +  # First term
            object_loss +                    # Second term
            self.lambda_noobj * no_object_loss +  # Third term
            class_loss                        # Fourth term
        )
        return loss


In [None]:
import torch
import albumentations as A

# Định nghĩa các hằng số cho kích thước
WIDTH = 448
HEIGHT = 448

# Hàm để lấy các biến đổi cho tập huấn luyện
def get_train_transforms():
    return A.Compose([
        A.OneOf([
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.9),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9)
        ]),
        A.ToGray(p=0.01),
        A.HorizontalFlip(p=0.2),
        A.VerticalFlip(p=0.2),
    ])

# Hàm để lấy các biến đổi cho tập xác thực
def get_valid_transforms():
    return A.Compose([
        A.Resize(height=WIDTH, width=WIDTH, p=1.0),
        A.ToTensorV2(p=1.0),
        A.BboxParams(format='yolo', min_area=0, label_fields=['labels'])
    ])

# Thiết lập ngẫu nhiên cho khả năng tái tạo
seed = 123
torch.manual_seed(seed)

# Các tham số và cấu hình
LEARNING_RATE = 2e-5
DEVICE = "cuda"
BATCH_SIZE = 16
EPOCHS = 300
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = False
LOAD_MODEL_FILE = "yolov1.pth.tar"

# Định nghĩa ánh xạ lớp
class_mapping = {
    'aeroplane': 0,
    'bicycle': 1,
    'bird': 2,
    'boat': 3,
    'bottle': 4,
    'bus': 5,
    'car': 6,
    'cat': 7,
    'chair': 8,
    'cow': 9,
    'diningtable': 10,
    'dog': 11,
    'horse': 12,
    'motorbike': 13,
    'person': 14,
    'pottedplant': 15,
    'sheep': 16,
    'sofa': 17,
    'train': 18,
    'tvmonitor': 19
}

In [None]:
import torch

def train_fn(train_loader, model, optimizer, loss_fn, epoch):
    mean_loss = []
    mean_AP = []

    total_batches = len(train_loader)
    display_interval = total_batches // 5  # Cập nhật sau 20% của tổng số batch.

    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Tiến hành dự đoán
        out = model(x)
        loss = loss_fn(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_boxes, true_boxes = get_bboxes_training(out, y, iou_threshold=0.4)
        mAP = mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.4, box_format="midpoint")

        mean_loss.append(loss.item())
        mean_AP.append(mAP.item())

        if batch_idx % display_interval == 0 or batch_idx == total_batches - 1:
            print(f"Epoch: {epoch}: Iter: {batch_idx}/{total_batches}: Loss: {loss.item():.3f}  mAP: {mAP.item():.3f}")

    avg_loss = sum(mean_loss) / len(mean_loss)
    avg_mAP = sum(mean_AP) / len(mean_AP)
    print(f"Train Loss: {avg_loss:.3f}  mAP: {avg_mAP:.3f}")

    return avg_mAP


def test_fn(test_loader, model, loss_fn, epoch):
    model.eval()
    mean_loss = []
    mean_AP = []

    for batch_idx, (x, y) in enumerate(test_loader):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Tiến hành dự đoán
        out = model(x)
        loss = loss_fn(out, y)

        pred_boxes, true_boxes = get_bboxes_training(out, y, iou_threshold=0.4)
        mAP = mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.4, box_format="midpoint")

        mean_loss.append(loss.item())
        mean_AP.append(mAP.item())

    avg_loss = sum(mean_loss) / len(mean_loss)
    avg_mAP = sum(mean_AP) / len(mean_AP)
    print(f"Test Loss: {avg_loss:.3f}  mAP: {avg_mAP:.3f}")

    model.train ()

    return avg_mAP

In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torch.utils.data import DataLoader, SubsetRandomSampler

def plot_image_with_labels(image, ground_truth_boxes, predicted_boxes, class_mapping):
    """Draw both ground truth and predicted bounding boxes on an image, with labels."""
    inverted_class_mapping = {v: k for k, v in class_mapping.items()}

    # Convert the image to a numpy array and get its dimensions
    im = np.array(image)
    height, width, _ = im.shape

    # Create a figure and axis for plotting
    fig, ax = plt.subplots(1)
    ax.imshow(im)

    # Plot each ground truth box in green
    for box in ground_truth_boxes:
        label_index = box[0]
        box = box[1:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle((upper_left_x * width, upper_left_y * height), box[2] * width, box[3] * height,
                                 linewidth=1, edgecolor="green", facecolor="none")
        ax.add_patch(rect)

        class_name = inverted_class_mapping.get(label_index, "Unknown")
        ax.text(upper_left_x * width, upper_left_y * height, class_name, fontsize=12,
                bbox=dict(facecolor='green', alpha=0.2))

    # Plot each predicted box in red
    for box in predicted_boxes:
        label_index = box[0]
        box = box[1:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle((upper_left_x * width, upper_left_y * height), box[2] * width, box[3] * height,
                                 linewidth=1, edgecolor="red", facecolor="none")
        ax.add_patch(rect)

        class_name = inverted_class_mapping.get(label_index, "Unknown")
        ax.text(upper_left_x * width, upper_left_y * height, class_name, fontsize=12,
                bbox=dict(facecolor='red', alpha=0.2))

    plt.show()

def test():
    # Create a YOLO model object with specific hyperparameters
    model = YoloV1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)

    # Load saved model weights and optimizer information from a file, if applicable
    if LOAD_MODEL:
        model.load_state_dict(torch.load(LOAD_MODEL_FILE)['state_dict'])

    # Prepare the test dataset and DataLoader for model evaluation
    test_dataset = CustomVOCDataSet(root='./data/', image_set='val', download=True)
    test_dataset.init_config_yolo(class_mapping=class_mapping, custom_transforms=get_valid_transforms())
    test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                             shuffle=False, drop_last=False)

    model.eval()
    # Iterate over the test dataset and process each batch
    for x, y in test_loader:
        x = x.to(DEVICE)
        out = model(x)

        # Convert model output to bounding boxes and apply non-max suppression
        pred_boxes = cellboxes_to_boxes(out)
        gt_boxes = cellboxes_to_boxes(y)

        # Plot the first 8 images with their ground truth and predicted bounding boxes
        for idx in range(8):
            pred_box = non_max_suppression(pred_boxes[idx], iou_threshold=0.5, box_format="midpoint")
            gt_box = non_max_suppression(gt_boxes[idx], iou_threshold=0.5, box_format="midpoint")

            image = x[idx].permute(1, 2, 0).cpu().detach().numpy() * 255
            plot_image_with_labels(image, gt_box, pred_box, class_mapping)

        break  # Stop after processing the first batch

# Main function to train the model
def train():
    # Initialize model, optimizer, loss
    model = YoloV1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = YoloLoss()

    # Load checkpoint if necessary
    if LOAD_MODEL:
        load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)

    # Create the full dataset
    train_dataset = CustomVOCDataSet(root='./data/', year='2012', image_set='train', download=True)
    train_dataset.init_config_yolo(class_mapping=class_mapping, custom_transforms=get_train_transforms())

    testval_dataset = CustomVOCDataSet(root='./data/', year='2012', image_set='val', download=True)
    testval_dataset.init_config_yolo(class_mapping=class_mapping, custom_transforms=get_valid_transforms())

    # Split dataset into train, validation, and test sets using indices
    dataset_size = len(testval_dataset)
    val_size = int(0.15 * dataset_size)
    test_size = dataset_size - val_size

    val_indices = list(range(val_size))
    test_indices = list(range(val_size, val_size + test_size))

    # Create samplers
    val_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    # Create DataLoaders using samplers
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)

    val_loader = DataLoader(dataset=testval_dataset, batch_size=BATCH_SIZE,
                            num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, sampler=val_sampler)

    test_loader = DataLoader(dataset=testval_dataset, batch_size=BATCH_SIZE,
                             num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, sampler=test_sampler)

    best_mAP_train = 0
    best_mAP_val = 0
    best_mAP_test = 0

    # Training loop
    for epoch in range(EPOCHS):
        train_mAP = train_fn(train_loader, model, optimizer, loss_fn, epoch)
        val_mAP = val_test_fn(val_loader, model, loss_fn, epoch)
        test_mAP = val_test_fn(test_loader, model, loss_fn, epoch, is_test=True)

        # Update best mAP values
        if train_mAP > best_mAP_train:
            best_mAP_train = train_mAP
        if val_mAP > best_mAP_val:
            best_mAP_val = val_mAP
        if test_mAP > best_mAP_test:
            best_mAP_test = test_mAP

    print(f"Best Train mAP: {best_mAP_train:.3f}")
    print(f"Best Val mAP: {best_mAP_val:.3f}")
    print(f"Best Test mAP: {best_mAP_test:.3f}")

# Chạy hàm train
if __name__ == "__main__":
    train()