BSD 3-Clause License

Copyright (c) 2017-2022, Pytorch contributors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
import json
import cv2

In [None]:
TRAIN_PATH = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/train'
VAL_PATH = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/valid'
TEST_PATH = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/test'

TRAIN_ANN = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/train/_annotations.coco.json'
VAL_ANN = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/valid/_annotations.coco.json'
TEST_ANN = '/kaggle/input/brain-tumor-image-dataset-semantic-segmentation/test/_annotations.coco.json'

In [None]:
print(os.listdir(TRAIN_PATH)[:5])

In [None]:
image_sample = Image.open(TRAIN_PATH + "/" + os.listdir(TRAIN_PATH)[0])
plt.imshow(image_sample)

In [None]:
with open(TRAIN_ANN, "r") as f:
    raw_coco = f.read()
    train_coco = json.loads(raw_coco)
    print(train_coco.keys())

In [None]:
print(train_coco["images"][100])

In [None]:
print(train_coco["annotations"][100])

In [None]:
def image_id_to_image_and_mask(coco, dir_name, image_id):
    image_path = coco["images"][image_id]["file_name"]
    image = np.array(Image.open(dir_name + "/" + image_path))

    bboxes = []
    for ann in coco["annotations"]:
        if ann["image_id"] == image_id:
            bboxes.append(ann["bbox"])
    mask = np.zeros((coco["images"][image_id]["height"], coco["images"][image_id]["width"]))
    for bbox in bboxes:
        x_min, y_min, width, height = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
        mask[y_min:y_min + height, x_min:x_min + width] = 255 

    return image, mask

In [None]:
image, mask = image_id_to_image_and_mask(train_coco, TRAIN_PATH, 0)
plt.subplot(1, 2, 1)
plt.imshow(Image.fromarray(image))
plt.title("image")

plt.subplot(1, 2, 2)
plt.imshow(Image.fromarray(mask))
plt.title("mask")

In [None]:
import pathlib

import torch
import torch.utils.data

from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

In [None]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.RandomPhotometricDistort(p=1),
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
        v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

dataset = datasets.CocoDetection(TRAIN_PATH, TRAIN_ANN, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])

In [None]:
import torch.optim as optim

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bounding
    # boxes varies between the images of the same batch.
    collate_fn=lambda batch: tuple(zip(*batch)),
)

model = models.get_model("maskrcnn_resnet50_fpn_v2", weights="DEFAULT", weights_backbone="DEFAULT").train()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

count = 0
for imgs, targets in data_loader:
    if count == 30:
        break
    else:
        count += 1
        
    optimizer.zero_grad()
    loss_dict = model(imgs, targets)
    losses = sum(loss for loss in loss_dict.values())
    losses.backward()
    optimizer.step()

    print(f"{[img.shape for img in imgs] = }")
    print(f"{[type(target) for target in targets] = }")
    for name, loss_val in loss_dict.items():
        print(f"{name:<20}{loss_val:.3f}")

In [None]:
model.eval()

In [None]:
with open(TEST_ANN, "r") as f:
    raw_test_ann = f.read()
    test_coco = json.loads(raw_test_ann)

In [None]:
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
image_tensor = image_tensor.unsqueeze(0)

with torch.no_grad():
    preds = model(image_tensor)
pred_masks = preds[0]["masks"]

In [None]:
pred_mask = pred_masks[0, 0]
pred_mask = pred_mask.cpu().numpy()
pred_mask = (pred_mask > 0.5).astype(np.uint8)
# pred_mask = pred_mask * 255

In [None]:
plt.imshow(Image.fromarray(pred_mask))

In [None]:
image, mask = image_id_to_image_and_mask(test_coco, TEST_PATH, 0)
mask = Image.fromarray(mask)
pred_mask = Image.fromarray(pred_mask)

plt.subplot(1, 3, 1)
plt.imshow(image)
plt.title("image")

plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.title("mask")

plt.subplot(1, 3, 3)
plt.imshow(pred_mask)
plt.title("mask")

In [None]:
import cv2

In [None]:
contours, hierarchy = cv2.findContours(np.array(mask).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

cv.drawContours(image, contours, -1, (255,0,0), 3)
plt.imshow(image)
plt.savefig("original.png")

In [None]:
contours, hierarchy = cv2.findContours(np.array(pred_mask).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

cv.drawContours(image, contours, -1, (0,255,0), 3)
plt.imshow(image)
plt.savefig("pred.png")