In [1]:
from utils.import_data import WiderFaceDataset, TRANSFORM, IMG_ROOT, ANN_FILE

dataset = WiderFaceDataset(root_dir=IMG_ROOT, annotation_file=ANN_FILE, img_size=224, transform=TRANSFORM)

# Test a single sample
img, target = dataset[0]
print(img.shape)          # torch.Size([3, 224, 224])
print(target['boxes'])    # Bounding boxes scaled to 224x224

torch.Size([3, 224, 224])
tensor([[ 98.2188,  53.3718, 124.9062,  77.4700]])


In [None]:
import torch.nn as nn
import torch.nn.functional as F

# # ----------------------------
# # Simple Face Detection Net
# # ----------------------------
# # Input: 3x128x128 image
# # Output: 5 values per predicted box: [x_center, y_center, width, height, confidence]
class SimpleFaceDetector(nn.Module):
    def __init__(self):
        super(SimpleFaceDetector, self).__init__()
        
        # CNN backbone
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 64x64
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 32x32
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 16x16
        )
        
        # Detection head (predict boxes and confidence for 16x16 grid)
        # Here we predict 1 box per grid cell
        self.detector = nn.Conv2d(64, 5, kernel_size=1)  # 5 = [x, y, w, h, conf]
        
    def forward(self, x):
        x = self.features(x)
        x = self.detector(x)  # shape: (batch, 5, 16, 16)
        
        # reshape to (batch, num_boxes, 5)
        batch_size = x.size(0)
        x = x.permute(0, 2, 3, 1)  # (batch, 16, 16, 5)
        x = x.reshape(batch_size, -1, 5)  # flatten grid â†’ (batch, 256, 5)
        return x


In [None]:
# # Create the network
net = SimpleFaceDetector()

# Dummy input: batch of 2 images, 3 channels, 128x128
dummy_input = torch.randn(2, 3, 128, 128)

# Forward pass
output = net(dummy_input)
print("Output shape:", output.shape)  # (batch, num_boxes, 5)
