In [1]:
import lib.Mask2Former as m2f
import lib.Mask2Former.mask2former as mask2former
import os
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from detectron2.engine import (launch)
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data import build_detection_train_loader
from lib.Mask2Former.train_net import Trainer
import numpy as np
from detectron2.structures import Boxes, Instances, BitMasks
import torch
import torch.nn.functional as F
from detectron2.evaluation import DatasetEvaluator, DatasetEvaluators
from detectron2.data import build_detection_test_loader
from detectron2.utils import comm
from detectron2.structures import BoxMode, pairwise_iou
import copy
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_SOURCE = "combined"
DATA_LOCATION = "_data"
DATA_DIR = "coco"
os.environ["DETECTRON2_DATASETS"] = os.path.join(DATA_LOCATION, DATA_DIR)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Convert the dataset to COCO format
The following commands convert the existing PNG mask-based dataset to the coco annotations required for training Mask2Former

In [3]:
CONFIG = "lib/Mask2Former/configs/coco/instance-segmentation/swin/maskformer2_swin_base_IN21k_384_bs16_50ep.yaml"
#CONFIG = "configs/mask2former.yaml"
NUM_GPUS = 1
BATCH_SIZE = 8
LEARNING_RATE = 0.001
DATASET_DIR = "_data/urban_street_combined"
DATASET_DIR_VAL = "_data/combined/val"
IMAGES_DIR_NAME = "images"
IMAGE_DIR = os.path.join(DATASET_DIR, IMAGES_DIR_NAME)
INSTANCES_DIR_NAME = "leaf_instances"
INSTANCES_DIR = os.path.join(DATASET_DIR, INSTANCES_DIR_NAME)
IMAGE_DIR_VAL = os.path.join(DATASET_DIR_VAL, IMAGES_DIR_NAME)
INSTANCES_DIR_VAL = os.path.join(DATASET_DIR_VAL, INSTANCES_DIR_NAME)

# Custom Data Loader

In [4]:
class LeavesDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_files[index])
        label_path = os.path.join(self.label_dir, self.image_files[index])

        image = Image.open(image_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        if self.transform:
            image = self.transform(image)
            #label = self.transform(label).squeeze()

        # Convert label to tensor
        label = torch.from_numpy(np.array(label))

        # Create instances
        instances = Instances(image.shape[1:])

        # Create gt_boxes
        boxes = []
        gt_classes = []
        gt_masks = []
        unique_labels = torch.unique(label)
        if len(unique_labels) > 1:
            if 255 in unique_labels: 
                print("Invalid label in file", image_path)
            for obj_class in unique_labels:
                if obj_class > 0:
                    mask = label == obj_class
                    coords = torch.nonzero(mask)
                    xmin, ymin = coords.min(dim=0).values
                    xmax, ymax = coords.max(dim=0).values
                    boxes.append([xmin, ymin, xmax, ymax])
                    gt_classes.append(obj_class.item())
                    gt_masks.append(mask)

            instances.gt_boxes = Boxes(torch.tensor(boxes))
            instances.gt_classes = torch.tensor(gt_classes, dtype=torch.long)

            # Resize masks to match the image size
            resized_masks = []
            for mask in gt_masks:
                resized_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), size=image.shape[1:], mode='nearest').squeeze().to(torch.bool)
                resized_masks.append(resized_mask)

            if len(resized_masks) > 0:
                instances.gt_masks = torch.stack(resized_masks)
            else:
                print("Masks empty, class lenght is", len(gt_classes))
                instances.gt_masks = torch.Tensor()

            return {
                "image": image,
                "height": image.shape[1],
                "width": image.shape[2],
                "instances": instances,
            }
        
        return {
            "image": image,
            "height": image.shape[1],
            "width": image.shape[2]
        }

In [5]:
class LeavesEvaluator(DatasetEvaluator):
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name
        self._cpu_device = torch.device("cpu")

    def reset(self):
        self._predictions = []
        self._targets = []

    def process(self, inputs, outputs):
        # sample single random instance
        idx = random.randrange(len(inputs))
        self._predictions.append(outputs[idx]["instances"].to(self._cpu_device))
        self._targets.append(inputs[idx]["instances"].to(self._cpu_device))
