In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
# Importing the class to get the dataset
from utils.import_data import WiderFaceDataset, TRANSFORM, TRAIN_ROOT, ANN_FILE
from utils.anchors import AnchorMatcher, AnchorGenerator, box_nms, compute_loss_with_anchors


# Create the dataset
dataset = WiderFaceDataset(
    root_dir=TRAIN_ROOT, 
    annotation_file=ANN_FILE, 
    img_size=224, 
    transform=TRANSFORM,
    single_face_only=True #if true, thirds the total images
    )

print(f"Total images: {len(dataset)}")

#custom collate function to seperate the boxes from the images
def custom_collate_fn(batch):
    images, targets = zip(*batch)  # unzip the batch
    return images, targets

# Create a dataloader
dataloader = DataLoader(
    dataset, 
    batch_size=128,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    #use collate here because our data has the images but also the boxes and number of boxes
    collate_fn = custom_collate_fn
    )

Total images: 4631


In [2]:
class FaceDetectionNet(nn.Module):
    def __init__(self, num_anchors=1):
        """
        num_anchors: number of boxes predicted per spatial cell (simplest: 1)
        """
        super(FaceDetectionNet, self).__init__()

        """
        kernel_size is the size of the box we pass over each img to extract the features, exactly like tf (3,3,3)
        """
        #Backbone (feature extractor)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # RGB input
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # downsample by 2 -> 112x1112

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # downsample by 2 -> 56x56

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # downsample by 2 -> 28x28
        )

        # Detection head
        # Predict bounding boxes + confidence
        # Output channels = num_anchors * 5 (x, y, w, h, conf)
        self.det_head = nn.Conv2d(256, num_anchors * 5, kernel_size=1)

    def forward(self, x):
        """
        x: [batch_size, 3, H, W]
        Returns:
            out: [batch_size, num_anchors * 5, H/4, W/4] 
                 Each cell predicts (x, y, w, h, confidence)
        """
        features = self.backbone(x)
        out = self.det_head(features)  # [B, 5*num_anchors, H', W']

        B, C, H, W = out.shape
        out = out.view(B, -1, 5, H, W)  # [B, num_anchors, 5, H', W']
        return out

In [3]:
net = FaceDetectionNet()

#moves all the info to the gpu (cuda) if it can, if not it keeps it on the cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}") 
net = net.to(device)


optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)
scaler = GradScaler()

# Initialize anchor generator and matcher
anchor_gen = AnchorGenerator(scales=[1.0], aspect_ratios=[1.0])
matcher = AnchorMatcher(pos_iou_thresh=0.5, neg_iou_thresh=0.4)

# Generate anchors (assuming feature map is 28x28 after 2x2 pooling)
anchors = anchor_gen.generate_anchors(feature_h=28, feature_w=28, stride=8, img_size=224)

num_epochs = 3

for epoch in range(num_epochs):
    net.train() # sets the net to training mode
    epoch_loss = 0

    #with pytorch it wont let you pass the entire dataset to the net at once so you have to send it in batches (thats what the batch size is for)
    for images, targets in dataloader:
        # images is a tuple of tensors; stack into a single batch tensor
        images = torch.stack(images).to(device)    # shape: [batch_size, 3, 224, 224]
        boxes = [t['boxes'].to(device) for t in targets]  # move each image's boxes to GPU


        optimizer.zero_grad()               # resets gradients
        
        # outputs = net(images)               # forward pass

        # loss = compute_loss_with_anchors(outputs, boxes, anchors, matcher, device)
        # loss.backward()                     # backpropogation
        # optimizer.step()                    # update weights 

        with autocast('cuda'):  # Use FP16 for forward pass
            outputs = net(images)
            loss = compute_loss_with_anchors(outputs, boxes, anchors, matcher, device)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()  # accumulate batch loss

    #custom loss to mimic what it looks like in tf
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Average Loss: {avg_epoch_loss:.4f}")






Training on: cuda
Epoch [1/3] - Average Loss: 0.3115
Epoch [2/3] - Average Loss: 0.2035
Epoch [3/3] - Average Loss: 0.2052


In [4]:
# Inference with NMS
def detect_faces(image, model, anchors, conf_threshold=0.5, nms_threshold=0.5):
    """
    Detect faces in image using trained model
    """
    model.eval()
    with torch.no_grad():
        outputs = model(image.unsqueeze(0))  # [1, 1, 5, H, W]
        
        pred_flat = outputs[0].view(5, -1).T  # [num_anchors, 5]
        pred_boxes = pred_flat[:, :4]
        pred_conf = torch.sigmoid(pred_flat[:, 4])
        
        # Filter by confidence
        keep_idx = pred_conf > conf_threshold
        boxes = anchors[keep_idx]
        scores = pred_conf[keep_idx]
        
        # Apply NMS
        keep = box_nms(boxes, scores, nms_threshold)
        
        return boxes[keep], scores[keep]