In [None]:
!pip install ijson

Collecting ijson
  Downloading ijson-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Downloading ijson-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (119 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/119.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.2/119.2 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ijson
Successfully installed ijson-3.3.0


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import os
import ijson

# Data Preprocessing

In [None]:
!unzip 'trainA_original_700.zip' # initial image tensors without resize

Archive:  trainA_original_700.zip
  inflating: trainA_original_700/0d1fc580-678ed420.jpg  
  inflating: trainA_original_700/0d1fc580-61bbac5d.jpg  
  inflating: trainA_original_700/0d2eaa9b-d7d07e80.jpg  
  inflating: trainA_original_700/0d2c8e1d-16f37059.jpg  
  inflating: trainA_original_700/0d5b19b3-fe488e51.jpg  
  inflating: trainA_original_700/0d3a2ed1-ac4b05e1.jpg  
  inflating: trainA_original_700/0d5a4fa0-4d2d45e6.jpg  
  inflating: trainA_original_700/0d1e09fe-f692552c.jpg  
  inflating: trainA_original_700/0d3a4e7f-42aea218.jpg  
  inflating: trainA_original_700/0d2c8e1d-aebc65ec.jpg  
  inflating: trainA_original_700/0d3a4e7f-fa62521a.jpg  
  inflating: trainA_original_700/0d2c8e1d-e81ed48e.jpg  
  inflating: trainA_original_700/0d5b19b3-e3f51851.jpg  
  inflating: trainA_original_700/0d5b19b3-fd1646ea.jpg  
  inflating: trainA_original_700/0d3a8582-d018e1f6.jpg  
  inflating: trainA_original_700/0d4d66bf-2d327e43.jpg  
  inflating: trainA_original_700/0d1fc580-3be6937e.jpg

In [None]:
!unzip 'trainA_norm_700.zip' # image tensors resized to [3, 223, 223]

unzip:  cannot find or open trainA_norm_700.zip, trainA_norm_700.zip.zip or trainA_norm_700.zip.ZIP.


## Labeling

In [None]:
def extract_first_n_labels(json_file_path, n):
    labels = []
    with open(json_file_path, 'rb') as f:
        parser = ijson.items(f, 'item')
        for i, item in enumerate(parser):
            if i >= n: # ensuring only labels of n images are extracted from the json file
                break
            # Filter out objects with poly2d labels and keep only box2d labels
            filtered_labels = [
                {
                    "category": label_item.get("category"),
                    "box2d": label_item.get("box2d")
                }
                for label_item in item.get("labels", [])
                if "box2d" in label_item # only get label_item with box2d
            ]
            labels.append({
                "name": item.get("name"),
                "timestamp": item.get("timestamp"),
                "labels": filtered_labels
            })
    return labels

json_file_path = 'bdd100k_labels_images_train.json'

# Extract the first 700 labels
first_700_labels = extract_first_n_labels(json_file_path, 700)


In [None]:
# Print a few labels to verify
for i, label in enumerate(first_700_labels[:2]):
    print(f"Label {i + 1}:")
    print(f"  Name: {label['name']}")
    print(f"  Timestamp: {label['timestamp']}")
    print("  Labels:")
    for obj in label['labels']:
        print(f"    Category: {obj['category']}")
        print(f"    2D Box: {obj['box2d']}")
    print("-" * 40)

Label 1:
  Name: 0000f77c-6257be58.jpg
  Timestamp: 10000
  Labels:
    Category: traffic light
    2D Box: {'x1': Decimal('1125.902264'), 'y1': Decimal('133.184488'), 'x2': Decimal('1156.978645'), 'y2': Decimal('210.875445')}
    Category: traffic light
    2D Box: {'x1': Decimal('1156.978645'), 'y1': Decimal('136.637417'), 'x2': Decimal('1191.50796'), 'y2': Decimal('210.875443')}
    Category: traffic sign
    2D Box: {'x1': Decimal('1101.731743'), 'y1': Decimal('211.122087'), 'x2': Decimal('1170.79037'), 'y2': Decimal('233.566141')}
    Category: traffic sign
    2D Box: {'x1': 0, 'y1': Decimal('0.246631'), 'x2': Decimal('100.381647'), 'y2': Decimal('122.825696')}
    Category: car
    2D Box: {'x1': Decimal('45.240919'), 'y1': Decimal('254.530367'), 'x2': Decimal('357.805838'), 'y2': Decimal('487.906215')}
    Category: car
    2D Box: {'x1': Decimal('507.82755'), 'y1': Decimal('221.727518'), 'x2': Decimal('908.367588'), 'y2': Decimal('442.715126')}
    Category: traffic sign
    2

## Preparing Datasets

In [None]:
import os
import torch
from torch.utils.data import Dataset, random_split
from PIL import Image
import torchvision.transforms as transforms
from decimal import Decimal

def standardize_filename(path_or_name):
    """
    Convert something like:
      - "folder/subfolder/abcd123.jpg.pt" -> "abcd123"
      - "abcd123.png" -> "abcd123"
    i.e. strip directories & remove extensions.
    """
    # Remove any directories
    base = os.path.basename(path_or_name)
    # Split off the first extension (e.g. ".pt" or ".jpg")
    base, _ = os.path.splitext(base)
    return base  # e.g. "abcd123.jpg" or just "abcd123" if there were two extensions

class CustomDataset(Dataset):
    def __init__(self, image_dir, labels, pt_dir='pt_files', max_boxes=16):
        self.image_dir = image_dir
        self.pt_dir = pt_dir
        self.image_files = sorted([os.path.join(image_dir, f)
                                   for f in os.listdir(image_dir)
                                   if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
        self.max_boxes = max_boxes  # Maximum number of bounding boxes per image

        # Create the directory for .pt files if it doesn't exist
        os.makedirs(self.pt_dir, exist_ok=True)

        # Build a dict from label "base name" -> the full label dictionary
        self.label_dict = {}
        for item in labels:
            key = standardize_filename(item["name"])
            self.label_dict[key] = item

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
      # Get the image path
      image_path = self.image_files[idx]
      transform = transforms.Compose([
          transforms.PILToTensor(),  # Keeps pixel values in range [0, 255]
      ])

      # Define the path for the .pt file
      pt_path = os.path.join(
          self.pt_dir,
          os.path.basename(image_path)
          .replace('.jpg', '.pt')
          .replace('.png', '.pt')
          .replace('.jpeg', '.pt')
      )

      # Check if the .pt file already exists
      if os.path.exists(pt_path):
          print(f"Loading .pt file: {pt_path}")
          image_tensor = torch.load(pt_path)
      else:
          print(f"Saving .pt file: {pt_path}")
          # Load the image and convert it to a tensor
          image = Image.open(image_path).convert('RGB')
          image_tensor = transform(image).float()  # Convert to float for normalization if needed
          torch.save(image_tensor, pt_path)

      # Standardize the base name so it lines up with your JSON
      base_key = standardize_filename(image_path)

      # Debug: Print the base_key to verify it matches the label_dict keys
      print(f"Base key for image {idx}: {base_key}")

      matched_label = self.label_dict.get(base_key, None)
      if matched_label is None or "labels" not in matched_label:
          # Return empty boxes and labels
          return image_tensor, {
              "boxes": torch.zeros((self.max_boxes, 4), dtype=torch.float32),  # Pad to max_boxes
              "labels": torch.zeros((self.max_boxes,), dtype=torch.int64),     # Pad to max_boxes
              "names": []
          }

      boxes = []
      class_labels_str = []  # will store strings like "car", "person", etc.
      for obj in matched_label["labels"]:
          if "box2d" in obj:
              b2d = obj["box2d"]
              # Convert Decimal values to float
              x1 = float(b2d["x1"])
              y1 = float(b2d["y1"])
              x2 = float(b2d["x2"])
              y2 = float(b2d["y2"])
              boxes.append([x1, y1, x2, y2])
              class_labels_str.append(obj["category"])  # e.g. "car", "person", etc.

      if len(boxes) == 0:
          # Return empty boxes and labels
          return image_tensor, {
              "boxes": torch.zeros((self.max_boxes, 4), dtype=torch.float32),  # Pad to max_boxes
              "labels": torch.zeros((self.max_boxes,), dtype=torch.int64),     # Pad to max_boxes
              "names": []
          }

      # Convert boxes to tensor
      boxes_tensor = torch.tensor(boxes, dtype=torch.float32)

      # Truncate or pad boxes and labels to a fixed size (max_boxes)
      if boxes_tensor.shape[0] > self.max_boxes:
          # Truncate excess boxes
          boxes_tensor = boxes_tensor[:self.max_boxes]
          class_labels_str = class_labels_str[:self.max_boxes]
      elif boxes_tensor.shape[0] < self.max_boxes:
          # Pad with zeros
          padding = torch.zeros((self.max_boxes - boxes_tensor.shape[0], 4), dtype=torch.float32)
          boxes_tensor = torch.cat([boxes_tensor, padding], dim=0)
          class_labels_str.extend([""] * (self.max_boxes - len(class_labels_str)))  # Pad names with empty strings

      # Convert labels to tensor
      labels_tensor = torch.tensor([1] * len(class_labels_str), dtype=torch.int64)

      # Return image + label dict
      return image_tensor, {
          "boxes": boxes_tensor,
          "labels": labels_tensor,
          "names": class_labels_str  # <--- store string names here
      }

# Specify the path to the directory containing the original images
image_dir = 'trainA_original_700'  # Replace with your actual directory

# Define the labels for your images (update as necessary)
# Example: Provide meaningful labels for each image
json_file_path = 'bdd100k_labels_images_train.json'
first_700_labels = extract_first_n_labels(json_file_path, 700)

# Specify the directory to save the .pt files
pt_dir = 'trainA_testing3'

# Create the dataset
dataset = CustomDataset(image_dir, first_700_labels, pt_dir, max_boxes=16)
# Iterate over the dataset to process and save the images as .pt files
for i in range(len(dataset)):
    _ = dataset[i]


Saving .pt file: trainA_testing3/0000f77c-6257be58.pt
Base key for image 0: 0000f77c-6257be58
Saving .pt file: trainA_testing3/0000f77c-62c2a288.pt
Base key for image 1: 0000f77c-62c2a288
Saving .pt file: trainA_testing3/0000f77c-cb820c98.pt
Base key for image 2: 0000f77c-cb820c98
Saving .pt file: trainA_testing3/000d4f89-3bcbe37a.pt
Base key for image 3: 000d4f89-3bcbe37a
Saving .pt file: trainA_testing3/000e0252-8523a4a9.pt
Base key for image 4: 000e0252-8523a4a9
Saving .pt file: trainA_testing3/000f157f-dab3a407.pt
Base key for image 5: 000f157f-dab3a407
Saving .pt file: trainA_testing3/000f8d37-d4c09a0f.pt
Base key for image 6: 000f8d37-d4c09a0f
Saving .pt file: trainA_testing3/00a0f008-3c67908e.pt
Base key for image 7: 00a0f008-3c67908e
Saving .pt file: trainA_testing3/00a0f008-a315437f.pt
Base key for image 8: 00a0f008-a315437f
Saving .pt file: trainA_testing3/00a1176f-0652080e.pt
Base key for image 9: 00a1176f-0652080e
Saving .pt file: trainA_testing3/00a1176f-5121b501.pt
Base k

In [None]:
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


# RPN from Scratch

## Embedding & Feature Extraction

In [None]:
class ResNetBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetBackbone, self).__init__()
        # Load Pretrain RESNET
        resnet = models.resnet50(pretrained=pretrained)

        # Remove the fully connected layers (we only need the feature extractor)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self, x):
        # Forward pass through the backbone
        features = self.backbone(x)
        return features

## Anchor Generator


In [None]:
def generate_base_anchors(scales, ratios, base_size=16):
    """
    Generate base anchor boxes in [x1, y1, x2, y2] format.
    """
    anchors = []
    for scale in scales:
        for ratio in ratios:
            # Compute width and height based on scale and aspect ratio.
            w = base_size * scale * np.sqrt(ratio)
            h = base_size * scale / np.sqrt(ratio)
            # intialize a square centered at the origin
            x1 = -w / 2.0
            y1 = -h / 2.0
            x2 = w / 2.0
            y2 = h / 2.0
            anchors.append([x1, y1, x2, y2])
    return np.array(anchors)

class AnchorGenerator:
    def __init__(self, scales=[8, 16, 32], ratios=[0.5, 1, 2], stride=16):
        self.scales = scales
        self.ratios = ratios
        self.base_anchors = generate_base_anchors(scales, ratios, base_size=stride)
        self.stride = stride

    def generate_anchors(self, feature_height, feature_width):
        """
        Generate all anchors for a feature map of size (feature_height, feature_width).
        Shifts the base anchors to all spatial locations in the feature map.
        """
        shift_x = torch.arange(0, feature_width * self.stride, step=self.stride, dtype=torch.float32)
        shift_y = torch.arange(0, feature_height * self.stride, step=self.stride, dtype=torch.float32)

        # Generate the 2D tensors (x,y)
        shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing="ij")
        shifts = torch.stack((
            shift_x.reshape(-1),
            shift_y.reshape(-1),
            shift_x.reshape(-1),
            shift_y.reshape(-1)
        ), dim=1)

        base_anchors = torch.tensor(self.base_anchors, dtype=torch.float32)
        A = base_anchors.shape[0] # Number of base anchors
        K = shifts.shape[0] # Total number of positions in the feature map
        # Add the base anchors to each shift to generate all anchors.
        anchors = base_anchors.reshape(1, A, 4) + shifts.reshape(K, 1, 4)
        anchors = anchors.reshape(K * A, 4)
        return anchors

# IOU

In [None]:
def compute_iou(boxes1, boxes2):
    area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
    area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N, M, 2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N, M, 2]

    wh = (rb - lt).clamp(min=0)  # [N, M, 2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N, M]

    union = area1[:, None] + area2 - inter
    iou = inter / union
    return iou
