# 1. Introduction

This notebook outlines the creation, compilation, and training of a deep learing network for object segmentation and detection. The notebook is fundamentally based on the PyTorch tutorial for object segmentation that is available [here](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html).

The training and inference framework created for object segmentation has been validated with the pre-trained CNN models described in this [link](https://pytorch.org/vision/master/models/faster_rcnn.html). The support for a custom backbone will be available in a future verstionof TorchSuite.

# 2. Object Detection and Segmentation

Object detection is a fundamental task in computer vision that involves identifying and localizing objects within an image. Unlike image classification, which assigns a single label to an entire image, object detection provides:

* Bounding boxes around detected objects.
* Class labels to categorize each detected object.

Object segmentation goes a step further by providing pixel-wise masks for objects instead of just bounding boxes. It is categorized into:

* Semantic Segmentation (labels each pixel but does not distinguish between instances).
* Instance Segmentation (detects and segments each object separately, such as Mask R-CNN).

The challegens in object detection are listed below:

* Scale variability: Objects can appear at different sizes, requiring models to detect both large and small objects.
* Occlusion:  Objects may be partially hidden behind others.
* Class imbalance: Some object classes may appear more frequently than others.
* Background clutter: Distracting elements in images can mislead detection models.
* Real-time processing: Faster models are needed for applications like autonomous driving and surveillance.

# 3. The R-CNN Model: Two-Stage Object Detection

A Region-Based Convolutional Neural Network (R-CNN) is a two-stage object detection model that first generates region proposals and then classifies and refines them. The Region Proposal Network (RPN) scans the image and suggests candidate object regions, predicting bounding boxes and objectness scores. These proposals are then passed to the Region of Interest (RoI) Head, which classifies objects, refines box coordinates, and (in the case of Mask R-CNN) predicts segmentation masks. This two-stage approach delivers high accuracy but is computationally intensive compared to single-stage detectors like YOLO and SSD.

**Stage 1: Region Proposal Network (RPN)**

The Region Proposal Network (RPN) is responsible for generating object proposals, that is, regions of the image that are likely to contain objects. It works as follows:

1. Sliding window over feature Map: The RPN slides over the feature map produced by the backbone CNN (e.g., ResNet, VGG).
2. Anchor boxes: For each sliding window position, the RPN predicts multiple anchor boxes (default bounding boxes at various sizes and aspect ratios).
3. Objectness score: the network outputs a score for each anchor, determining whether it contains an object (foreground) or background.
4. Bounding box refinement: The RPN predicts adjustments to the anchor box coordinates to refine object localization.

**Stage 2: Region of Interest (RoI)**

The RoI Head takes the region proposals from the RPN and performs classification, bounding box refinement, and (for Mask R-CNN) segmentation mask prediction.

1. RoI pooling (or RoIAlign in Mask R-CNN): The variable-sized proposed regions from RPN are resized into a fixed size to be fed into a fully connected layer.
2. Classification & bounding box refinement: The RoI head classifies the detected object into one of the predefined categories. It further refines the bounding box predictions.
3. (For Mask R-CNN) Segmentation Mask Prediction: A separate mask branch is used to predict per-pixel masks for each detected object.

## Links of Interest
* https://medium.com/@soumyajitdatta123/faster-rcnns-explained-af76f96a0b70
* https://medium.com/@RobuRishabh/understanding-and-implementing-faster-r-cnn-248f7b25ff96

# 4. Importing Libraries

In [None]:
# Generic libraries
import os
import torch
import zipfile
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from pathlib import Path
from torchinfo import summary
from torch.optim.lr_scheduler import CosineAnnealingLR

# Torchvision libraries
from torchvision import tv_tensors
from torchvision.io import read_image
from torchvision.transforms import v2 as T
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks

# Import custom libraries
from utils.classification_utils import set_seeds, load_model
from utils.obj_detection_utils import collate_fn, prune_predictions, diplay_predictions, visualize_transformed_data
from engines.obj_detection import ObjectDetectionEngine
from engines.schedulers import FixedLRSchedulerWrapper
from dataloaders.obj_dect_dataloaders import ProcessDataset
from models.faster_rcnn import FasterRCNN

# Warnings
import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

# Create target model directory
MODEL_DIR = Path("outputs")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Set seeds
set_seeds(42)

# 5. Specifying the Target Device

In [None]:
# Activate cuda benchmark
#cudnn.benchmark = True

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

if device == "cuda":
    !nvidia-smi

# 6. Downloading the Penn-Fundan Dataset

In [None]:
# Download the dataset
!curl -L -k https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip -o data/PennFudanPed.zip

In [None]:
# Path to the downloaded zip file
zip_file_path = "data/PennFudanPed.zip"
extract_dir = "data"

# Ensure the extraction directory exists
os.makedirs(extract_dir, exist_ok=True)

# Unzip the file
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

print(f"Files extracted to {extract_dir}")

zip_file = Path(zip_file_path)
if zip_file.exists():
    os.remove(zip_file)

# 7. Image and Mask Visualization

In [None]:
image = read_image("data/PennFudanPed/PNGImages/FudanPed00046.png")
mask = read_image("data/PennFudanPed/PedMasks/FudanPed00046_mask.png")

plt.figure(figsize=(16, 8))
plt.subplot(121)
plt.title("Image")
plt.imshow(image.permute(1, 2, 0))
plt.subplot(122)
plt.title("Mask")
plt.imshow(mask.permute(1, 2, 0))
plt.show()

# 7. Preparing Dataloaders

In [None]:
# Pre-processing transformations
def get_transform(train, mean_std_norm):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(p=0.5))
        transforms.append(T.RandomVerticalFlip(p=0.5))
        transforms.append(T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2))
        transforms.append(T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)))
        transforms.append(T.RandomPerspective(distortion_scale=0.2, p=0.5))
        transforms.append(T.RandomGrayscale(p=0.1))
        transforms.append(T.RandomZoomOut(fill={tv_tensors.Image: (0, 0, 0), "others": 0}, side_range=(1.0, 2.0), p=0.2)), #(123, 117, 104)
    transforms.append(T.ToDtype(torch.float, scale=True))
    if mean_std_norm:
        transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [None]:
