In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class YOLOv2(nn.Module):
    def __init__(self, num_classes, anchors, grid_size=13, num_bboxes=5):
        super(YOLOv2, self).__init__()

        # Initialize Darknet-19 backbone
        self.backbone = Darknet19() # implemented in darknet-19.ipynb

        # Final detection layer: 1x1 convolution to predict bounding boxes and class probabilities
        # Output channels are: (B * 5 + C), where B is number of bounding boxes per grid,
        # 5 for (x, y, w, h, confidence), and C is number of classes.
        self.det_conv = nn.Conv2d(1024, num_bboxes * (5 + num_classes), kernel_size=1, stride=1, padding=0)

        self.grid_size = grid_size
        self.num_bboxes = num_bboxes
        self.num_classes = num_classes
        self.anchors = anchors  # List of anchor box dimensions

    def forward(self, x):
        # Pass through the backbone to get feature map
        x = self.backbone(x)

        # Pass through the detection convolution layer
        output = self.det_conv(x)

        # Reshape the output to (batch_size, grid_size, grid_size, B*(5 + C))
        output = output.view(output.size(0), self.num_bboxes * (5 + self.num_classes), self.grid_size, self.grid_size)

        # Permute to (batch_size, grid_size, grid_size, B*(5 + C)) 
        # for easier access to bounding box parameters
        output = output.permute(0, 2, 3, 1)

        return output


    def predict(self, x, threshold=0.5):
        """
        This method processes the network's output and applies Non-Maximum Suppression (NMS).
        The output will be filtered based on the confidence score threshold.
        """
        # Run the forward pass
        output = self.forward(x)

        # Output shape is (batch_size, grid_size, grid_size, B*(5 + C))
        batch_size, grid_size, _, _ = output.shape

        # Initialize the predictions
        predictions = []

        for i in range(batch_size):
            grid_pred = output[i]  # Shape: (grid_size, grid_size, B*(5 + C))

            # Initialize list to store individual predictions for the current image
            image_predictions = []

            for j in range(grid_size):
                for k in range(grid_size):
                    cell_pred = grid_pred[j, k]  # Shape: (B*(5 + C),)
                    
                    # Reshape to (B, 5 + C)
                    cell_pred = cell_pred.view(self.num_bboxes, 5 + self.num_classes)

                    # Extract the box coordinates and confidence
                    box_confidence = cell_pred[:, 4]
                    box_coords = cell_pred[:, :4]  # x, y, w, h
                    class_probs = cell_pred[:, 5:]  # Class probabilities

                    # Apply sigmoid to box confidence and class probabilities
                    box_confidence = torch.sigmoid(box_confidence)
                    class_probs = torch.sigmoid(class_probs)

                    # Filter out predictions based on confidence threshold
                    mask = box_confidence > threshold
                    box_confidence = box_confidence[mask]
                    box_coords = box_coords[mask]
                    class_probs = class_probs[mask]

                    if len(box_confidence) > 0:
                        # Each prediction is (confidence, bbox, class_probs)
                        image_predictions.append((box_confidence, box_coords, class_probs))

            predictions.append(image_predictions)

        return predictions