# TorchGeo: An Introduction to Object Detection Example
[https://medium.com/@byeonghyeokyu/torchgeo-an-introduction-to-object-detection-example-b0fd43e89649](https://medium.com/@byeonghyeokyu/torchgeo-an-introduction-to-object-detection-example-b0fd43e89649)
- https://doi.org/10.1016/j.isprsjprs.2014.10.002
- https://doi.org/10.1109/IGARSS.2019.8898573
- https://doi.org/10.3390/rs12060989

## Installing Dependencies

Install rar and unrar

In [None]:
!wget https://www.rarlab.com/rar/rarlinux-x64-5.5.0.tar.gz
!tar xzvf rarlinux-x64-5.5.0.tar.gz 
!sudo cp rar/rar rar/unrar /usr/local/bin/

Install Python Dependencies

In [None]:
!pip install torchgeo[all]
!pip install gdown
!pip install -q -U pytorch-lightning

In [None]:
import torchgeo
from torchgeo.datasets import VHR10
from torchgeo.trainers import ObjectDetectionTask

import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl

import matplotlib.pyplot as plt

In [None]:
torchgeo.__version__
# Need version 0.5.2

## Downloading the VHR-10 Dataset

In [None]:
import os, gdown

os.makedirs('data/VHR10/', exist_ok=True)

url = 'https://drive.google.com/uc?id=1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE'
output_path = 'data/VHR10/NWPU VHR-10 dataset.rar'
gdown.download(url, output_path, quiet=False)

In [None]:
def preprocess(sample):
    sample["image"] = sample["image"].float() / 255.0
    return sample

ds = VHR10(
    root="data/VHR10/",
    split="positive",
    transforms=preprocess,
    download=True,
    checksum=True,
)

## Exploring the VHR-10 Dataset

In [None]:
print(f"VHR-10 dataset: {len(ds)}")


In [None]:
ds[0]["image"].shape


In [None]:
torch.Size([3, 808, 958])


In [None]:
image = ds[5]["image"].permute(1, 2, 0)
plt.imshow(image)
plt.show()

In [None]:
ds.plot(ds[5])
plt.savefig('ground_truth.png', bbox_inches='tight')
plt.show()

## Model Training

In [None]:
def collate_fn(batch):
    new_batch = {
        "image": [item["image"] for item in batch],  # Images
        "boxes": [item["boxes"] for item in batch],  # Bounding boxes
        "labels": [item["labels"] for item in batch],  # Labels
        "masks": [item["masks"] for item in batch],  # Masks
    }
    return new_batch  # Return the new batch

# Data Loader

dl = DataLoader(
    ds,  # Dataset
    batch_size=32,  # Number of data to load at one time
    num_workers=2,  # Number of processes to use for data loading
    shuffle=True,  # Whether to shuffle the dataset before loading
    collate_fn=collate_fn,  # collate_fn function for batch processing
)


In [None]:
class VariableSizeInputObjectDetectionTask(ObjectDetectionTask):
    # Define the training step
    def training_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]  # Image
        batch_size = len(x)  # Set batch size (number of images)
        y = [
            {"boxes": batch["boxes"][i], "labels": batch["labels"][i]}
            for i in range(batch_size)
        ] # Extract bounding box and label information for each image
        loss_dict = self(x, y)  # Loss
        train_loss: Tensor = sum(loss_dict.values())  # Training loss (sum of loss values)
        self.log_dict(loss_dict)  # Record loss values
        return train_loss  # Return training loss

task = VariableSizeInputObjectDetectionTask(
    model="faster-rcnn",  # Faster R-CNN model
    backbone="resnet18",  # ResNet18 neural network architecture
    weights=True,  # Use pretrained weights
    in_channels=3,  # Number of channels in the input image (RGB images)
    num_classes=11,  # Number of classes to classify (10 + background)
    trainable_layers=3,  # Number of trainable layers
    lr=1e-3,  # Learning rate
    patience=10,  # Set the number of patience iterations for early stopping
    freeze_backbone=False,  # Whether to train with the backbone network weights unfrozen
)
task.monitor = "loss_classifier"  # Set the metric to monitor (here, the classifier's loss)

In [None]:
trainer = pl.Trainer(
    default_root_dir="logs/",  # Set the default directory
    accelerator="gpu",  # Set the type of hardware accelerator for training (using GPU)
    devices=[0],  # List of device IDs to use ([0] means the first GPU)
    min_epochs=6,  # Set the minimum number of training epochs
    max_epochs=100,  # Set the maximum number of training epochs
    log_every_n_steps=20,  # Set how often to log after a number of steps
)

In [None]:
%%time
# Model training
trainer.fit(task, train_dataloaders=dl)

## Model Inference Example

In [None]:
batch = next(iter(dl))

In [None]:
model = task.model
model.eval()

with torch.no_grad():
  out = model(batch["image"])

In [None]:
def create_sample(batch, out, batch_idx):
    return {
        "image": batch["image"][batch_idx],  # Image
        "boxes": batch["boxes"][batch_idx],  # Actual bounding boxes
        "labels": batch["labels"][batch_idx],  # Actual labels
        "masks": batch["masks"][batch_idx],  # Actual masks
        "prediction_labels": out[batch_idx]["labels"],  # Labels predicted by the model
        "prediction_boxes": out[batch_idx]["boxes"],  # Bounding boxes predicted by the model
        "prediction_scores": out[batch_idx]["scores"],  # Confidence scores for each prediction
    }

batch_idx = 0
sample = create_sample(batch, out, batch_idx)

In [None]:
ds.plot(sample)
plt.savefig('inference.png', bbox_inches='tight')
plt.show()

In [None]:
# Visualizing Sample for Batch Index 3
batch_idx = 3
sample = create_sample(batch, out, batch_idx)

ds.plot(sample)
plt.show()

In [None]:
# Visualizing Sample for Batch Index 5
batch_idx = 5
sample = create_sample(batch, out, batch_idx)

ds.plot(sample)
plt.show()