# 1. Introduction
This notebook aims to develop and train a U-Net–based deep learning model for detecting and segmenting vehicles such as cars, buses, and trucks in images.

# 2. The U-Net Architecture
U-Net is one of the most common deep learning architectures for segmentation tasks. The block diagram of this model is depicted in the figure below (https://arxiv.org/abs/1505.04597):

<div style="text-align: center;">
    <img src="images/u-net-architecture.png" width="600"/>
</div>

The architecture features a "U" shape, consisting of two main stages: the contracting (encoder) and expansive (decoder) stages.

* The **encoder** captures the context and high-level features of the input image by using several convolutional layers. It gradually reduces the spatial dimensions while increasing the feature dimensions.

* The **decoder** is responsible for reconstructing the output image, which, in the case of segmentation, is the mask that identifies the objects of interest. This stage involves upsampling across the same number of levels as the encoder, followed by convolutional operations to "expand" the contracted image.

One of the unique characteristics of U-Net is its **skip connections**, which link the encoder and decoder stages at each level by merging features. While the contracting and expanding paths ("U" shape) capture high-level contextual information, the skip connections help preserve low-level spatial details that might be lost during downsampling.

# 3. Importing Libraries

In [None]:
# Generic libraries
import os
import torch
import json
import zipfile
import urllib
import random
import glob
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from pathlib import Path
from torchinfo import summary
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingLR, SequentialLR, ConstantLR

# Import custom libraries
from utils.obj_detection_utils import collate_fn
from utils.segmentation_utils import display_image_with_mask, collapse_one_hot_mask, create_label_class_dict
from engines.segmentation import SegmentationEngine
from engines.schedulers import FixedLRSchedulerWrapper
from engines.loss_functions import DiceCrossEntropyLoss
from dataloaders.segmentation_dataloaders import ProcessDatasetSegmentation, SegmentationTransforms
from models.unet import create_unet

# Import custom libraries
from utils.common_utils import set_seeds
from utils.coco_dataset_utils import COCO_2_ImgMsk, split_dataset

# Warnings
import warnings
os.environ['TORCH_USE_CUDA_DSA'] = "1"
warnings.filterwarnings("ignore", category=UserWarning, module="torch.autograd.graph")
warnings.filterwarnings("ignore", category=FutureWarning, module="onnxscript.converter")

import torch._dynamo
#torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 64
#warnings.filterwarnings("ignore", category=UserWarning, module="torch._dynamo")

# Create target model directory
MODEL_DIR = Path("outputs")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Set seeds
SEED = 42
set_seeds(SEED)

DOWNLOAD_COCO = False
PROCESS_COCO = False
VISUALIZE_TRANSFORMED_DATA = True

# 4. Specifying the Target Device

In [None]:
# Activate cuda benchmark
#cudnn.benchmark = True

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

if device == "cuda":
    !nvidia-smi

# 5. Importing the COCO Image Dataset
This section converts the COCO dataset for vehicle detection and segmentation into a structured format. The process involves:

* Filtering the dataset to retain categories such as 'bus', 'car', and 'truck'.
* Storing original images and corresponding segmentation masks in separate folders.
* Splitting the dataset into three subsets: training, validation, and testing.

By applying this conversion, the COCO dataset can be used alongside the PennFudanPed dataset for pedestrian detection and segmentation tasks.

After this section, the dataset will be organized into three subdirectories — train, val, test — each containing two subdirectories: PNGImages and PedMasks.

```
driving/
├── train/
│   └── PNGImages/
│       ├── img1.png
│       ├── img2.png
│       └── ...
│   └── PedMasks/
│       ├── msk1.png
│       ├── msk2.png
│       └── ...
├── val/
│   └── PNGImages/
│       ├── img1.png
│       ├── img2.png
│       └── ...
│   └── PedMasks/
│       ├── msk1.png
│       ├── msk2.png
│       └── ...
├── test/
│   └── PNGImages/
│       ├── img1.png
│       ├── img2.png
│       └── ...
│   └── PedMasks/
│       ├── msk1.png
│       ├── msk2.png
│       └── ...
```

**Note:** 25 GB disk is required to download the complete dataset.

## 5.1. Downloading the Dataset

In [None]:
if DOWNLOAD_COCO:
    # Define download URLs
    coco_urls = {
        "val_images": "http://images.cocodataset.org/zips/val2017.zip",
        "test_images": "http://images.cocodataset.org/zips/test2017.zip",
        "train_images": "http://images.cocodataset.org/zips/train2017.zip",
        "annotations": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    }

    # Create a directory to store the dataset
    dataset_dir = "d:/Repos/coco_dataset"
    os.makedirs(dataset_dir, exist_ok=True)

    # Download function
    def download_coco(url, filename):
        filepath = os.path.join(dataset_dir, filename)
        if not os.path.exists(filepath):
            print(f"Downloading {filename}...")
            urllib.request.urlretrieve(url, filepath)
            print(f"Saved to {filepath}")
        else:
            print(f"{filename} already exists.")

    # Download all files
    for key, url in coco_urls.items():
        filename = url.split("/")[-1]
        download_coco(url, filename)

In [None]:
if DOWNLOAD_COCO:
    # Unzip the file
    PATH = Path(dataset_dir)
    
    zip_file = PATH / "val2017.zip"
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)

    if zip_file.exists():
        os.remove(zip_file)

    zip_file = PATH / "annotations_trainval2017.zip"
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)

    if zip_file.exists():
        os.remove(zip_file)

    zip_file = PATH / "test2017.zip"
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)

    if zip_file.exists():
        os.remove(zip_file)

    zip_file = PATH / "train2017.zip"
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(dataset_dir)

    if zip_file.exists():
        os.remove(zip_file)