def assign_targets(anchors, gt_boxes, pos_iou_thresh=0.7, neg_iou_thresh=0.3):
    N = anchors.shape[0]
    labels = -1 * torch.ones((N,), dtype=torch.int64)
    iou = compute_iou(anchors, gt_boxes)
    max_iou, argmax_iou = iou.max(dim=1)
    labels[max_iou < neg_iou_thresh] = 0
    labels[max_iou >= pos_iou_thresh] = 1
    max_iou_per_gt, _ = iou.max(dim=0)
    for i in range(gt_boxes.shape[0]):
        inds = (iou[:, i] == max_iou_per_gt[i]).nonzero(as_tuple=False).view(-1)
        labels[inds] = 1
    bbox_targets = torch.zeros((N, 4), dtype=torch.float32)
    pos_inds = torch.nonzero(labels == 1, as_tuple=False).view(-1)
    if pos_inds.numel() > 0:
        assigned_gt = gt_boxes[argmax_iou[pos_inds]]
        anchors_pos = anchors[pos_inds]
        widths = anchors_pos[:, 2] - anchors_pos[:, 0]
        heights = anchors_pos[:, 3] - anchors_pos[:, 1]
        ctr_x = anchors_pos[:, 0] + 0.5 * widths
        ctr_y = anchors_pos[:, 1] + 0.5 * heights
        gt_widths = assigned_gt[:, 2] - assigned_gt[:, 0]
        gt_heights = assigned_gt[:, 3] - assigned_gt[:, 1]
        gt_ctr_x = assigned_gt[:, 0] + 0.5 * gt_widths
        gt_ctr_y = assigned_gt[:, 1] + 0.5 * gt_heights
        targets_dx = (gt_ctr_x - ctr_x) / widths
        targets_dy = (gt_ctr_y - ctr_y) / heights
        targets_dw = torch.log(gt_widths / widths)
        targets_dh = torch.log(gt_heights / heights)
        bbox_targets[pos_inds] = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
    return labels, bbox_targets


