In [1]:
import torch, torchvision
from pycocotools.coco import COCO

class COCODataset(torch.utils.data.Dataset):
    def __init__(self, annotation_file, image_dir, transforms=None):
        self.coco = COCO(annotation_file)
        self.image_dir = image_dir
        self.transforms = transforms
        self.image_ids = list(self.coco.imgs.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        annotations = self.coco.loadAnns(self.coco.getAnnIds(imgIds=image_id))
        image_info = self.coco.loadImgs(image_id)[0]

        # Load image
        image_path = os.path.join(self.image_dir, image_info["file_name"])
        image = Image.open(image_path).convert("RGB")

        # Process annotations
        boxes = []
        labels = []
        for ann in annotations:
            x, y, width, height = ann["bbox"]
            if width > 0 and height > 0:  # Only add valid boxes
                boxes.append([x, y, x + width, y + height])
                labels.append(ann["category_id"])

        # Convert to tensor
        if len(boxes) == 0:  # Handle no annotations
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([image_id])
        }

        if self.transforms:
            image = self.transforms(image)

        return image, target


val_annotation_file = 'archive/valid_annotations1.json'
val_image_dir = 'archive/dataset/dataset/valid/images'

from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn(weights=weights)
num_classes = 4  # 3 classes (pothole, cracks, open_manhole) + 1 background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

def get_transform():
    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])

val_dataset = COCODataset(val_annotation_file, val_image_dir, transforms=get_transform())

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [None]:
import os
import glob
import torch
from torch.utils.data import DataLoader

# ===== User must define or import their model class and validation dataset =====
# from your_model_file import ModelClass, val_dataset


# Create DataLoader for validation data
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Evaluation function: compute classification accuracy
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            # assumes classification: take the highest logit
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)
    return correct / total if total > 0 else 0

# Directory containing checkpoint files
checkpoint_dir = 'ckpt1'
# Automatically find all .pth files in the folder
checkpoint_paths = sorted(glob.glob(os.path.join(checkpoint_dir, '*.pth')))

if not checkpoint_paths:
    raise FileNotFoundError(f"No .pth files found in {checkpoint_dir}")

# Main evaluation loop
if __name__ == '__main__':
    # Initialize model and move to device
    model = model.to(device)

    best_acc = 0.0
    best_ckpt = None

    for ckpt_path in checkpoint_paths:
        # Load checkpoint
        checkpoint = torch.load(ckpt_path, map_location=device)
        # If checkpoint contains a state_dict key, adjust accordingly:
        state_dict = checkpoint.get('state_dict', checkpoint)
        model.load_state_dict(state_dict)

        # Evaluate accuracy
        acc = evaluate(model, val_loader, device)
        print(f"Checkpoint {os.path.basename(ckpt_path)}: Accuracy = {acc:.4f}")

        # Track best
        if acc > best_acc:
            best_acc = acc
            best_ckpt = ckpt_path

    print(f"\nBest checkpoint: {os.path.basename(best_ckpt)} with Accuracy = {best_acc:.4f}")