## 5.2. Processing the Dataset for Vehicle Segmentation

In [None]:
if PROCESS_COCO:
    # Path to COCO annotations file
    ANNOTATIONS_PATH = r"D:\Repos\coco_dataset\annotations\instances_train2017.json"

    # Load COCO annotations
    with open(ANNOTATIONS_PATH, "r") as f:
        coco_data = json.load(f)

    # Extract category ID to name mapping
    categories = {c["id"]: c["name"] for c in coco_data["categories"]}

    # Display all categories
    for cat_id, cat_name in categories.items():
        print(f"{cat_id}: {cat_name}")

In [None]:
# Create the dictionary describing label -> category), say in alphabetical order
target_categories={
    0: 'background',
    1: 'bus',
    2: 'car',
    3: 'truck'
}

# RGB
color_map = {
    0: (0, 0, 0),       # Background (Black)
    1: (232, 66, 66),   # Bus (Red)
    2: (35, 171, 75),   # Car (Green)
    3: (28, 163, 218),  # Truck (Blue)   
}

## Display categories
for cat_id, cat_name in target_categories.items():
    print(f"{cat_id}: {cat_name}")

In [None]:
if PROCESS_COCO:
    # Training dataset
    mapping = COCO_2_ImgMsk(
        coco_images_path=      r"D:\Repos\coco_dataset\train2017",
        coco_annotations_path= r"D:\Repos\coco_dataset\annotations\instances_train2017.json",
        output_images_dir=     r"D:\Repos\coco_dataset\driving\PNGImages",
        output_masks_dir=      r"D:\Repos\coco_dataset\driving\PedMasks",
        class_dictionary=      target_categories,
        label=                 "train"
    )

In [None]:
if PROCESS_COCO:
    # Training dataset
    _ = COCO_2_ImgMsk(
        coco_images_path=      r"D:\Repos\coco_dataset\val2017",
        coco_annotations_path= r"D:\Repos\coco_dataset\annotations\instances_val2017.json",
        output_images_dir=     r"D:\Repos\coco_dataset\driving\PNGImages",
        output_masks_dir=      r"D:\Repos\coco_dataset\driving\PedMasks",
        class_dictionary=      target_categories,
        label=                 "val"
    )