#        for input, output in zip(inputs, outputs):
#           self._predictions.append(output["instances"].to(self._cpu_device))
#            self._targets.append(input["instances"].to(self._cpu_device))

    def evaluate(self):
        if comm.is_main_process():
            self._evaluate()

        if comm.is_main_process():
            return copy.deepcopy(self._results)
        else:
            return None

    def _evaluate(self):
        self._results = {}
        iou_thresholds = [0.5, 0.75]
        for iou_threshold in iou_thresholds:
            self._results[f"IoU_{iou_threshold}"] = self._compute_iou(iou_threshold)
        self._results["mask_mse_loss"] = self._compute_mask_mse_loss()
        print(self._results)

    def _compute_iou(self, iou_threshold):
        print("Computing IoU")
        num_instances = len(self._predictions)
        iou_sum = 0.0

        for pred, target in zip(self._predictions, self._targets):
            pred_boxes = pred.pred_boxes.tensor
            target_boxes = target.gt_boxes.tensor

            if len(pred_boxes) == 0 or len(target_boxes) == 0:
                continue

            # Convert the boxes to the format expected by the pairwise_iou function
            pred_boxes = BoxMode.convert(pred_boxes, BoxMode.XYXY_ABS, BoxMode.XYXY_ABS)
            target_boxes = BoxMode.convert(target_boxes, BoxMode.XYXY_ABS, BoxMode.XYXY_ABS)

            # Compute IoU between predicted and target boxes
            iou_matrix = pairwise_iou(Boxes(pred_boxes), Boxes(target_boxes))
            max_iou, _ = iou_matrix.max(dim=1)

            # Count the number of predicted boxes with IoU above the threshold
            num_above_threshold = (max_iou > iou_threshold).sum().item()
            iou_sum += num_above_threshold

        avg_iou = iou_sum / num_instances
        return avg_iou
    
    def _compute_mask_mse_loss(self):
        loss = 0
        for pred, target in zip(self._predictions, self._targets):
            for pred_mask, target_mask in zip(pred.pred_masks, target.gt_masks):
                target_mask = target_mask.float()
                diff2 = (torch.flatten(pred_mask) - torch.flatten(target_mask)) ** 2.0
                sum2 = 0.0
                num = 0

                flat_mask = torch.flatten(target_mask)
                assert(len(flat_mask) == len(diff2))
                for i in range(len(diff2)):
                    if flat_mask[i] == 1:
                        sum2 += diff2[i]
                        num += 1

                loss += sum2 / num
        return loss


In [6]:
def collate_fn(batch):
    images = []
    instances = []
    extras = {}

    for item in batch:
        images.append(item["image"])
        
        item_instances = item["instances"]
        item_instances["gt_boxes"] = torch.tensor(item_instances["gt_boxes"])
        item_instances["gt_classes"] = torch.tensor(item_instances["gt_classes"], dtype=torch.long)
        item_instances["gt_masks"] = torch.tensor(item_instances["gt_masks"])
        instances.append(item_instances)
        
        extras["height"] = item["height"]
        extras["width"] = item["width"]

    batched_inputs = [
        {"image": image, "instances": instance, **extras}
        for image, instance in zip(images, instances)
    ]

    return batched_inputs

class LeavesTrainer(Trainer):
    @classmethod
    def build_train_loader(cls, _):
        # Define your data transforms
        transform = transforms.Compose([
            transforms.Resize((800, 800)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Create the dataset
        dataset = LeavesDataset(IMAGE_DIR, INSTANCES_DIR, transform=transform, )
        
        # Create the DataLoader
        dataloader = build_detection_train_loader(dataset, mapper=None, total_batch_size=1)
        return dataloader
    
    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        # Define your data transforms
        transform = transforms.Compose([
            transforms.Resize((800, 800)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Create the dataset
        dataset = LeavesDataset(IMAGE_DIR, INSTANCES_DIR, transform=transform, )
        
        # Create the DataLoader
        dataloader = build_detection_test_loader(dataset, mapper=None)
        return dataloader
        
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        return LeavesEvaluator(dataset_name)

In [7]:
def get_trainer(cfg):
    trainer = LeavesTrainer(cfg)
    #trainer.resume_or_load(resume=args.resume)
    return trainer.train()

In [None]:
cfg = get_cfg()
add_deeplab_config(cfg)
mask2former.add_maskformer2_config(cfg)
cfg.merge_from_file(CONFIG)

launch(get_trainer, 1, args=(cfg,))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[07/17 12:47:17 d2.engine.defaults]: [0mModel:
MaskFormer(
  (backbone): D2SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=128, out_features=384, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=128, out_features=128, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
      



[32m[07/17 12:47:26 d2.utils.events]: [0m eta: 1 day, 20:50:49  iter: 19  total_loss: 132  loss_ce: 5.979  loss_mask: 2.508  loss_dice: 4.511  loss_ce_0: 8.791  loss_mask_0: 2.039  loss_dice_0: 4.348  loss_ce_1: 7.337  loss_mask_1: 2.212  loss_dice_1: 4.42  loss_ce_2: 6.183  loss_mask_2: 2.173  loss_dice_2: 4.548  loss_ce_3: 5.973  loss_mask_3: 3.044  loss_dice_3: 4.621  loss_ce_4: 6.269  loss_mask_4: 1.882  loss_dice_4: 4.652  loss_ce_5: 6.102  loss_mask_5: 1.717  loss_dice_5: 4.704  loss_ce_6: 6.107  loss_mask_6: 2.22  loss_dice_6: 4.591  loss_ce_7: 5.956  loss_mask_7: 2.501  loss_dice_7: 4.556  loss_ce_8: 5.908  loss_mask_8: 2.825  loss_dice_8: 4.586    time: 0.4400  last_time: 0.4436  data_time: 0.0250  last_data_time: 0.0272   lr: 1e-05  max_mem: 6710M
[32m[07/17 12:47:35 d2.evaluation.evaluator]: [0mStart inference on 9899 batches
Instances(num_instances=6, image_height=800, image_width=800, fields=[gt_boxes: Boxes(tensor([[380., 364., 439., 441.],
        [ 44., 199., 404., 