# RPN

In [None]:
class RPN(nn.Module):
    def __init__(self, in_channels, num_anchors=9):
        super(RPN, self).__init__()
        self.num_anchors = num_anchors
        self.conv = nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1)
        self.objectness = nn.Conv2d(512, num_anchors * 2, kernel_size=1, stride=1)
        self.bbox_reg = nn.Conv2d(512, num_anchors * 4, kernel_size=1, stride=1)

    def forward(self, x):
        x = torch.relu(self.conv(x))
        objectness_scores = self.objectness(x)
        bbox_offsets = self.bbox_reg(x)
        return objectness_scores, bbox_offsets


In [None]:
class RCNN(nn.Module):
    def __init__(self, backbone, rpn, anchor_generator):
        super(RCNN, self).__init__()
        self.backbone = backbone
        self.rpn = rpn
        self.anchor_generator = anchor_generator

    def forward(self, x, gt_boxes=None):
        features = self.backbone(x)
        objectness_scores, bbox_offsets = self.rpn(features)
        _, _, feat_h, feat_w = features.shape
        device = features.device
        anchors = self.anchor_generator.generate_anchors(feat_h, feat_w)
        if self.training and gt_boxes is not None:
            labels, bbox_targets = assign_targets(anchors, gt_boxes)
            return objectness_scores, bbox_offsets, anchors, labels, bbox_targets
        return objectness_scores, bbox_offsets, anchors