In [None]:
if PROCESS_COCO:
# Split dataset into train (80%), validation (10%), and test (10%) sets
    split_dataset(
        src_images=       r"D:\Repos\coco_dataset\driving\PNGImages",
        src_masks=        r"D:\Repos\coco_dataset\driving\PedMasks",
        dst_train_images= r"D:\Repos\ML_Projects\torchsuite\data\driving\train\PNGImages",
        dst_train_masks=  r"D:\Repos\ML_Projects\torchsuite\data\driving\train\PedMasks",
        dst_val_images=   r"D:\Repos\ML_Projects\torchsuite\data\driving\val\PNGImages",
        dst_val_masks=    r"D:\Repos\ML_Projects\torchsuite\data\driving\val\PedMasks",
        dst_test_images=  r"D:\Repos\ML_Projects\torchsuite\data\driving\test\PNGImages",
        dst_test_masks=   r"D:\Repos\ML_Projects\torchsuite\data\driving\test\PedMasks",
        train_pct=        0.80,
        val_pct=          0.10,
        test_pct=         0.10,
        seed=             SEED
    )

## 5.3. Preparing Dataloaders

In [None]:
# The dataset contains two classes only: background and person
NUM_CLASSES = len(target_categories)
BATCH_SIZE = 4
ACCUM_STEPS = 8
IMG_SIZE = (384, 384) #(512, 512)
AUGMENT_MAGNITUDE = 4 # 1 (low) to 5 (high)
THEME = 'light' # or 'dark'. Default is 'light'

