<a href="https://colab.research.google.com/github/sidhu2690/YOLO_V8_PyTorch/blob/main/00_YOLO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
import torch.nn as nn

In [13]:
class TinyYOLO(nn.Module):

    def __init__(self, num_classes=20, grid_size=7, num_boxes=2):
        super(TinyYOLO, self).__init__()

        self.num_classes = num_classes
        self.grid_size = grid_size
        self.num_boxes = num_boxes

        self.output_channels = num_boxes * 5 + num_classes

        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.conv4 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv5 = nn.Conv2d(256, self.output_channels, 1, 1, 0)

        self.relu = nn.ReLU()

    def forward(self, x):

        x = self.relu(self.conv1(x))
        x = self.pool1(x)  # 224 -> 112

        x = self.relu(self.conv2(x))
        x = self.pool2(x)  # 112 -> 56

        x = self.relu(self.conv3(x))
        x = self.pool3(x)  # 56 -> 28

        # Detection
        x = self.relu(self.conv4(x))
        x = self.conv5(x)  # Still 28x28

        x = nn.functional.adaptive_avg_pool2d(x, (self.grid_size, self.grid_size))

        # To (B, S, S, output_channels)
        B = x.size(0)
        x = x.permute(0, 2, 3, 1).contiguous()

        return x


class YOLOLoss(nn.Module):
    """
    Loss = Î»_coord * bbox_loss + confidence_loss + classification_loss
    """

    def __init__(self, num_classes=20, num_boxes=2, lambda_coord=5.0, lambda_noobj=0.5):
        super(YOLOLoss, self).__init__()

        self.num_classes = num_classes
        self.num_boxes = num_boxes
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj

    def forward(self, predictions, targets):
        """
        Compute YOLO loss.

        Args:
            predictions: (B, S, S, num_boxes*5 + num_classes)
            targets: (B, S, S, 5 + num_classes)
                    Format: [x, y, w, h, confidence, class_probs...]

        Returns:
            total_loss: scalar tensor
        """
        B, S, S, _ = predictions.shape

        # Split predictions into components
        box_predictions = predictions[..., :self.num_boxes*5].view(B, S, S, self.num_boxes, 5)
        class_predictions = predictions[..., self.num_boxes*5:]

        # Target components
        target_boxes = targets[..., :4]
        target_conf = targets[..., 4:5]
        target_class = targets[..., 5:]

        # Object mask: 1 if cell contains object, 0 otherwise
        obj_mask = target_conf.squeeze(-1)
        noobj_mask = 1 - obj_mask

        # Coordinate loss (only for cells with objects)
        coord_loss = 0
        conf_loss = 0

        for b in range(self.num_boxes):
            pred_xy = box_predictions[..., b, :2]
            pred_wh = box_predictions[..., b, 2:4]
            pred_conf = box_predictions[..., b, 4]

            # Expand obj_mask for coordinate dimensions
            obj_mask_xy = obj_mask.unsqueeze(-1)

            # XY loss
            xy_loss = torch.sum(obj_mask_xy * ((pred_xy - target_boxes[..., :2])**2))

            # WH loss (using square root)
            wh_loss = torch.sum(obj_mask_xy * ((torch.sqrt(pred_wh + 1e-6) -
                                                torch.sqrt(target_boxes[..., 2:4] + 1e-6))**2))

            coord_loss += xy_loss + wh_loss

            # Confidence loss
            conf_loss += torch.sum(obj_mask * ((pred_conf - target_conf.squeeze(-1))**2))
            conf_loss += self.lambda_noobj * torch.sum(noobj_mask * (pred_conf**2))

        # Classification loss
        class_loss = torch.sum(obj_mask.unsqueeze(-1) *
                               ((class_predictions - target_class)**2))

        # Total loss
        total_loss = (self.lambda_coord * coord_loss +
                     conf_loss +
                     class_loss)

        return total_loss / B


# Example usage
if __name__ == "__main__":
    # Create model
    model = TinyYOLO(num_classes=20, grid_size=7, num_boxes=2)

    # Example input: batch of 2 images, 224x224
    x = torch.randn(2, 3, 224, 224)

    # Forward pass
    predictions = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {predictions.shape}")
    print(f"Expected: (2, 7, 7, {2*5 + 20})")

    # Example loss computation
    loss_fn = YOLOLoss(num_classes=20, num_boxes=2)

    # Dummy targets
    targets = torch.zeros(2, 7, 7, 5 + 20)
    targets[0, 3, 3, 4] = 1  # Object at grid cell (3,3)
    targets[0, 3, 3, :4] = torch.tensor([0.5, 0.5, 0.3, 0.4])  # bbox
    targets[0, 3, 3, 5] = 1  # Class 0

    loss = loss_fn(predictions, targets)
    print(f"\nLoss: {loss.item():.4f}")

    # Model parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTotal parameters: {total_params:,}")

Input shape: torch.Size([2, 3, 224, 224])
Output shape: torch.Size([2, 7, 7, 30])
Expected: (2, 7, 7, 30)

Loss: nan

Total parameters: 396,126
