# MaskRCNN for Spines Segmentation

### Set Proxy Environment Variables

In [None]:
import os

# Set proxy environment variables
os.environ['http_proxy'] = 'http://proxy:80'
os.environ['https_proxy'] = 'http://proxy:80'

### Import Required Libraries

In [None]:
import os
import tqdm
import random
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
from torchvision import models, transforms
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

### Define Custom Spine Dataset Class

In [None]:
class SpineDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = sorted([f for f in os.listdir(os.path.join(root, "input_images")) if os.path.isfile(os.path.join(root, "input_images", f))])
        self.masks = sorted([f for f in os.listdir(os.path.join(root, "spine_images")) if os.path.isfile(os.path.join(root, "spine_images", f))])

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "input_images", self.imgs[idx])
        mask_path = os.path.join(self.root, "spine_images", self.masks[idx])
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        img = np.array(img)
        mask = np.array(mask)

        if self.transforms:
            transformed = self.transforms(image=img, mask=mask)
            img = transformed['image']
            mask = transformed['mask']

        obj_ids = np.unique(mask)[1:]

        boxes, masks = [], []
        for obj_id in obj_ids:
            mask_obj = mask == obj_id
            pos = np.where(mask_obj)
            if pos[0].size > 0 and pos[1].size > 0:
                xmin, xmax = np.min(pos[1]), np.max(pos[1])
                ymin, ymax = np.min(pos[0]), np.max(pos[0])
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    masks.append(mask_obj)

        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            masks = torch.zeros((0, mask.shape[0], mask.shape[1]), dtype=torch.uint8)
            area = torch.zeros((0,), dtype=torch.float32)
            iscrowd = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.ones((len(boxes),), dtype=torch.int64)
            masks = np.array(masks, dtype=np.uint8)
            masks = torch.as_tensor(masks, dtype=torch.uint8)
            area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
            iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([idx]),
            "area": area,
            "iscrowd": iscrowd
        }

        return img, target

    def __len__(self):
        return len(self.imgs)

### Define Transformations for Dataset

In [None]:
# Define transforms
train_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

### Create Dataset and DataLoader

In [None]:
# Collate function for DataLoader
def collate_fn(batch):
    return tuple(zip(*batch))

# Dataset and DataLoader
root_train = 'Dataset/DeepD3_Training'
root_val = 'Dataset/DeepD3_Validation'

train_loader = DataLoader(
    SpineDataset(root_train, transforms=train_transform), 
    batch_size=1, 
    shuffle=True, 
    num_workers=2, 
    collate_fn=collate_fn
)


val_loader = DataLoader(
    SpineDataset(root_val, transforms=val_transform), 
    batch_size=1, 
    shuffle=False, 
    num_workers=2, 
    collate_fn=collate_fn
)

### Initialize Model

In [None]:
# Model initialization function
def get_model_instance_segmentation(num_classes):
    weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=weights)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 512
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

# Set device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2  # Background and spine
model = get_model_instance_segmentation(num_classes)
model.to(device)
print('Model Loaded')

### Training Function

In [None]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, num_epochs):
    model.train()
    train_epoch_loss = 0
    loss_components = {'loss_box_reg': 0, 'loss_classifier': 0, 'loss_mask': 0}
    progress_bar = tqdm.tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")

    for images, targets in progress_bar:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        train_epoch_loss += losses.item()
        for k in loss_components.keys():
            if k in loss_dict:
                loss_components[k] += loss_dict[k].item()

        progress_bar.set_postfix(loss=losses.item())

    num_batches = len(data_loader)
    avg_loss_components = {k: v / num_batches for k, v in loss_components.items()}
    return train_epoch_loss / num_batches, avg_loss_components

### Validation Function

In [None]:
def compute_loss(model, images, targets):
    model.train()
    loss_dict = model(images, targets)
    losses = sum(loss for loss in loss_dict.values())
    model.eval()
    return loss_dict, losses

