## Part 1: Import Libraries


In [None]:
# Import Required Libraries
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CocoDetection
from torchvision.transforms import Compose, ToTensor
import os
import numpy as np
from PIL import Image
import torch.cuda.amp as amp

# Set PyTorch CUDA memory configuration
# helps with reducing GPU memory fragmentation, therefore preventing out-of-memory errors
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True'

## Part 2: Dataset Preparation


In [None]:
# Dataset Preparation
# Set the root directory for your dataset
root_dir = "/workspaces/ai-projects/Dataset/Fish_COCO"  # Change this to your dataset path if different

train_images_dir = os.path.join(root_dir, "train")
train_annotations_file = os.path.join(root_dir, "train", "_annotations.coco.json")
val_images_dir = os.path.join(root_dir, "valid")
val_annotations_file = os.path.join(root_dir, "valid", "_annotations.coco.json")

# Check if directories and files exist
print("Checking dataset directories and files...")
for dir_path in [train_images_dir, val_images_dir]:
    if os.path.exists(dir_path):
        print(f"✓ {dir_path} exists")
    else:
        print(f"✗ {dir_path} does not exist - please create and populate it")

for file_path in [train_annotations_file, val_annotations_file]:
    if os.path.exists(file_path):
        print(f"✓ {file_path} exists")
    else:
        print(f"✗ {file_path} does not exist - please export from Roboflow")

