<a href="https://colab.research.google.com/github/rahiakela/modern-computer-vision-with-pytorch/blob/main/8-advanced-object-detection/2_training_yolo_on_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Training YOLO on a custom dataset

You Only Look Once (YOLO) and its variants are one of the prominent object
detection algorithms.

that are likely to contain an object, and then we make the bounding box corrections. However, in the fully connected layer, where only the detected region's RoI pooling output is passed as input, in the case of regions that do not fully encompass the object(where the object is beyond the boundaries of the bounding box of region proposal), the network has to guess the real boundaries of object, as it has not seen the full image(but has seen only the region proposal).

**YOLO comes in handy in such scenarios, as it looks at the whole image while
predicting the bounding box corresponding to an image.**

Furthermore, Faster R-CNN is still slow, as we have two networks: the RPN and the final network that predicts classes and bounding boxes around objects.

YOLO overcomes the limitations of Faster R-CNN, both by looking at the whole image at once as well as by having a single network to make predictions.

## Working details of YOLO

We will look at how data is prepared for YOLO through the
following example:

**Step-1: Create a ground truth to train a model for a given image:**

Let's consider an image with the given ground truth of bounding
boxes in red:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-1.png?raw=1' width='800'/>

- Divide the image into $N x N$ grid cells – for now, let's say $N=3$:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-2.png?raw=1' width='800'/>

- Identify those grid cells that contain the center of at least one ground
truth bounding box. In our case, they are cells $b1$ and $b3$ of our 3 x 3
grid image.

- The cell(s) where the middle point of ground truth bounding box falls
is/are responsible for predicting the bounding box of the object. Let's
create the ground truth corresponding to each cell.

- The output ground truth corresponding to each cell is as follows:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-3.png?raw=1' width='800'/>

Here, pc (the objectness score) is the probability of the cell containing
an object.

Let's understand how to calculate bx, by, bw, and bh.

First, we consider the grid cell (let's consider the b1 grid cell) as our
universe, and normalize it to a scale between 0 and 1, as follows:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-4.png?raw=1' width='800'/>

bx and by are the locations of the mid-point of the ground truth
bounding box with respect to the image (of the grid cell), as defined
previously. In our case, bx = 0.5, as the mid-point of the ground truth is
at a distance of 0.5 units from the origin. Similarly, by= 0.5:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-5.png?raw=1' width='800'/>

So far, we have calculated offsets from the grid cell center to the
ground truth center corresponding to the object in the image. Now,
let's understand how bw and bh are calculated.

**bw is the ratio of the width of the bounding box with respect to the the
width of the grid cell.**

**bh is the ratio of the height of the bounding box with respect to the
height of the grid cell.**

Next, we will predict the class corresponding to the grid cell. If we
have three classes (c1 – truck, c2 – car, c3 – bus), we will predict the
probability of the cell containing an object among any of the three
classes. **Note that we do not need a background class here, as pc
corresponds to whether the grid cell contains an object.**

Now that we understand how to represent the output layer of each cell,
let's understand how we construct the output of our 3 x 3 grid cells.

- Let's consider the output of the grid cell a3:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-6.png?raw=1' width='800'/>

As the grid cell does not contain an object, the first output (pc – objectness
score) is 0 and the remaining values do not matter as the cell does not
contain the center of any ground truth bounding boxes of an object.

- Let's consider the output corresponding to grid cell b1:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-7.png?raw=1' width='800'/>

The preceding output is the way it is because the grid cell contains an
object with the bx, by, bw, and bh values that were obtained in the
same way as we went through earlier (in the bullet point before last),
and finally the class being car resulting in c2 being 1 while c1 and c3
are 0.

**Note that for each cell, we are able to fetch 8 outputs. Hence, for the 3 x
3 grid of cells, we fetch 3 x 3 x 8 outputs.**

**Step-2: Define a model where the input is an image and the output is 3 x 3 x 8 with the ground truth being as defined in the previous step:**

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-8.png?raw=1' width='800'/>

**Step-3:Define the ground truth by considering the anchor boxes.**

So far, we have been building for a scenario where the expectation is that
there is only one object within a grid cell. However, in reality, there can be
scenarios where there are multiple objects within the same grid cell. This
would result in creating ground truths that are incorrect. 

Let's understand this phenomenon through the following example image:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-9.png?raw=1' width='800'/>

In the preceding example, the mid-point of the ground truth bounding
boxes for both the car and the person fall in the same cell – cell b1.

One way to avoid such a scenario is by having a grid that has more rows
and columns – for example, a 19 x 19 grid. However, there can still be a
scenario where an increase in the number of grid cells does not help.
**Anchor boxes come in handy in such a scenario.** Let's say we have two
anchor boxes – one that has a greater height than width (corresponding to
the person) and another that has a greater width than height (corresponding
to the car):

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-10.png?raw=1' width='800'/>

Typically, the anchor boxes would have the grid cell center as their centers.
The output for each cell in a scenario where we have two anchor boxes is
represented as a concatenation of the output expected of the two anchor
boxes:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/bounding-box-11.png?raw=1' width='800'/>

Here, bx, by, bw, and bh represent the offset from the anchor box (which is
the universe in this scenario as seen in the image instead of the grid cell).

From the preceding screenshot, we see we have an output that is 3 x 3 x 16,
as we have two anchors. The expected output is of the shape N x N x
(num_classes + 1) x (num_anchor_boxes), where N x N is the number of
cells in the grid, num_classes is the number of classes in the dataset, and
num_anchor_boxes is the number of anchor boxes.

**Step-4:Now we define the loss function to train the model.**

When calculating the loss associated with the model, we need to ensure that
we do not calculate the regression loss and classification loss when the
objectness score is less than a certain threshold (this corresponds to the cells
that do not contain an object).

Next, if the cell contains an object, we need to ensure that the classification
across different classes is as accurate as possible.

Finally, if the cell contains an object, the bounding box offsets should be as
close to expected as possible. However, since the offsets of width and height
can be much higher when compared to the offset of the center (as offsets of
the center range between 0 and 1, while the offsets of width and height need
not), we give a lower weightage to offsets of width and height by fetching a
square root value.

Calculate the loss of localization and classification as follows:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/yolo-equation.png?raw=1' width='800'/>

**The overall loss is a sum of classification and regression loss values.**

With this in place, we are now in a position to train a model to predict the bounding boxes around objects.


## Training YOLO on a custom dataset