In [1]:
import lib.Mask2Former as m2f
import lib.Mask2Former.mask2former as mask2former
import os
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from detectron2.engine import (launch)
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data import build_detection_train_loader
from lib.Mask2Former.train_net import Trainer
import numpy as np
from detectron2.structures import Boxes, Instances, BitMasks
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_SOURCE = "combined"
DATA_LOCATION = "_data"
DATA_DIR = "coco"
os.environ["DETECTRON2_DATASETS"] = os.path.join(DATA_LOCATION, DATA_DIR)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Convert the dataset to COCO format
The following commands convert the existing PNG mask-based dataset to the coco annotations required for training Mask2Former

In [3]:
CONFIG = "lib/Mask2Former/configs/coco/instance-segmentation/swin/maskformer2_swin_base_IN21k_384_bs16_50ep.yaml"
NUM_GPUS = 1
BATCH_SIZE = 8
LEARNING_RATE = 0.001
DATASET_DIR = "_data/urban_street_combined"
IMAGES_DIR_NAME = "images"
IMAGE_DIR = os.path.join(DATASET_DIR, IMAGES_DIR_NAME)
INSTANCES_DIR_NAME = "leaf_instances"
INSTANCES_DIR = os.path.join(DATASET_DIR, INSTANCES_DIR_NAME)

# Custom Data Loader

In [4]:
class LeavesDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_files = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.image_files[index])
        label_path = os.path.join(self.label_dir, self.image_files[index])

        image = Image.open(image_path).convert("RGB")
        label = Image.open(label_path).convert("L")

        if self.transform:
            image = self.transform(image)
            #label = self.transform(label).squeeze()

        # Convert label to tensor
        label = torch.from_numpy(np.array(label))

        # Create instances
        instances = Instances(image.shape[1:])

        # Create gt_boxes
        boxes = []
        gt_classes = []
        gt_masks = []
        unique_labels = torch.unique(label)
        if len(unique_labels) > 1:
            if 255 in unique_labels: 
                print("Invalid label in file", image_path)
            for obj_class in unique_labels:
                if obj_class > 0:
                    mask = label == obj_class
                    coords = torch.nonzero(mask)
                    xmin, ymin = coords.min(dim=0).values
                    xmax, ymax = coords.max(dim=0).values
                    boxes.append([xmin, ymin, xmax, ymax])
                    gt_classes.append(obj_class.item())
                    gt_masks.append(mask)

            instances.gt_boxes = Boxes(torch.tensor(boxes))
            instances.gt_classes = torch.tensor(gt_classes, dtype=torch.long)

            # Resize masks to match the image size
            resized_masks = []
            for mask in gt_masks:
                resized_mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(), size=image.shape[1:], mode='nearest').squeeze().to(torch.bool)
                resized_masks.append(resized_mask)

            if len(resized_masks) > 0:
                instances.gt_masks = torch.stack(resized_masks)
            else:
                print("Masks empty, class lenght is", len(gt_classes))
                instances.gt_masks = torch.Tensor()

            return {
                "image": image,
                "height": image.shape[1],
                "width": image.shape[2],
                "instances": instances,
            }
        
        return {
            "image": image,
            "height": image.shape[1],
            "width": image.shape[2]
        }

In [5]:
def collate_fn(batch):
    images = []
    instances = []
    extras = {}

    for item in batch:
        images.append(item["image"])
        
        item_instances = item["instances"]
        item_instances["gt_boxes"] = torch.tensor(item_instances["gt_boxes"])
        item_instances["gt_classes"] = torch.tensor(item_instances["gt_classes"], dtype=torch.long)
        item_instances["gt_masks"] = torch.tensor(item_instances["gt_masks"])
        instances.append(item_instances)
        
        extras["height"] = item["height"]
        extras["width"] = item["width"]

    batched_inputs = [
        {"image": image, "instances": instance, **extras}
        for image, instance in zip(images, instances)
    ]

    return batched_inputs

class LeavesTrainer(Trainer):
    @classmethod
    def build_train_loader(cls, _):
        # Define your data transforms
        transform = transforms.Compose([
            transforms.Resize((800, 800)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        # Create the dataset
        dataset = LeavesDataset(IMAGE_DIR, INSTANCES_DIR, transform=transform, )
        
        # Create the DataLoader
        dataloader = build_detection_train_loader(dataset, mapper=None, total_batch_size=1)
        #dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
        return dataloader

In [6]:
def get_trainer(cfg):
    trainer = LeavesTrainer(cfg)
    #trainer.resume_or_load(resume=args.resume)
    return trainer.train()

In [None]:
cfg = get_cfg()
add_deeplab_config(cfg)
mask2former.add_maskformer2_config(cfg)
cfg.merge_from_file(CONFIG)

launch(get_trainer, 1, args=(cfg,))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[07/16 19:31:46 d2.engine.defaults]: [0mModel:
MaskFormer(
  (backbone): D2SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=128, out_features=384, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=128, out_features=128, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
      



[32m[07/16 19:31:56 d2.utils.events]: [0m eta: 1 day, 22:26:17  iter: 19  total_loss: 123.4  loss_ce: 5.416  loss_mask: 2.164  loss_dice: 4.68  loss_ce_0: 8.857  loss_mask_0: 2.087  loss_dice_0: 4.352  loss_ce_1: 6.32  loss_mask_1: 2.261  loss_dice_1: 4.452  loss_ce_2: 6.239  loss_mask_2: 2.354  loss_dice_2: 4.508  loss_ce_3: 6.153  loss_mask_3: 2.237  loss_dice_3: 4.521  loss_ce_4: 5.983  loss_mask_4: 2.081  loss_dice_4: 4.653  loss_ce_5: 5.822  loss_mask_5: 2.271  loss_dice_5: 4.575  loss_ce_6: 5.702  loss_mask_6: 2.156  loss_dice_6: 4.611  loss_ce_7: 5.416  loss_mask_7: 2.503  loss_dice_7: 4.57  loss_ce_8: 5.408  loss_mask_8: 2.242  loss_dice_8: 4.677    time: 0.4683  last_time: 0.4544  data_time: 0.0270  last_data_time: 0.0200   lr: 1e-05  max_mem: 6713M
[32m[07/16 19:32:06 d2.utils.events]: [0m eta: 1 day, 22:14:44  iter: 39  total_loss: 105.9  loss_ce: 3.757  loss_mask: 1.954  loss_dice: 4.659  loss_ce_0: 9.027  loss_mask_0: 1.613  loss_dice_0: 4.334  loss_ce_1: 3.895  loss_m