<a href="https://colab.research.google.com/github/thaneesan99/Pytorch_Mask_RCNN/blob/main/PytorchMaskRCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Download the dataset zip file
!curl -L "place_your_dataset_link" -o roboflow.zip

# Unzip the dataset into the 'dataset' folder
!unzip roboflow.zip -d dataset

In [2]:
import os
import numpy as np
import torch
from PIL import Image
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
from torchvision import transforms as T
import torchvision


In [3]:
class CocoInstanceDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, transforms=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transforms = transforms

    def __getitem__(self, index):
        coco = self.coco
        img_id = self.ids[index]
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)

        path = coco.loadImgs(img_id)[0]['file_name']
        img = Image.open(os.path.join(self.img_dir, path)).convert("RGB")

        boxes, labels, masks = [], [], []
        for ann in anns:
            x, y, w, h = ann['bbox']
            boxes.append([x, y, x + w, y + h])
            labels.append(ann['category_id'])

            rle = coco_mask.frPyObjects(ann['segmentation'], img.height, img.width)
            m = coco_mask.decode(rle)
            if len(m.shape) == 3:
                m = np.any(m, axis=2)
            masks.append(m)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(np.array(masks), dtype=torch.uint8)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.ids)


In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Define the transform for the dataset
def get_transform():
    return T.Compose([
        T.ToTensor(),  # converts PIL image to [C,H,W] float tensor in [0,1]
    ])


def get_model_instance_segmentation(num_classes):
    model = maskrcnn_resnet50_fpn(pretrained=True)

    # Replace box head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace mask head
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)

    return model


train_dataset = CocoInstanceDataset("/content/dataset/train", "/content/dataset/train/_annotations.coco.json", transforms=get_transform())
val_dataset = CocoInstanceDataset("/content/dataset/valid", "/content/dataset/valid/_annotations.coco.json", transforms=get_transform())

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))


In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_classes = 2  # background + your single object class (update as needed)
num_epochs = 10

model = get_model_instance_segmentation(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)


In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}")


In [7]:
torch.save(model.state_dict(), "maskrcnn_coco_model.pt")

In [None]:
from google.colab import files

files.download('maskrcnn_coco_model.pt')