# The dataset contains two classes only: background and person
num_classes = 2
BATCHES = 2

# Use ther dataset and defined transformations
dataset_tr = ProcessDataset(
    root='data/PennFudanPed',
    image_path="PNGImages",
    mask_path="PedMasks",
    transforms=get_transform(train=True, mean_std_norm=False),
    num_classes=num_classes-1) # exclude the background

dataset_ntr = ProcessDataset(
    root='data/PennFudanPed',
    image_path="PNGImages",
    mask_path="PedMasks",
    transforms=get_transform(train=False, mean_std_norm=False),
    num_classes=num_classes-1) # exclude the background

# Split the dataset in train and test set
indices = torch.randperm(len(dataset_tr)).tolist()
train_dataset = torch.utils.data.Subset(dataset_tr, indices[:-25])
test_dataset = torch.utils.data.Subset(dataset_ntr, indices[-25:])
test_dataset_t = torch.utils.data.Subset(dataset_tr, indices[-25:])

# Define training and validation data loaders
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCHES,
    shuffle=True,
    collate_fn=collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCHES,
    shuffle=False,
    collate_fn=collate_fn
)

# Define training and validation data loaders
test_dataloader_t = torch.utils.data.DataLoader(
    test_dataset_t,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)

In [None]:
# Visualize transformations
for idx, ((img, target), (img_t, target_t)) in enumerate(zip(test_dataloader, test_dataloader_t)):   
    for i in range(0, BATCHES):
        visualize_transformed_data(img[i], target[i], img_t[i], target_t[i])
    if idx > 4:
        break

# 8. Creating the Object Detection Model Based on Faster R-CNN

In [None]:
model = FasterRCNN(
    backbone="resnet50_v2",
    num_classes=num_classes,
    device=device
    )

#model = torch.compile(model, backend="aot_eager")

#summary(model,
#        input_size=(64, 3, 224, 224),
#        col_names=["input_size", "output_size", "num_params", "trainable"],
#        col_width=20,
#        row_settings=["var_names"])

# 9. Training the Model