# Define training, validation, and test data loaders
train_dataloader = torch.utils.data.DataLoader(
    dataset=ProcessDatasetSegmentation(
        root='data/driving/train',
        image_path="PNGImages",
        mask_path="PedMasks",
        transforms=SegmentationTransforms(
            train=True,
            img_size=IMG_SIZE,
            mean_std_norm=True,
            augment_magnitude=AUGMENT_MAGNITUDE
            ),
        class_dictionary=target_categories
        ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
)

val_dataloader = torch.utils.data.DataLoader(
    dataset=ProcessDatasetSegmentation(
        root='data/driving/val',
        image_path="PNGImages",
        mask_path="PedMasks",
        transforms=SegmentationTransforms(
            train=False, # If train if False, augmentation is not applied
            img_size=IMG_SIZE,
            mean_std_norm=True,
            augment_magnitude=AUGMENT_MAGNITUDE
            ),
        class_dictionary=target_categories
        ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

test_dataloader = torch.utils.data.DataLoader(
    dataset=ProcessDatasetSegmentation(
        root='data/driving/test',
        image_path="PNGImages",
        mask_path="PedMasks",
        transforms=SegmentationTransforms(
            train=False,
            img_size=IMG_SIZE,
            mean_std_norm=True,
            augment_magnitude=AUGMENT_MAGNITUDE
            ),
        class_dictionary=target_categories
        ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

dataloaders = {
    'train':         train_dataloader,    
    'test':          val_dataloader
}

# 5.4. Visualizing Images with Masks

In [None]:
if VISUALIZE_TRANSFORMED_DATA:

    # Visualize transformations
    BATCH_SIZE = 64
    # Train dataloader without transformations
    dataloader_nt = torch.utils.data.DataLoader(
        dataset=ProcessDatasetSegmentation(
            root='data/driving/train',
            image_path="PNGImages",
            mask_path="PedMasks",
            transforms=SegmentationTransforms(
                train=False,
                img_size=IMG_SIZE,
                mean_std_norm=False,
                augment_magnitude=AUGMENT_MAGNITUDE
                ),
            class_dictionary=target_categories
            ), 
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        collate_fn=collate_fn)
    
    # Test dataloader with transformations
    dataloader_t = torch.utils.data.DataLoader(
        dataset=ProcessDatasetSegmentation(
            root='data/driving/train',
            image_path="PNGImages",
            mask_path="PedMasks",
            transforms=SegmentationTransforms(
                train=True,
                img_size=IMG_SIZE,
                mean_std_norm=False,
                augment_magnitude=AUGMENT_MAGNITUDE
                ),
            class_dictionary=target_categories
            ), 
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=True,
        collate_fn=collate_fn)
    
    random.seed(SEED+1)

    # Visualize images and masks with and without transformations    
    for idx, ((img_nt, target_nt), (img_t, target_t)) in enumerate(zip(dataloader_nt, dataloader_t)):   

        # Pick random images
        random_indices = random.sample(range(BATCH_SIZE), min(10, BATCH_SIZE))
        for i in random_indices:

            # Set up the figure
            fig, axes = plt.subplots(1, 2, figsize=(15, 6))
            
            # Pass subplot axes to the function
            mask_nt = collapse_one_hot_mask(target_nt[i])
            mask_t = collapse_one_hot_mask(target_t[i])

            # Create the label-class dictionary for the mask
            classes_nt = create_label_class_dict(target_nt[i], target_categories)
            classes_t = create_label_class_dict(target_t[i], target_categories)

            # Remove background, as it is always there
            classes_nt = dict(list(classes_nt.items())[1:])
            classes_t = dict(list(classes_t.items())[1:])

            # And generate the titles
            title_nt = f"Original: {', '.join(classes_nt.values())}"
            title_t = f"Transformed: {', '.join(classes_t.values())}"
            
            # Display overlaid images
            alpha, beta = 1.0, 0.5
            display_image_with_mask(img_nt[i], mask_nt, fig=fig, ax=axes[0], alpha=alpha, beta=beta, color_map=color_map, title=title_nt, theme=THEME)
            display_image_with_mask(img_t[i], mask_t, fig=fig, ax=axes[1], alpha=alpha, beta=beta, color_map=color_map, title=title_t, theme=THEME)
            
            plt.show() 
            
        if idx > -1:
            break

# 6. Creating the U-Net Architecture

In [None]:
# Instantiate the model
model = create_unet(
    model_type="pretrained",
    backbone='convnext_large_384_in22ft1k', #'convnext_small_in22ft1k', #
    in_channels=3,
    num_classes=NUM_CLASSES,
    print_available_models=True,
    pretrained=True,
    encoder_freeze=False
)

# Unfreeze the base parameters
for parameter in model.parameters():
    parameter.requires_grad = True

#summary(model,
#        input_size=(BATCH_SIZE,3, IMG_SIZE[0], IMG_SIZE[1]),
#        col_names=["input_size", "output_size", "num_params", "trainable"],
#        col_width=20,
#        row_settings=["var_names"])

# 7. Training the Model

In [None]:
# Train the model
EPOCHS = 35
LR = 1e-4
ETAMIN = 1e-6
model_type="model_seg_convnext"
model_name = model_type + ".pth"

In [None]:
torch.cuda.empty_cache()

# Create AdamW optimizer
optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR,
    betas=(0.9, 0.999),
    weight_decay=1e-4
    )

# Create loss function
#loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
loss_fn = DiceCrossEntropyLoss(
    num_classes=NUM_CLASSES,
    label_smoothing=0.1
    )

# Set scheduler
cosine = CosineAnnealingLR(optimizer, T_max=EPOCHS-5, eta_min=ETAMIN)
fixed = ConstantLR(optimizer, factor=ETAMIN/LR, total_iters=5)
scheduler = SequentialLR(
    optimizer,
    schedulers=[cosine, fixed],
    milestones=[EPOCHS-5]
)
# Or (it is equivalent)
#scheduler = FixedLRSchedulerWrapper(
#    scheduler=CosineAnnealingLR(optimizer, T_max=EPOCHS-5, eta_min=1e-6),
#    fixed_lr=1e-6,
#    fixed_epoch=EPOCHS-5
#    )

# And train...

# Instantiate the classification engine with the created model and the target device
engine = SegmentationEngine(
    model=model,                                # Model to be trained
    optimizer=optimizer,                        # Optimizer
    loss_fn=loss_fn,                            # Loss function
    scheduler=scheduler,                        # Scheduler     
    theme=THEME,                                # Theme
    log_verbose=True,                           # Verbosity
    device=device                               # Target device
    )

# Configure the training method
results = engine.train(
    target_dir=MODEL_DIR,                       # Directory where the model will be saved
    model_name=model_name,                      # Name of the model
    enable_resume=True,                         # Resume training from the last saved checkpoint
    save_best_model=[
        "last", "loss", "dice", "iou"],         # Save the best models based on different criteria
    keep_best_models_in_memory=False,           # If False: do not keep the models stored in memory for the sake of training time and memory efficiency
    dataloaders=dataloaders,                    # Dictionary with the dataloaders     
    apply_validation=True,                      # Enable validation step
    augmentation_strategy="always",             # Augmentation strategy        
    epochs=EPOCHS,                              # Total number of epochs
    amp=True,                                   # Enable Automatic Mixed Precision (AMP)
    enable_clipping=False,                      # Disable clipping on gradients, only useful if training becomes unestable
    debug_mode=False,                           # Disable debug mode    
    accumulation_steps=ACCUM_STEPS,             # Accumulation steps: effective batch size = batch_size x accumulation steps
    )

# 8. Making Predictions on Test

In [None]:
from os.path import exists
def rename_model(model_name: str, new_name: str):
    old_name = model_name[0]
    os.rename(old_name, new_name)
    print(f"Renamed {old_name} to {new_name}")
    
# Find the model file with "model_1_loss_epoch" prefix and rename it
new_model_name = str(MODEL_DIR / f"{model_name}")
if not exists(new_model_name):
    model_name_dice = glob.glob(str(MODEL_DIR / f"{model_type}_dice_epoch*.pth"))
    rename_model(model_name_dice, new_model_name)

# Instantiate the classification engine with the created model and the target device
engine = SegmentationEngine(
    model=model,    
    log_verbose=True,
    device=device
    ).load(
        target_dir=MODEL_DIR,
        model_name=Path(new_model_name).name
    )

preds = engine.predict(
    dataloader=test_dataloader,
    num_classes=NUM_CLASSES,
    model_state="last", # This option takes the default loaded model
    output_type="onehot" #one-hot binary masks: (num_batches, batch_size, num_classes, W, H)
)

In [None]:
print(f"Shape of the prediction tensor: {preds.shape}")

In [None]:
# Visualize images and masks with and without transformations
BATCH_SIZE = test_dataloader.batch_size

# Create a test dataloader without mean/std transformation and augmentations
test_dataloader_nt = torch.utils.data.DataLoader(
    dataset=ProcessDatasetSegmentation(
        root='data/driving/test',
        image_path="PNGImages",
        mask_path="PedMasks",
        transforms=SegmentationTransforms(
            train=False,
            img_size=IMG_SIZE,
            mean_std_norm=False,
            augment_magnitude=AUGMENT_MAGNITUDE
            ),
        class_dictionary=target_categories
        ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    pin_memory=True,
)

for batch, (img, target) in enumerate(test_dataloader_nt):

    # Pick random images
    random_indices = random.sample(range(BATCH_SIZE), min(2, BATCH_SIZE))
    for i in random_indices:

        # Set up the figure
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Pass subplot axes to the function
        mask_true = collapse_one_hot_mask(target[i])
        mask_pred = collapse_one_hot_mask(preds[batch][i])

        # Create the label-class dictionary for the mask
        classes_true = create_label_class_dict(target[i], target_categories)
        classes_pred = create_label_class_dict(preds[batch][i], target_categories)

        # classes_true background, as it is always there
        classes_true = dict(list(classes_true.items())[1:])
        classes_pred = dict(list(classes_pred.items())[1:])

        # And generate the titles
        title_true = f"Ground-Truth: {', '.join(classes_true.values())}"
        title_pred = f"Predictions: {', '.join(classes_pred.values())}"
        
        # Display overlaid images
        alpha, beta = 1.0, 0.5
        display_image_with_mask(img[i], mask_true, fig=fig, ax=axes[0], alpha=alpha, beta=beta, color_map=color_map, title=title_true, theme=THEME)
        display_image_with_mask(img[i], mask_pred, fig=fig, ax=axes[1], alpha=alpha, beta=beta, color_map=color_map, title=title_pred, theme=THEME)
        
        plt.show() 
        
    if batch > 5:
        break