# List number of files (optional check)
if os.path.exists(train_images_dir):
    num_train_images = len([f for f in os.listdir(train_images_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
    print(f"Number of training images: {num_train_images}")
if os.path.exists(val_images_dir):
    num_val_images = len([f for f in os.listdir(val_images_dir) if f.endswith(('.jpg', '.jpeg', '.png'))])
    print(f"Number of validation images: {num_val_images}")

## Part 3: Defining Custom Dataset Class (Obsolete)


In [None]:
# Define the Custom Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    '''
    NOTE:
    for now since im using the annotations theres no need for the custom dataset class.
    '''
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # List and sort image files
        self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
        # List and sort mask files
        self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.root, "images", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        img = F.to_tensor(img)

        # Load mask
        mask_path = os.path.join(self.root, "masks", self.masks[idx])
        mask = Image.open(mask_path)
        mask = torch.as_tensor(np.array(mask), dtype=torch.uint8)

        # Get unique object IDs (assuming 0 is background)
        obj_ids = torch.unique(mask)
        obj_ids = obj_ids[obj_ids != 0]  # Remove background

        # Create binary masks for each object
        masks = mask == obj_ids[:, None, None]

        # Compute bounding boxes from masks
        boxes = []
        for i in range(len(obj_ids)):
            pos = torch.where(masks[i])
            xmin = torch.min(pos[1])
            xmax = torch.max(pos[1])
            ymin = torch.min(pos[0])
            ymax = torch.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        # Labels: assuming all objects are the same class (fish), label=1
        # For multiple classes, you'd need to encode class info in masks or separate files
        labels = torch.ones((len(obj_ids),), dtype=torch.int64)

        # Convert masks to uint8
        masks = masks.to(torch.uint8)

        # Target dictionary for Mask R-CNN
        target = {"boxes": boxes, "labels": labels, "masks": masks}

        return img, target

    def __len__(self):
        return len(self.imgs)

## Part 4: Define Data Transforms

In [None]:
# Define Data Transforms
def get_transforms(train):
    """
    Define transforms for training and validation.
    For now, no augmentations are applied.
    ToTensor is handled in the dataset using F.to_tensor.
    """
    transforms = []
    if train:
        # Add training augmentations here if needed, e.g.:
        # transforms.append(RandomHorizontalFlip(0.5))
        # transforms.append(RandomCrop(...))
        pass
    # Note: Normalization or resizing can be added here if required
    return Compose(transforms) if transforms else None

## Part 5: Load and Modify the Pre-trained Model


In [None]:
# Load and Modify the Pre-trained Mask R-CNN Model
# Load pre-trained Mask R-CNN with ResNet-50 backbone
# pre-trained is set to True so that COCO dataset training does not have to start from scratch
model = maskrcnn_resnet50_fpn(pretrained=True)

# Number of classes: background + number of fish classes
# For your project, if you have multiple species (e.g., Bass, Tilapia), set accordingly
# For now, assuming 2 fish classes + background = 3
num_classes = 4  # Adjust based on your classes

# Modify the box predictor head
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

# Modify the mask predictor head
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256  # Standard value
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

print(f"Model loaded with {num_classes} classes")

## Part 6: Setting up Data loaders

In [None]:
# Setting up Data Loaders
# Create training and validation datasets using CocoDetection
# CocoDetection(root, annFile, transforms=None) - root is images folder, annFile is annotations.json
train_dataset = CocoDetection(root=train_images_dir, annFile=train_annotations_file, transforms=get_transforms(train=True))
val_dataset = CocoDetection(root=val_images_dir, annFile=val_annotations_file, transforms=get_transforms(train=False))

# Create data loaders
# Batch size: reduce for GPU memory (try 2-4 for Mask R-CNN)
batch_size = 4
# num_workers: reduce to avoid memory issues
num_workers = 2

data_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    collate_fn=lambda x: tuple(zip(*x))  # Custom collate for variable-sized targets
)

val_data_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    collate_fn=lambda x: tuple(zip(*x))
)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Batch size: {batch_size}")

## Part 7: Training the Model

In [None]:
# Training the Model
import torch.optim as optim

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device: {device}")

# Clear GPU cache
torch.cuda.empty_cache()

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

# Mixed precision scaler
scaler = amp.GradScaler()

# Number of epochs
num_epochs = 50  # Adjust as needed

# Function to process COCO target from list of annotations to dict
def process_target(target):
    if isinstance(target, dict):
        return target
    boxes = []
    labels = []
    masks = []
    for ann in target:
        bbox = ann['bbox']
        boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
        labels.append(ann['category_id'])
        # For masks, if segmentation is present, you can add conversion here
        # For now, use placeholder
    boxes = torch.tensor(boxes, dtype=torch.float32).view(-1, 4)
    labels = torch.tensor(labels, dtype=torch.int64)
    masks = torch.empty(len(labels), 1, 1, dtype=torch.uint8)  # placeholder
    return {'boxes': boxes, 'labels': labels, 'masks': masks}

# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (images, targets) in enumerate(data_loader):
        # Process targets
        targets = [process_target(t) for t in targets]
        # Convert PIL images to tensors and move to device
        images = [F.to_tensor(img).to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass with mixed precision
        with amp.autocast():
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        # Backward pass with scaler
        optimizer.zero_grad()
        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += losses.item()

        # Optional: print batch loss
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {losses.item():.4f}")

    # Average loss per epoch
    avg_epoch_loss = epoch_loss / len(data_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_epoch_loss:.4f}")

print("Training completed!")

## Part 8: Save the Model

In [None]:
# Save the Trained Model
model_save_path = "mask_rcnn_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

## Part 9: Evaluation

In [None]:
# Evaluation
# Load the saved model for inference
model.load_state_dict(torch.load(model_save_path))
model.eval()

# Example: Run inference on a validation image
with torch.no_grad():
    # Get a sample from validation set
    img, target = val_dataset[0]
    target = process_target(target)
    img = F.to_tensor(img).unsqueeze(0).to(device)  # Convert to tensor and add batch dimension

    # Run model
    predictions = model(img)

    # Print predictions
    print("Predictions for sample image:")
    print(f"Boxes: {predictions[0]['boxes']}")
    print(f"Labels: {predictions[0]['labels']}")
    print(f"Scores: {predictions[0]['scores']}")
    print(f"Masks shape: {predictions[0]['masks'].shape}")

# For full evaluation, you could loop over val_data_loader and compute mAP, etc.
# But that requires additional libraries like pycocotools for COCO metrics

In [None]:
import matplotlib.pyplot as plt
plt.imshow(img.squeeze(0).permute(1, 2, 0).cpu())  # Convert tensor back to image
plt.show()