In [None]:
def apply_bbox_offsets(anchors, bbox_offsets):
    """
    Apply predicted offsets to anchors to generate final bounding boxes.
    anchors: (N, 4)
    bbox_offsets: (N, 4)
    Returns: (N, 4) bounding boxes.
    """
    widths = anchors[:, 2] - anchors[:, 0]
    heights = anchors[:, 3] - anchors[:, 1]
    ctr_x = anchors[:, 0] + 0.5 * widths
    ctr_y = anchors[:, 1] + 0.5 * heights

    dx = bbox_offsets[:, 0]
    dy = bbox_offsets[:, 1]
    dw = bbox_offsets[:, 2]
    dh = bbox_offsets[:, 3]

    pred_ctr_x = dx * widths + ctr_x
    pred_ctr_y = dy * heights + ctr_y
    pred_w = torch.exp(dw) * widths
    pred_h = torch.exp(dh) * heights

    pred_boxes = torch.zeros_like(anchors)
    pred_boxes[:, 0] = pred_ctr_x - 0.5 * pred_w
    pred_boxes[:, 1] = pred_ctr_y - 0.5 * pred_h
    pred_boxes[:, 2] = pred_ctr_x + 0.5 * pred_w
    pred_boxes[:, 3] = pred_ctr_y + 0.5 * pred_h

    return pred_boxes

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Hyperparameters
num_epochs = 10
learning_rate = 0.001
batch_size = 4

