In [2]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.cluster import KMeans
from matplotlib import pyplot as plt

from pathlib import Path

from coco_downloader import COCOCatDogDownloader
from dataloader import create_split_datasets, create_dataloaders
from yolo_helpers import create_yolo_targets, decode_yolo_predictions

In [18]:
downloader = COCOCatDogDownloader()
downloader.download_and_prepare_dataset()

Annotations already exist, skipping download
loading annotations into memory...
Done (t=9.30s)
creating index...
index created!
Cat category ID: [17]
Dog category ID: [18]
Found 4078 pure cat images
Found 4342 pure dog images
Found 220 mixed images
Selected 220 mixed images (all)
Selected 2500 pure cat images
Selected 2500 pure dog images


100%|██████████| 5220/5220 [00:00<00:00, 92811.64it/s]


In [19]:
annotations_file = "cat_dog_images/cat_dog_annotations.json"

train_file, val_file = create_split_datasets(
    annotations_file, val_ratio=0.15)

train_loader, val_loader = create_dataloaders(
    train_file, val_file,
    images_dir="cat_dog_images",  
    batch_size=16,
    target_size=(224, 224)  
)


Split Results:
Training set: 4437 images
- Mixed: 187
- Pure cat: 2125
- Pure dog: 2125
Validation set: 783 images
- Mixed: 33
- Pure cat: 375
- Pure dog: 375

Saved:
- Training annotations: cat_dog_images/train_annotations.json
- Validation annotations: cat_dog_images/val_annotations.json


In [None]:
def calculate_anchors(train_loader, num_anchors=3, grid_size=14):
    all_boxes = []
    for _, targets in train_loader:
        for target in targets:
            for box in target['boxes']:
                w = (box[2] - box[0]) * grid_size / 224
                h = (box[3] - box[1]) * grid_size / 224
                all_boxes.append([w.item(), h.item()])
    
    kmeans = KMeans(n_clusters=num_anchors, random_state=42)
    kmeans.fit(all_boxes)
    return kmeans.cluster_centers_

anchors = calculate_anchors(train_loader)
anchors = np.sort(anchors, axis=0)
anchors = anchors.astype(np.float32)
anchors

array([[ 2.07443423,  2.33441195],
       [ 5.45975762,  6.25354672],
       [ 9.75938853, 10.51438995]])

In [3]:
class ResidualFeatureAdapter(nn.Module):
    def __init__(self, in_channels, hidden_channels, dropout_rate=0.1):
        super(ResidualFeatureAdapter, self).__init__()
        self.adapter = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(hidden_channels, in_channels, kernel_size=1),
            nn.BatchNorm2d(in_channels)
        )
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        return x + self.adapter(x)
    