In [None]:
# Model name
model_type="model"
model_name = model_type + ".pth"
EPOCHS = 30
LR = 0.001

# Create the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=LR,
    momentum=0.9,
    weight_decay=0.0005
)

# Create AdamW optimizer
#optimizer = torch.optim.AdamW(
#    params=model.parameters(),
#    lr=LR,
#    betas=(0.9, 0.999),
#    weight_decay=0.01
#)


# Create the scheduler
#scheduler = torch.optim.lr_scheduler.StepLR(
#    optimizer,
#    step_size=5,
#    gamma=0.1
#)

scheduler = FixedLRSchedulerWrapper(
    scheduler=CosineAnnealingLR(optimizer, T_max=25, eta_min=1e-6),
    fixed_lr=1e-6,
    fixed_epoch=25)

# Instantiate the engine with the created model and the target device
engine = ObjectDetectionEngine(
    model=model,
    device=device)

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                       # Directory where the model will be saved
    model_name=model_name,                      # Name of the model
    save_best_model=["last", "loss"],           # Save the best models based on different criteria
    keep_best_models_in_memory=True,            # Do not keep the models stored in memory for the sake of training time and memory efficiency
    train_dataloader=train_dataloader,          # Train dataloader
    test_dataloader=test_dataloader,            # Test dataloader
    optimizer=optimizer,                        # Optimizer    
    scheduler=scheduler,                        # Scheduler
    epochs=EPOCHS,                              # Total number of epochs
    amp=True,                                   # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                      # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                           # Disable debug mode    
    accumulation_steps=1,                       # Accumulation steps 2: effective batch size = batch_size x accumulation steps
    apply_validation=True                       # Enable validation step
    )

# 10. Making Predictions

In [None]:
OPTION = 1

# Make predictions using the `engine` object, best model is already internally stored
if OPTION == 1:
    # Make predictions and plot the results
    preds = engine.predict(
        dataloader=test_dataloader,
        model_state='loss', # Take the model with the lowest loss
        prune_predictions = True,
        #score_threshold = 0.66,
        #mask_threshold = 0.5,    
        #iou_threshold = 0.5
    )

# Make predictions by loading the already trained model manually
else:
    # Instantiate the trained model
    # First, load the architecture
    model = FasterRCNN(
        backbone="resnet50_v2",
        num_classes=num_classes,
        device=device
        )

    # Second, load the parameters of the best model
    #model = load_model(model, "outputs", "model_loss_epoch21.pth")

    # Instantiate the engine with the created model and the target device
    engine2 = ObjectDetectionEngine(
        model=model,
        device=device)

    # Make predictions and plot the results
    preds = engine2.predict(
        dataloader=test_dataloader,
        prune_predictions = True,
        #score_threshold = 0.66,
        #mask_threshold = 0.5,    
        #iou_threshold = 0.5
    )

In [None]:
# Configuration parameters
MASK_COLOR = "blue"
BOX_COLOR = "white"
WIDTH = 3
PRINT_LABELS = True

# Display predictions
diplay_predictions(
    preds=preds,
    dataset=test_dataset,
    box_color=BOX_COLOR,
    mask_color=MASK_COLOR,
    width=WIDTH,
    print_classes=True,
    print_scores=True,
    label_to_class_dict={1: 'pedestrian'}
    )

In [None]:
# Load an arbitrary image from a different dataset
image = read_image("images/examples/000000000674.jpg")

# And make a prediction
eval_transform = get_transform(train=False, mean_std_norm=False)
model.eval()
with torch.no_grad():
    x = eval_transform(image)
    # convert RGBA -> RGB and move to device
    x = x[:3, ...].to(device)
    predictions = model([x, ])
    pred = prune_predictions(predictions[0])
    

# Prepare the image for plotting
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
pred_labels = [f"roi: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
pred_boxes = pred["boxes"].long()
output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="white", width=3)

#masks = (pred["masks"] > 0.7).squeeze(1)
#output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")

# Plot the image
fig, ax = plt.subplots(figsize=(12, 12))
plt.imshow(output_image.permute(1, 2, 0))
ax.axis("off")
plt.show()