# **Inference**
# Mask RCNN
Instance Segmentation with torchvision

---
### Imports

In [None]:
# general
import cv2  # Use OpenCV instead of Pillow
import numpy as np
import matplotlib.pyplot as plt

# Pytorch
import torch
import torchvision.transforms as T

---
### Mask RCNN

In [None]:
# Load Mask-RCNN model
def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = torchvision.models.detection.MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    
    return model

In [None]:
# Define transformation
def get_transform():
    return T.Compose([T.ToTensor()])

---
### Inference 

In [None]:
def plot_results(img, prediction, threshold=0.5):
    # Filter out objects with low scores
    masks = prediction['masks'][prediction['scores'] > threshold].cpu()

    plt.figure(figsize=(10, 10))
    plt.imshow(img)

    # Plot each mask
    for i in range(len(masks)):
        mask = masks[i, 0].mul(255).byte().cpu().numpy()
        plt.imshow(mask, alpha=0.5)

    plt.axis('off')
    plt.show()

In [None]:
# Perform inference
def run_inference(image_path, model, device):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
    transform = get_transform()
    img_tensor = transform(img).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        prediction = model(img_tensor)

    return img, prediction[0]

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load the model
model = get_model_instance_segmentation(num_classes=2)  # Background + 1 object class
model.to(device)

# Load trained weights (specify the path to your model checkpoint)
model.load_state_dict(torch.load('path/to/model.pth'))

# Inference
image_path = 'path/to/test_image.jpg'
img, prediction = run_inference(image_path, model, device)

# Plot the results
plot_results(img, prediction)