class ResNetYOLODetector(nn.Module):
    def __init__(self, anchor_boxes, backbone_name="resnet50", grid_size=14, freeze_backbone_epochs=15, dropout_rate=0.1):
        super(ResNetYOLODetector, self).__init__()
    
        self.num_classes = 2
        self.num_anchors = len(anchor_boxes)
        self.freeze_backbone_epochs = freeze_backbone_epochs
        self.current_epoch = 0
        self.dropout_rate = dropout_rate

        self.backbone, backbone_channels = self._load_backbone(backbone_name, grid_size)
        self.feature_adapter = ResidualFeatureAdapter(backbone_channels, backbone_channels // 2, dropout_rate)
        
        self.prediction_head = nn.Sequential(
            nn.Conv2d(backbone_channels, backbone_channels // 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(backbone_channels // 4),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate),
            nn.Conv2d(backbone_channels // 4, self.num_anchors * (5 + self.num_classes), kernel_size=1)
        )
        
        self.register_buffer('anchors', torch.tensor(anchor_boxes) \
                              if not isinstance(anchor_boxes, torch.Tensor) else anchor_boxes)
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        final_conv = self.prediction_head[-1]
        num_outputs_per_anchor = 5 + self.num_classes
        for i in range(self.num_anchors):
            obj_idx = i * num_outputs_per_anchor + 4
            nn.init.constant_(final_conv.bias[obj_idx], -np.log((1 - 0.01) / 0.01))
     
        self.freeze_backbone()

    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
    
    def unfreeze_backbone_layers(self, num_layers=2):
        backbone_layers = list(self.backbone.children())
        
        if num_layers > 0:
            layers_to_unfreeze = backbone_layers[-num_layers:]
            for layer in layers_to_unfreeze:
                for param in layer.parameters():
                    param.requires_grad = True
    
    def update_epoch(self, epoch, partial_unfreeze_layers=2):
        self.current_epoch = epoch
        if epoch >= self.freeze_backbone_epochs and self.backbone[0].weight.requires_grad == False:
            if partial_unfreeze_layers == -1:
                self.unfreeze_backbone()
                print(f"🔓 Epoch {epoch}: Unfroze entire backbone")
            else:
                self.unfreeze_backbone_layers(partial_unfreeze_layers) 
                print(f"🔓 Epoch {epoch}: Unfroze last {partial_unfreeze_layers} backbone layers")

    def _load_backbone(self, backbone_name, grid_size=7):
   
        backbone = torch.hub.load('pytorch/vision:v0.10.0', backbone_name, pretrained=True)
        
        if backbone_name in ['resnet18', 'resnet34']:
            final_channels = 512
            has_bottleneck = False
        else:
            final_channels = 2048
            has_bottleneck = True
        
        if grid_size == 7:
            backbone_modified = nn.Sequential(*list(backbone.children())[:-2])
            output_channels = final_channels
            
        elif grid_size == 14:
            if has_bottleneck:
                backbone.layer4[0].conv2.stride = (1, 1)
                backbone.layer4[0].downsample[0].stride = (1, 1)
            else:
                backbone.layer4[0].conv1.stride = (1, 1)
                backbone.layer4[0].downsample[0].stride = (1, 1)
            
            backbone_modified = nn.Sequential(*list(backbone.children())[:-2])
            output_channels = final_channels
            
        elif grid_size == 28:
            if has_bottleneck:
                backbone.layer3[0].conv2.stride = (1, 1)
                backbone.layer3[0].downsample[0].stride = (1, 1)
                backbone.layer4[0].conv2.stride = (1, 1)
                backbone.layer4[0].downsample[0].stride = (1, 1)
            else:
                backbone.layer3[0].conv1.stride = (1, 1)
                backbone.layer3[0].downsample[0].stride = (1, 1)
                backbone.layer4[0].conv1.stride = (1, 1)
                backbone.layer4[0].downsample[0].stride = (1, 1)
            
            backbone_modified = nn.Sequential(*list(backbone.children())[:-2])
            output_channels = final_channels
            
        else:
            raise ValueError(f"Unsupported grid size: {grid_size}. Supported: [7, 14, 28]")
        
        print(f"Backbone {backbone_name} configured for {grid_size}x{grid_size} grid")
        print(f"Output channels: {output_channels}")
        print(f"Approximate backbone parameters: {sum(p.numel() for p in backbone_modified.parameters()) / 1e6:.1f}M")
        
        return backbone_modified, output_channels

    def forward(self, x):
        batch_size = x.size(0)
        
        features = self.backbone(x)
        adapted_features = self.feature_adapter(features)
        predictions = self.prediction_head(adapted_features)
        
        predictions = predictions.view(
            batch_size,
            self.num_anchors,
            5 + self.num_classes,
            predictions.size(-2),
            predictions.size(-1)
        )
        
        return predictions
    
    def get_model_info(self):
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        backbone_params = sum(p.numel() for p in self.backbone.parameters())
        adapter_params = sum(p.numel() for p in self.feature_adapter.parameters())
        head_params = sum(p.numel() for p in self.prediction_head.parameters())
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'backbone_parameters': backbone_params,
            'adapter_parameters': adapter_params,
            'head_parameters': head_params,
            'backbone_frozen': not self.backbone[0].weight.requires_grad,
            'current_epoch': self.current_epoch
        }

In [44]:
def resnet_yolo_detector_loss(predictions, targets, coord_weight=5.0, noobj_weight=0.5):
    batch_size = predictions.shape[0]
    device = predictions.device

    pred_xy = torch.sigmoid(predictions[:, :, 0:2])
    pred_wh = predictions[:, :, 2:4]            
    pred_obj = predictions[:, :, 4:5]             
    pred_cls = predictions[:, :, 5:]              

    target_xy = targets[:, :, 0:2]
    target_wh = targets[:, :, 2:4]
    target_obj = targets[:, :, 4:5]
    target_cls = targets[:, :, 5:]

    obj_mask = target_obj > 0  
    noobj_mask = ~obj_mask

    obj_mask_expanded = obj_mask.expand_as(pred_xy)  

    if obj_mask.sum() > 0:
        xy_loss = F.mse_loss(
            pred_xy[obj_mask_expanded], 
            target_xy[obj_mask_expanded], 
            reduction='sum'
        )

        wh_loss = F.mse_loss(
            pred_wh[obj_mask_expanded], 
            target_wh[obj_mask_expanded], 
            reduction='sum'
        )
    else:
        xy_loss = torch.tensor(0.0, device=device)
        wh_loss = torch.tensor(0.0, device=device)

    obj_loss = F.binary_cross_entropy_with_logits(
        pred_obj[obj_mask], 
        target_obj[obj_mask], 
        reduction='sum'
    ) if obj_mask.sum() > 0 else torch.tensor(0.0, device=device)

    noobj_loss = F.binary_cross_entropy_with_logits(
        pred_obj[noobj_mask], 
        target_obj[noobj_mask], 
        reduction='sum'
    ) if noobj_mask.sum() > 0 else torch.tensor(0.0, device=device)

    if obj_mask.sum() > 0:
        obj_mask_cls = obj_mask.expand_as(pred_cls)
        cls_loss = F.binary_cross_entropy_with_logits(
            pred_cls[obj_mask_cls], 
            target_cls[obj_mask_cls], 
            reduction='sum'
        )
    else:
        cls_loss = torch.tensor(0.0, device=device)

    total_loss = (
        coord_weight * (xy_loss + wh_loss) + 
        obj_loss + 
        noobj_weight * noobj_loss + 
        cls_loss
    ) / batch_size

    return {
        'total_loss': total_loss,
        'xy_loss': xy_loss / batch_size,
        'wh_loss': wh_loss / batch_size,
        'obj_loss': obj_loss / batch_size,
        'noobj_loss': noobj_loss / batch_size,
        'cls_loss': cls_loss / batch_size
    }

In [None]:
def train_epoch(model, anchors, train_loader, optimizer, scheduler, device, epoch):
    model.train()
    model.update_epoch(epoch, partial_unfreeze_layers=2)
    
    running_losses = {'total': 0.0, 'xy': 0.0, 'wh': 0.0, 'obj': 0.0, 'noobj': 0.0, 'cls': 0.0}
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    
    for batch_idx, (images, targets) in enumerate(pbar):
        images = torch.stack(images).to(device)
        
        yolo_targets = create_yolo_targets(
            targets, 
            anchors
        ).to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        
        losses = resnet_yolo_detector_loss(outputs, yolo_targets)
        losses['total_loss'].backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        for key in running_losses:
            if key == 'total':
                running_losses[key] += losses['total_loss'].item()
            else:
                loss_key = f'{key}_loss'
                if loss_key in losses:
                    running_losses[key] += losses[loss_key].item()

        if batch_idx % 10 == 0:
            avg_loss = running_losses['total'] / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            pbar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{current_lr:.2e}'})
    
    scheduler.step()
    
    num_batches = len(train_loader)
    avg_losses = {key: running_losses[key] / num_batches for key in running_losses}
    
    return avg_losses

def validate_epoch(model, anchors, val_loader, device):
    model.eval()
    
    running_losses = {'total': 0.0, 'xy': 0.0, 'wh': 0.0, 'obj': 0.0, 'noobj': 0.0, 'cls': 0.0}
    
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc='Validation'):
            images = torch.stack(images).to(device)
            
            yolo_targets = create_yolo_targets(
                targets, 
                anchors
            ).to(device)
            
            outputs = model(images)
            losses = resnet_yolo_detector_loss(outputs, yolo_targets)
            
            for key in running_losses:
                if key == 'total':
                    running_losses[key] += losses['total_loss'].item()
                else:
                    loss_key = f'{key}_loss'
                    if loss_key in losses:
                        running_losses[key] += losses[loss_key].item()
    
    num_batches = len(val_loader)
    avg_losses = {key: running_losses[key] / num_batches for key in running_losses}
    
    return avg_losses

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, loss, filepath):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
    }, filepath)

