In [1]:
import os
import cv2
import yaml
import random
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CocoDetection
from torchvision.models.detection import maskrcnn_resnet50_fpn, maskrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import torchvision.transforms as transforms
from torchvision.transforms import functional as F 
import numpy as np
from collections import defaultdict
from tqdm import tqdm
from pycocotools.coco import COCO
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
from coco_eval import CocoEvaluator
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from PIL import Image, ImageDraw

In [2]:
NUM_EPOCHS = 100
NUM_WORKERS = 2
BATCH_SIZE = 4

OUTPUT_DIR = "out/pt_maskrcnn"

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device {device}")

Using device cuda


In [4]:
def get_model_instance_segmentation(num_classes):
    model = maskrcnn_resnet50_fpn_v2(pretrained=True)
    # Replace the box classifier with the desired number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    # Replace the mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    
    return model

In [5]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_file, root, split, transforms=None):
        self.dataset_file = dataset_file
        self.root = root
        with open(dataset_file) as f:
            self.dataset = yaml.safe_load(f)
        if split not in self.dataset:
            raise Error(f"Split not defined in {dataset_file}")
        self.split = split
        self.transforms = transforms
        self.data_dir = os.path.join(root, self.dataset["path"], self.dataset[split])
        if "images" in self.data_dir:
            self.data_dir = Path(self.data_dir).parent
        self.image_files = os.listdir(os.path.join(self.data_dir, "images"))
        self.label_files = os.listdir(os.path.join(self.data_dir, "labels"))

    def __getitem__(self, index):
        img = cv2.imread(os.path.join(self.data_dir, "images", self.image_files[index]))
        target = self.load_annotations(os.path.join(self.data_dir, "labels", self.label_files[index]), img.shape[:2], index)
        return img, target

    def __len__(self):
        return len(self.image_files)
                  
    def load_annotations(self, file, shape, idx):
        with open(file) as f:
            label_data = f.read().split("\n")
        boxes = []
        labels = []
        masks = []
        image_ids = []
        areas = []
        iscrowds = []
        for label_datum in label_data:
            class_id = label_data[0]
            bbox = np.array(label_data[1:5]).astype(np.float64)
            poly = np.array(label_data[5:]).astype(np.float64)

            #scale_mask = np.array([shape[0], shape[1]] * (len(poly) // 2 + 1))[:len(poly)]
            poly_scaled = poly.copy()
            poly_scaled[::2] *= shape[0]  # multiply every other element starting from index 0
            poly_scaled[1::2] *= shape[1] # multiply every other element starting from index 1

            poly_scaled = poly_scaled.reshape(-1, 2).reshape((-1, 1, 2)).astype(np.int32)

            mask = np.zeros(shape)
            cv2.fillPoly(mask, poly_scaled, 1)
            pixels = cv2.countNonZero(mask)
            image_area = shape[0] * shape[1]
            area_ratio = (pixels / image_area)
            mask = mask.astype(bool)
            
            labels.append(class_id)
            boxes.append(bbox)
            masks.append(mask)
            image_ids.append(idx * random.randint(0,2543))
            areas.append(area_ratio)
            iscrowds.append(False)
                                       
        target = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(class_ids, dtype=torch.int64),
            "masks": torch.tensor(masks, dtype=torch.uint8),  # Shape: (N, H, W)
            "image_id": torch.tensor(image_ids),
            "area": torch.tensor(areas, dtype=torch.float32),
            "iscrowd": torch.tensor(iscrowds, dtype=torch.int64)
        }
                                       
        return target