def validate_one_epoch(model, data_loader, device, epoch, num_epochs):
    model.eval()
    val_epoch_loss = 0
    loss_components = {'loss_box_reg': 0, 'loss_classifier': 0, 'loss_mask': 0}
    val_progress_bar = tqdm.tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation")

    with torch.no_grad():
        for images, targets in val_progress_bar:
            images = list(img.to(device) for img in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict, losses = compute_loss(model, images, targets)

            val_epoch_loss += losses.item()
            for k in loss_components.keys():
                if k in loss_dict:
                    loss_components[k] += loss_dict[k].item()

            val_progress_bar.set_postfix(loss=losses.item())

    num_batches = len(data_loader)
    avg_loss_components = {k: v / num_batches for k, v in loss_components.items()}
    return val_epoch_loss / num_batches, avg_loss_components

### Setting Parameters

In [None]:
# Initialize optimizer and scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=0.00001)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

EPOCHS = 100
best_val_loss = np.inf
train_losses, val_losses = [], []
train_loss_components_list, val_loss_components_list = [], []

# Checkpointing setup
CHECKPOINT_DIR = 'checkpoint_spines'
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint.pth')
LOG_PATH = os.path.join(CHECKPOINT_DIR, 'training_log.log')
LOSS_PATH = os.path.join(CHECKPOINT_DIR, 'losses.npz')

# Create checkpoint directory if it doesn't exist
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def load_checkpoint():
    global best_val_loss
    start_epoch = 0
    if os.path.exists(CHECKPOINT_PATH):
        checkpoint = torch.load(CHECKPOINT_PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_val_loss = checkpoint['best_val_loss']
        start_epoch = checkpoint['epoch'] + 1
        if os.path.exists(LOSS_PATH):
            loaded_losses = np.load(LOSS_PATH)
            train_losses.extend(loaded_losses['train_losses'].tolist())
            val_losses.extend(loaded_losses['val_losses'].tolist())
    return start_epoch

def save_checkpoint(epoch):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss
    }, CHECKPOINT_PATH)
    np.savez(LOSS_PATH, train_losses=train_losses, val_losses=val_losses)

    formatted_train_loss_components = " | ".join([f"{k}: {v:.4f}" for k, v in train_loss_components_list[-1].items()])
    formatted_val_loss_components = " | ".join([f"{k}: {v:.4f}" for k, v in val_loss_components_list[-1].items()])

    with open(LOG_PATH, 'a') as logfile:
        logfile.write(f"Epoch {epoch + 1}/{EPOCHS}\n")
        logfile.write(f"Train Loss: {train_losses[-1]:.4f} | {formatted_train_loss_components}\n")
        logfile.write(f"Valid Loss: {val_losses[-1]:.4f} | {formatted_val_loss_components}\n\n")

### Training Loop

In [None]:
# Load checkpoint if exists
start_epoch = load_checkpoint()

# Training loop
for epoch in range(start_epoch, EPOCHS):
    avg_train_loss, train_loss_components = train_one_epoch(model, optimizer, train_loader, device, epoch, EPOCHS)
    avg_val_loss, val_loss_components = validate_one_epoch(model, val_loader, device, epoch, EPOCHS)

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_loss_components_list.append(train_loss_components)
    val_loss_components_list.append(val_loss_components)

    lr_scheduler.step(avg_val_loss)

    formatted_train_loss_components = " | ".join([f"{k}: {v:.4f}" for k, v in train_loss_components.items()])
    formatted_val_loss_components = " | ".join([f"{k}: {v:.4f}" for k, v in val_loss_components.items()])

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"Train Loss: {avg_train_loss:.4f} | {formatted_train_loss_components}")
    print(f"Val Loss: {avg_val_loss:.4f} | {formatted_val_loss_components}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'spines_model.pt')

    save_checkpoint(epoch)

### End of Script