# instance segmentation

In [1]:
import torch 
from torchvision.models.detection.mask_rcnn import MaskRCNN_ResNet50_FPN_Weights
from torchvision.models import detection
import lightning as pl
import os
os.chdir("/home/matrament/studia/deep_learning/DeepLearning")
from lightning.pytorch.loggers import MLFlowLogger
# visualisation
import torchvision
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from torchvision import transforms as T
import numpy as np
from pycocotools.coco import COCO

### Build Dataloader

In [2]:
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_dir, transforms=None):
        # self.transform =  v2.Compose([
        #     v2.ToImage(),
        #     v2.RandomResizedCrop(size=(100, 100), antialias=True),
        #     v2.RandomHorizontalFlip(p=0.5),
        #     v2.ToDtype(torch.float32, scale=True),
        # ])     
        # self.normalization=v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.transforms = transforms
        self.image_dir = image_dir
        self.coco = COCO(annotation_dir)
        self.cat_ids = self.coco.getCatIds()
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):

        img_id = self.coco.getImgIds()[idx]

        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.cat_ids, iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)

        img = self.coco[idx][0]
        # Bounding boxes for objects
        # In coco format, bbox = [xmin, ymin, width, height]
        # In maskrcnn, the input should be [xmin, ymin, xmax, ymax]
        # Prepare target dictionary
        boxes = []
        masks = []
        labels = []
        areas = []
        iscrowd = []

        for ann in anns:
            # Bounding box
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])

            # Segmentation mask
            mask = self.coco.annToMask(ann)
            masks.append(mask)

            # Category ID
            labels.append(ann['category_id'])

            # Area and iscrowd
            areas.append(ann['area'])
            iscrowd.append(ann.get('iscrowd', 0))

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)

        # Target dictionary
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": areas,
            "iscrowd": iscrowd,
        }

        # Apply transforms if provided
        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target

In [3]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

### Load data

In [4]:
dataset_train=CocoDataset("dataset/train2017_subset", "dataset/instances_train2017_subset.json",get_transform(train=True)) # define the dataset

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))

loading annotations into memory...
Done (t=1.82s)
creating index...
index created!


### Load model

In [5]:
def build_model(num_classes):

    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=None)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

In [6]:
num_classes = len(dataset_train.coco.cats) +1 # 1 for background
model = build_model(num_classes)

### Create a Lightning Module

In [7]:
class MaskRCNNLitModule(pl.LightningModule):
    def __init__(self, model, learning_rate=0.005):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def training_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self.model(images, targets)
        total_loss = sum(loss for loss in loss_dict.values())
        self.log("train_loss", total_loss)
        return total_loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self.model(images, targets)
        total_loss = sum(loss for loss in loss_dict.values())
        self.log("val_loss", total_loss)
        return total_loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.learning_rate,
            momentum=0.9,
            weight_decay=0.0005
        )
        return optimizer

    def forward(self, images, targets):
        return self.model(images, targets)

### Extension of the callback class (Lightning)

In [8]:
class MetricTracker(pl.Callback): # extend the Callback class

    def __init__(self):
        self.collection = []

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        vacc = outputs["val_loss"]  # you can access them here
        self.collection.append(vacc)  # track them

    def on_validation_epoch_end(self, trainer, module):
        elogs = trainer.logged_metrics["val_loss"]  # access it here
        self.collection.append(elogs)
        # do whatever is needed

### MlFlow Logger

In [9]:
mlf_logger = MLFlowLogger(
            experiment_name=f"maskrcnn_resnet",
            tracking_uri="http://localhost:5000",
            log_model=True,
        )

### Training (with Pytorch Lightning Trainer)

In [None]:
metr = MetricTracker()
litmodule = MaskRCNNLitModule(model)

trainer = pl.Trainer(
    max_epochs=30,
    enable_progress_bar=True,
    logger=mlf_logger,
    default_root_dir=f"/tmp/{'v'}_{74}",
    callbacks=[
        metr,
    ],
)
trainer.fit(
            model=litmodule, train_dataloaders=train_loader #, val_dataloaders=val_loader
        )