In [6]:
class CocoDataset(CocoDetection):
    def __init__(self, root, annFile, transforms=None):
        super(CocoDataset, self).__init__(root, annFile)
        self.transforms = transforms

    def __getitem__(self, idx):
        img, targets = super(CocoDataset, self).__getitem__(idx)

        # Extract image ID
        image_id = self.ids[idx]

        # Convert target information to a usable format
        boxes = []
        labels = []
        masks = []
        
        print(len(targets))

        for target in targets:
            # Each target is a tuple (mask, bbox, category_id)
            masks.append(self.coco.annToMask(target[0]))  # Extract mask from the first item
            boxes.append(target[1])  # Extract bounding box from the second item
            labels.append(target[2])  # Extract label from the third item

        # Convert to tensors
        target_dict = {
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.int64),
            "masks": torch.stack([torch.tensor(mask, dtype=torch.uint8) for mask in masks]),
            "image_id": torch.tensor([image_id]),
        }

        # Apply transformations if provided
        #if self.transforms:
        #    img = self.transforms(img)

        return img, target_dict

In [7]:
class CustomCocoDataset(CocoDetection):
    def __init__(self, root, annFile, transform=None):
        """
        Custom COCO dataset that loads images, bounding boxes, segmentation masks, 
        and category labels for each instance in an image.
        
        Args:
            root (str): Directory with all the images.
            annFile (str): Path to the COCO annotation file.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        super().__init__(root, annFile)
        self.transform = transform

    def __getitem__(self, index):
        """
        Returns an image along with its target annotations.

        Args:
            index (int): Index of the sample.

        Returns:
            image (torch.Tensor): Transformed image tensor.
            target (dict): Dictionary with keys:
                - 'boxes': Tensor of bounding boxes (N, 4).
                - 'labels': Tensor of category labels (N).
                - 'masks': Tensor of segmentation masks (N, H, W).
                - 'image_id': Tensor with a unique ID for the image.
                - 'area': Tensor of areas of the bounding boxes.
                - 'iscrowd': Tensor indicating if the instance is a crowd (1) or not (0).
        """
        # Load image and annotations
        img, annotations = super().__getitem__(index)
        
        # Convert image to tensor
        if self.transform:
            img = self.transform(img)
        else:
            img = F.to_tensor(img)  # Convert image to tensor
        
        # Initialize lists to hold instance data
        boxes = []
        labels = []
        masks = []
        area = []
        iscrowd = []
        ids = []
        
        if len(annotations) == 0:
            return img, {
                'boxes': torch.tensor([]),
                'labels': torch.tensor([]),
                'masks': torch.tensor([]),
                'image_id': torch.tensor([]),
                'area': torch.tensor([]),
                'iscrowd': torch.tensor([]),
                'ids': torch.tensor([])
            }

        for annotation in annotations:
            # Bounding box in [x, y, width, height]
            x, y, width, height = annotation['bbox']
            boxes.append([x, y, x + width, y + height])
            labels.append(annotation['category_id'])
            area.append(annotation['area'])
            iscrowd.append(annotation['iscrowd'])
            ids.append(annotation['id'])

            # Process segmentation
            mask = self._create_segmentation_mask(annotation, img.size(1), img.size(2))
            masks.append(mask)

        # Convert lists to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        if len(masks) == 0:
            masks = torch.tensor([])
        else:
            masks = torch.stack(masks, dim=0)
        area = torch.as_tensor(area, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)

        # Image ID
        image_id = torch.tensor([index])
        ids = torch.as_tensor(ids, dtype=torch.int64)

        # Create the target dictionary
        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd,
            'ids': ids
        }

        return img, target

    def _create_segmentation_mask(self, annotation, height, width):
        """
        Creates a binary mask for a given instance's segmentation data.

        Args:
            annotation (dict): COCO annotation dictionary containing segmentation data.
            height (int): Height of the image.
            width (int): Width of the image.

        Returns:
            torch.Tensor: Binary mask tensor for the instance.
        """
        mask = Image.new("L", (width, height), 0)  # Create a blank mask
        for segmentation in annotation['segmentation']:
            # segmentation is a list of coordinates: [x1, y1, x2, y2, ..., xn, yn]
            if isinstance(segmentation, list):  # Polygon format
                poly = np.array(segmentation).reshape((-1, 2))
                ImageDraw.Draw(mask).polygon(poly.flatten().tolist(), outline=1, fill=1)
            # Additional processing may be added here for other segmentation formats
        return F.to_tensor(mask)  # Convert mask to tensor

In [None]:
# Create train and validation datasets and loaders
#train_dataset = CustomDataset(root="_data/datasets", dataset_file="_data/synthetic.yaml", split="train", transforms=F.to_tensor)
#val_dataset = CustomDataset(root="_data/datasets", dataset_file="_data/synthetic.yaml", split="val", transforms=F.to_tensor)
input_transforms = transforms.Compose([
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float32),
])

train_dataset = CustomCocoDataset(root="_data/datasets/synthetic_leaf_instances/train/images", 
                              annFile="_data/coco_synthetic_train.json", 
                              transform=input_transforms)
val_dataset = CustomCocoDataset(root="_data/datasets/synthetic_leaf_instances/val/images", 
                            annFile="_data/coco_synthetic_val.json", 
                            transform=input_transforms)

loading annotations into memory...


In [None]:
sample_img, sample_annos = next(iter(train_dataset))
print(sample_img.shape)
print(sample_annos.keys())
print(sample_annos)
plt.imshow(sample_img.numpy().transpose((1,2,0)))
plt.show()

In [None]:
train_loader = DataLoader(train_dataset, 
                          batch_size=BATCH_SIZE, 
                          num_workers=NUM_WORKERS, 
                          shuffle=True,
                          collate_fn=lambda batch: tuple(zip(*batch)))
val_loader = DataLoader(val_dataset, 
                        batch_size=BATCH_SIZE,
                        num_workers=NUM_WORKERS,
                        shuffle=False,
                        collate_fn=lambda batch: tuple(zip(*batch)))

In [None]:
loader_sample = next(iter(train_loader))
print(loader_sample)

In [None]:
coco_test = COCO("_data/coco_synthetic_val.json")
len(coco_test.getAnnIds())

In [None]:
# Set up the model
num_classes = 2  # background + 1 foreground class (or adjust as needed)
model = get_model_instance_segmentation(num_classes)
model = model.to(device)

In [None]:
# Set up optimizer and learning rate scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
def convert_to_coco_format(detections, image_ids, category_mapping):
    """
    Converts model detections to COCO results format.
    
    Args:
        detections (list of dict): List of detection results where each dict contains:
            - 'boxes': Tensor of shape (N, 4) with bounding boxes in [x_min, y_min, x_max, y_max]
            - 'labels': Tensor of shape (N,) with category labels
            - 'scores': Tensor of shape (N,) with confidence scores
        image_ids (list): List of image IDs corresponding to each detection.
        category_mapping (dict): Mapping from class indices to COCO category IDs.
        
    Returns:
        list: List of detections in COCO results format.
    """
    coco_results = []
    for i, det in enumerate(detections):
        image_id = image_ids[i]
        
        boxes = det['boxes'].cpu().numpy()  # Bounding boxes
        labels = det['labels'].cpu().numpy()  # Class labels
        scores = det['scores'].cpu().numpy()  # Confidence scores
        
        for j in range(len(boxes)):
            x_min, y_min, x_max, y_max = boxes[j]
            width = x_max - x_min
            height = y_max - y_min
            coco_results.append({
                'image_id': int(image_id),
                'category_id': int(category_mapping[labels[j]]),
                'bbox': [float(x_min), float(y_min), float(width), float(height)],
                'score': float(scores[j]),
                'mask': det['masks'].cpu().numpy()
            })
    
    return coco_results

In [None]:
# Evaluation function to compute metrics
coco_gt = COCO("_data/coco_synthetic_val.json")  # Load COCO annotations if you have them
def evaluate(model, data_loader, device):
    model.eval()
    coco_evaluator = None
    
    with torch.no_grad():
        for images, targets in tqdm(data_loader):
            corrupt = False
            for target in targets:
                if len(target["boxes"]) == 0:
                    corrupt = True
                    break
            if corrupt:
                continue
            images = list(img.to(device) for img in images)
            outputs = model(images)
            
            # Collect outputs and ground-truth for coco-style evaluation
            if coco_evaluator is None:
                coco_evaluator = CocoEvaluator(coco_gt, iou_types=["bbox", "segm"])
            
            coco_evaluator.update(convert_to_coco_format(outputs, image_ids=[target['image_id'] for target in targets], category_mapping={1:1}))
        
        # Gather evaluation results
        coco_evaluator.synchronize_between_processes()
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
        
        return coco_evaluator.coco_eval["segm"].stats  # return segmentation stats

In [None]:
def coco_eval_to_dict(coco_eval):
    """
    Converts a COCOeval object to a dictionary format.

    Args:
        coco_eval (COCOeval): COCOeval object after evaluation.

    Returns:
        dict: Dictionary with COCO evaluation metrics.
    """
    stats = coco_eval
    
    eval_metrics = {
        'Average Precision (AP) @[ IoU=0.50:0.95 | area=all | maxDets=100 ]': stats[0],
        'Average Precision (AP) @[ IoU=0.50      | area=all | maxDets=100 ]': stats[1],
        'Average Precision (AP) @[ IoU=0.75      | area=all | maxDets=100 ]': stats[2],
        'Average Precision (AP) @[ IoU=0.50:0.95 | area=small | maxDets=100 ]': stats[3],
        'Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ]': stats[4],
        'Average Precision (AP) @[ IoU=0.50:0.95 | area=large | maxDets=100 ]': stats[5],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=all | maxDets=1 ]': stats[6],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=all | maxDets=10 ]': stats[7],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=all | maxDets=100 ]': stats[8],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=small | maxDets=100 ]': stats[9],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=medium | maxDets=100 ]': stats[10],
        'Average Recall (AR) @[ IoU=0.50:0.95    | area=large | maxDets=100 ]': stats[11],
    }
    
    return eval_metrics

In [None]:
# Training loop
metrics_df = pd.DataFrame()
training_df = pd.DataFrame()
for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    p_bar = tqdm(enumerate(train_loader), desc=f"Training Epoch {epoch+1}", total=len(train_loader))
    total_corrupt = 0
    for i, data in p_bar:
        images, targets = data
        corrupt = False
        for target in targets:
            if len(target["boxes"]) == 0:
                corrupt = True
                break
        if corrupt:
            total_corrupt += 1
            continue
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        if i % 100:
            training_df = pd.concat([training_df, pd.DataFrame({key: loss.item() for key, loss in loss_dict.items()}, index=[epoch * len(train_loader) + i])])
        losses = sum(loss for loss in loss_dict.values())
        epoch_loss += losses.item()
    
        p_bar.set_description(f"Training Epoch {epoch+1}, loss: {0 if not epoch_loss > 0 else (epoch_loss/(i+1)):.4f}, corrupted: {total_corrupt}")
        
        # Backpropagation
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    # Step the scheduler
    lr_scheduler.step()
    
    # Print training loss
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {epoch_loss/len(train_loader):.4f}")
    
    # Evaluate on validation set and log metrics
    segm_stats = evaluate(model, val_loader, device)
    segm_stats_dict = coco_eval_to_dict(segm_stats)
#    segm_stats_dict['epoch'] = epoch
    metrics_df = pd.concat([metrics_df, pd.DataFrame(segm_stats_dict, index=[epoch])])
    print(f"Segmentation metrics (AP, AR) at epoch {epoch+1}: {segm_stats}")
    metrics_df.to_csv(os.path.join(OUTPUT_DIR, "metrics_eval.csv"))
    training_df.to_csv(os.path.join(OUTPUT_DIR, "metrics_train.csv"))
    

print("Finished training")