In [1]:
import os
import cv2
import torch
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt

import torchvision
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [2]:
def show_image_and_mask(dataset, idx):
    image, target_dict = dataset[idx]
    image = image
    mask = target_dict['masks'][0].numpy()

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(image, cmap='gray')
    axs[0].set_title('Image')
    axs[1].imshow(mask, cmap='gray')
    axs[1].set_title('Mask')
    plt.show()

In [3]:
def show_image_with_bbox(dataset, idx):
    image, target_dict = dataset[idx]
    bbox = target_dict['boxes'][0].numpy()
    xmin, ymin, xmax, ymax = bbox

    plt.imshow(image, cmap='gray')
    plt.gca().add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, edgecolor='red', linewidth=2))
    plt.title('Image with Bounding Box')
    plt.show()

In [4]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [5]:
class SketchDataset(Dataset):
    def __init__(self, image_dir, target_dir, transform=None):
        self.image_dir = image_dir
        self.target_dir = target_dir
        self.transform = transform
       
        self.images = sorted([img for img in os.listdir(image_dir) if img.endswith('.png')])
        self.targets = sorted([tgt for tgt in os.listdir(target_dir) if tgt.endswith('.png')])

        assert len(self.images) == len(self.targets), "The number of images and targets must be the same"
        assert all(img.split('.')[0] == tgt.split('.')[0] for img, tgt in zip(self.images, self.targets)), "Image and target filenames must match"

    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images[idx])

        target_path = os.path.join(self.target_dir, self.images[idx])
        
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image = image / 255
        
        
        target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)

        if image is None or target is None:
            raise RuntimeError(f"Failed to load image or target at index {idx}")
        
        
        target[target != 0] = 1  # Convert to binary mask

        bbox = self.get_bounding_box(target)

        # Convert bbox from (xmin, ymin, xmax, ymax) to [N, 4] tensor
        boxes = torch.as_tensor([bbox], dtype=torch.float32)

        # There is only one class (mask)
        labels = torch.ones((1,), dtype=torch.int64)

        # Convert mask to [N, H, W] tensor
        masks = torch.as_tensor(target, dtype=torch.uint8)
        masks = masks.unsqueeze(0)  # Add an extra dimension for N

        target_dict = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks
        }

        if self.transform:
            image = self.transform(image)

        return image, target_dict

    @staticmethod
    def get_bounding_box(mask):
        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]

        return (xmin, ymin, xmax, ymax)

In [6]:
IMAGE_DIR = "../../datasets/sketch-parse/images"
TARGET_DIR = "../../datasets/sketch-parse/masks"

In [7]:
dataset = SketchDataset(
    image_dir=IMAGE_DIR,
    target_dir=TARGET_DIR
)

In [8]:
dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
test_size = dataset_size-train_size
batch_size = 4

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,collate_fn=collate_fn)

In [9]:
def create_maskrcnn_resnet50_fpn():
    num_classes = 2 #Background and object
    
    # Load a model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='COCO_V1')

    # Replace the classifier with a new one for our number of classes
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor with a new one for our number of classes
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

In [10]:
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
model = create_maskrcnn_resnet50_fpn().to(device)
model.train()

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

In [12]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [13]:
start_epoch = 0
num_epochs = 100
train_losses = []
test_losses = []

if start_epoch > 0:
    checkpoint = torch.load(f"results/model_epoch_{start_epoch}.pth")
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

for epoch in range(start_epoch, num_epochs):
    model.train()  # Set the model to training mode
    total_train_loss = 0

    for images, targets in train_loader:
        images = list(torch.from_numpy(img.astype(np.float32)).unsqueeze(0).to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

        total_train_loss += losses.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    lr_scheduler.step()
    
    with torch.no_grad():
        total_test_loss = 0
        for images, targets in test_loader:
            images = list(torch.from_numpy(img.astype(np.float32)).unsqueeze(0).to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            test_loss_dict = model(images, targets)
            test_loss = sum(loss for loss in test_loss_dict.values())
            total_test_loss += test_loss.item()

    avg_test_loss = total_test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss} | Test Loss: {avg_test_loss}")


    if(epoch % 10 == 0):
        state = {
            'epoch': epoch + 1,  # next epoch number
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }   
        torch.save(state, f'results/model_epoch_{epoch+1}.pth')

Epoch [1/100], Training Loss: 0.6547335605056299 | Test Loss: 0.3278859766324361
Epoch [2/100], Training Loss: 0.2743396882204651 | Test Loss: 0.23017881910006205
Epoch [3/100], Training Loss: 0.22174706529675192 | Test Loss: 0.20594605505466462
Epoch [4/100], Training Loss: 0.1858733260648788 | Test Loss: 0.18831702669461567
Epoch [5/100], Training Loss: 0.17876692047353424 | Test Loss: 0.18608784159024558
Epoch [6/100], Training Loss: 0.17492632033852484 | Test Loss: 0.18537440657615661
Epoch [7/100], Training Loss: 0.17038921934331772 | Test Loss: 0.18532869617144268
Epoch [8/100], Training Loss: 0.17022478980997394 | Test Loss: 0.18530551294485728
Epoch [9/100], Training Loss: 0.17019036257198092 | Test Loss: 0.18482691089312236
Epoch [10/100], Training Loss: 0.17008321254239606 | Test Loss: 0.18455503424008687
Epoch [11/100], Training Loss: 0.1695884486843396 | Test Loss: 0.18515389502048493
Epoch [12/100], Training Loss: 0.16989651697047184 | Test Loss: 0.18494991600513458
Epoch 