def load_checkpoint(model, optimizer, scheduler, filepath):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, loss

def plot_losses(train_losses, val_losses, save_path=None):
    import matplotlib.pyplot as plt
    epochs = range(1, len(train_losses['total']) + 1)
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Training and Validation Losses')
    
    loss_types = ['total', 'xy', 'wh', 'obj', 'noobj', 'cls']
    titles = ['Total Loss', 'XY Loss', 'WH Loss', 'Objectness Loss', 'No-Object Loss', 'Classification Loss']
    
    for i, (loss_type, title) in enumerate(zip(loss_types, titles)):
        ax = axes[i//3, i%3]
        ax.plot(epochs, train_losses[loss_type], 'b-', label='Train')
        ax.plot(epochs, val_losses[loss_type], 'r-', label='Validation')
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def main():
    from pathlib import Path
    
    config = {
        'experiment_name': 'resnet50_YOLO_improved_regularization_14',
        'batch_size': 16,
        'num_epochs': 25,
        'device': 'mps',
        'target_size': (224, 224),
        'lr': 3e-4,  
        'weight_decay': 1e-2,  
        'dropout_rate': 0.15,
    }
    
    print(f"Using device: {config['device']}")
    device = torch.device(config['device'])
    
    save_dir = "checkpoints" / Path(config['experiment_name'])
    save_dir.mkdir(parents=True, exist_ok=True)

    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Validation samples: {len(val_loader.dataset)}")
    
    model = ResNetYOLODetector(
        anchor_boxes=anchors, 
        backbone_name="resnet50", 
        freeze_backbone_epochs=12,
        grid_size=14,
        dropout_rate=config['dropout_rate']
    ).to(device)
    
    info = model.get_model_info()
    print("Model Information:")
    for key, value in info.items():
        print(f"  {key}: {value:,}" if isinstance(value, int) else f"  {key}: {value}")
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config['lr'], 
        weight_decay=config['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=config['num_epochs'],
        eta_min=config['lr'] * 0.01
    )
    
    train_history = {'total': [], 'xy': [], 'wh': [], 'obj': [], 'noobj': [], 'cls': []}
    val_history = {'total': [], 'xy': [], 'wh': [], 'obj': [], 'noobj': [], 'cls': []}
    
    best_val_loss = float('inf')
    
    for epoch in range(1, config['num_epochs'] + 1):
        print(f"\nEpoch {epoch}/{config['num_epochs']}")
        print("-" * 50)
        
        train_losses = train_epoch(model, anchors, train_loader, optimizer, scheduler, device, epoch)
        val_losses = validate_epoch(model, anchors, val_loader, device)
        
        print(f"Train - Total: {train_losses['total']:.4f}, "
              f"XY: {train_losses['xy']:.4f}, "
              f"WH: {train_losses['wh']:.4f}, "
              f"Obj: {train_losses['obj']:.4f}, "
              f"NoObj: {train_losses['noobj']:.4f}, "
              f"Cls: {train_losses['cls']:.4f}")
        
        print(f"Val   - Total: {val_losses['total']:.4f}, "
              f"XY: {val_losses['xy']:.4f}, "
              f"WH: {val_losses['wh']:.4f}, "
              f"Obj: {val_losses['obj']:.4f}, "
              f"NoObj: {val_losses['noobj']:.4f}, "
              f"Cls: {val_losses['cls']:.4f}")
        
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        for key in train_history:
            train_history[key].append(train_losses[key])
            val_history[key].append(val_losses[key])
        
        if val_losses['total'] < best_val_loss:
            best_val_loss = val_losses['total']
            save_checkpoint(
                model, optimizer, scheduler, epoch, val_losses['total'],
                save_dir / 'best_model.pth'
            )
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
    
    plot_losses(train_history, val_history, 
                save_path=save_dir / 'training_curves.png')
    
if __name__ == "__main__":
    main()

Using device: mps
Training samples: 4437
Validation samples: 783


Using cache found in /Users/tobysmith/.cache/torch/hub/pytorch_vision_v0.10.0


Backbone resnet50 configured for 14x14 grid
Output channels: 2048
Approximate backbone parameters: 23.5M
Model Information:
  total_parameters: 46,601,301
  trainable_parameters: 23,093,269
  backbone_parameters: 23,508,032
  adapter_parameters: 13,643,776
  head_parameters: 9,449,493
  backbone_frozen: 1
  current_epoch: 0

Epoch 1/25
--------------------------------------------------


Epoch 1:  49%|████▊     | 135/278 [01:06<01:10,  2.04it/s, loss=354.9033, lr=3.00e-04]


KeyboardInterrupt: 

In [5]:
def load_model_for_inference(model_path, device='mps'):
    dummy_anchors = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]])  
    model = ResNetYOLODetector(
        anchor_boxes=dummy_anchors,
        backbone_name="resnet50", 
        freeze_backbone_epochs=12,
        grid_size=14
    ).to(device)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"Loaded model from {model_path}")
    print(f"Model was trained for {checkpoint['epoch']} epochs")
    print(f"Best validation loss: {checkpoint['loss']:.4f}")
    print(f"Loaded anchors from model: {model.anchors}")
    
    return model