# Initialize model, optimizer, and loss functions
backbone = ResNetBackbone(pretrained=True)
rpn = RPN(in_channels=2048)  # ResNet-50 backbone outputs 2048 channels
anchor_generator = AnchorGenerator()
model = RCNN(backbone, rpn, anchor_generator)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
cls_criterion = nn.CrossEntropyLoss()  # For objectness classification
reg_criterion = nn.SmoothL1Loss()      # For bounding box regression


from torch.utils.data._utils.collate import default_collate

def custom_collate_fn(batch):
    images = [item[0] for item in batch]  # Extract images
    labels = [item[1] for item in batch]  # Extract label dictionaries

    # Stack images into a single tensor
    images = default_collate(images)

    # Collate labels manually
    collated_labels = {
        "boxes": default_collate([label["boxes"] for label in labels]),
        "labels": default_collate([label["labels"] for label in labels]),
        "names": [label["names"] for label in labels]  # Names remain as a list of lists
    }

    return images, collated_labels

# Create DataLoader with the custom collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=custom_collate_fn)

for images, gt_boxes in train_loader:
    print("Images shape:", images.shape)  # Should be [batch_size, 3, height, width]
    print("Boxes shape:", gt_boxes["boxes"].shape)  # Should be [batch_size, max_boxes, 4]
    print("Labels shape:", gt_boxes["labels"].shape)  # Should be [batch_size, max_boxes]
    print("Names:", gt_boxes["names"])  # Should be a list of lists of strings
    break

# Training loop
train_losses = []
train_ious = []

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    epoch_iou = 0.0

    for images, gt_boxes in train_loader:
        for images, gt_boxes in train_loader:
          images = images  # Images tensor of shape [batch_size, 3, height, width]
          boxes = gt_boxes["boxes"]  # Bounding boxes tensor of shape [batch_size, max_boxes, 4]
          labels = gt_boxes["labels"]  # Labels tensor of shape [batch_size, max_boxes]
          names = gt_boxes["names"]  # List of lists of strings

          # Ensure all gt_boxes have the same shape
          max_boxes = boxes.shape[1]  # Maximum number of boxes per image

          # Debug: Print shapes to verify
          print("Images shape:", images.shape)
          print("Boxes shape:", boxes.shape)
          print("Labels shape:", labels.shape)
          print("Names:", names)
        # Forward pass
        optimizer.zero_grad()
        objectness_scores, bbox_offsets, anchors, labels, bbox_targets = model(images, gt_boxes)

        # Reshape outputs
        objectness_scores = objectness_scores.permute(0, 2, 3, 1).reshape(-1, 2)  # (N * H * W * A, 2)
        bbox_offsets = bbox_offsets.permute(0, 2, 3, 1).reshape(-1, 4)            # (N * H * W * A, 4)

        # Compute classification loss
        cls_loss = cls_criterion(objectness_scores, labels)

        # Compute regression loss (only for positive anchors)
        pos_inds = torch.nonzero(labels == 1, as_tuple=False).squeeze(1)
        if pos_inds.numel() > 0:
            reg_loss = reg_criterion(bbox_offsets[pos_inds], bbox_targets[pos_inds])
        else:
            reg_loss = torch.tensor(0.0)

        # Total loss
        loss = cls_loss + reg_loss
        epoch_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Compute IoU for evaluation
        with torch.no_grad():
            pred_boxes = apply_bbox_offsets(anchors, bbox_offsets)  # Convert offsets to boxes
            iou = compute_iou(pred_boxes[pos_inds], gt_boxes[pos_inds])  # Compare to ground truth
            epoch_iou += iou.mean().item()

    # Average loss and IoU for the epoch
    epoch_loss /= len(train_loader)
    epoch_iou /= len(train_loader)
    train_losses.append(epoch_loss)
    train_ious.append(epoch_iou)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, IoU: {epoch_iou:.4f}")