def inference_on_images(model, image_paths, device='mps', conf_threshold=0.5, target_size=(224, 224)):
    from PIL import Image, ImageDraw, ImageFont
    import torchvision.transforms as transforms
    
    anchors = model.anchors
    
    if isinstance(image_paths, str):
        image_paths = [image_paths]
    
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    class_names = {0: 'dog', 1: 'cat'}
    colors = {0: 'red', 1: 'blue'}
    
    num_images = len(image_paths)
    cols = min(3, num_images)
    rows = (num_images + cols - 1) // cols
    
    _, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))
    if num_images == 1:
        axes = [axes]
    elif rows == 1:
        axes = [axes] if cols == 1 else axes
    else:
        axes = axes.flatten()
    
    model.eval()
    with torch.no_grad():
        for idx, image_path in enumerate(image_paths):
            original_image = Image.open(image_path).convert('RGB')
            original_size = original_image.size
            
            input_tensor = transform(original_image).unsqueeze(0).to(device)
            
            predictions = model(input_tensor)
            
            detections = decode_yolo_predictions(predictions, anchors, conf_threshold=conf_threshold)
            
            scale_x = original_size[0] / target_size[0]
            scale_y = original_size[1] / target_size[1]
            
            draw_image = original_image.copy()
            draw = ImageDraw.Draw(draw_image)
            
            detection_count = {'cat': 0, 'dog': 0}
            
            if len(detections) > 0 and len(detections[0]) > 0:
                for detection in detections[0]:
                    x1, y1, x2, y2, conf, cls = detection
                    
                    x1 = int(x1 * scale_x)
                    y1 = int(y1 * scale_y)
                    x2 = int(x2 * scale_x)
                    y2 = int(y2 * scale_y)
                    
                    class_id = int(cls)
                    class_name = class_names[class_id]
                    color = colors[class_id]
                    detection_count[class_name] += 1
 
                    draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
                    
                    label = f'{class_name}: {conf:.2f}'
                    
                    try:
                        font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 16)
                    except:
                        font = ImageFont.load_default()
                    
                    bbox = draw.textbbox((0, 0), label, font=font)
                    text_width = bbox[2] - bbox[0]
                    text_height = bbox[3] - bbox[1]
                    
                    draw.rectangle([x1, y1-text_height-4, x1+text_width+4, y1], fill=color)
                    draw.text((x1+2, y1-text_height-2), label, fill='white', font=font)
            
            ax = axes[idx] if num_images > 1 else axes[0]
            ax.imshow(draw_image)
            ax.set_title(f'Image {idx+1}\nCats: {detection_count["cat"]}, Dogs: {detection_count["dog"]}')
            ax.axis('off')
            
            print(f"Image {idx+1}: {image_path}")
            print(f"  Detections: {len(detections[0]) if len(detections) > 0 else 0}")
            print(f"  Cats: {detection_count['cat']}, Dogs: {detection_count['dog']}")
    
    for idx in range(num_images, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage:
# Load your trained model (anchors automatically loaded from model)
model_path = "checkpoints/resnet50_YOLO_residual_adapter_14/best_model.pth"
device = 'mps'

# Load the model - anchors are now loaded automatically from the saved model
# model = load_model_for_inference(model_path, device)

# Test on single image
# inference_on_images(model, "path/to/your/test/image.jpg", device)

# Test on multiple images
# image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", "path/to/image3.jpg"]
# inference_on_images(model, image_paths, device)