In [1]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from PIL import Image
import torch
import numpy as np
import cv2
import requests
from bunch import Bunch
from ruamel.yaml import YAML

import torch.nn as nn
import torch.nn.functional as F
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from torchvision.models.segmentation import deeplabv3_resnet50
from torchsummary import summary
from torchvision.transforms import ToTensor

from utils.helpers import get_instance
import models



In [None]:


image_url = "https://farm1.staticflickr.com/6/9606553_ccc7518589_z.jpg"
image = np.array(Image.open(requests.get(image_url, stream=True).raw))
rgb_img = np.float32(image) / 255
input_tensor = preprocess_image(rgb_img,
                                mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
# Taken from the torchvision tutorial
# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html
model = deeplabv3_resnet50(pretrained=True, progress=False)
model = model.eval()

if torch.cuda.is_available():
    model = model.cuda()
    input_tensor = input_tensor.cuda()

output = model(input_tensor)
print(type(output), output.keys())

In [None]:
print(image.dtype)
print(rgb_img.dtype)

In [4]:
class SegmentationModelOutputWrapper(torch.nn.Module):
    def __init__(self, model): 
        super(SegmentationModelOutputWrapper, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x)["out"]
    
model = SegmentationModelOutputWrapper(model)
output = model(input_tensor)

In [None]:
normalized_masks = torch.nn.functional.softmax(output, dim=1).cpu()
sem_classes = [
    '__background__', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)}

car_category = sem_class_to_idx["car"]
car_mask = normalized_masks[0, :, :, :].argmax(axis=0).detach().cpu().numpy()
car_mask_uint8 = 255 * np.uint8(car_mask == car_category)
car_mask_float = np.float32(car_mask == car_category)

both_images = np.hstack((image, np.repeat(car_mask_uint8[:, :, None], 3, axis=-1)))
Image.fromarray(both_images)

In [None]:
print(normalized_masks.shape)
print(normalized_masks.dtype)
print(car_mask.shape)
print(car_mask.dtype)
print(car_mask_uint8.shape)
print(car_mask_uint8.dtype)

In [None]:
from pytorch_grad_cam import GradCAM

class SemanticSegmentationTarget:
    def __init__(self, category, mask):
        self.category = category
        self.mask = torch.from_numpy(mask)
        if torch.cuda.is_available():
            self.mask = self.mask.cuda()
        
    def __call__(self, model_output):
        return (model_output[self.category, :, : ] * self.mask).sum()

    
target_layers = [model.model.backbone.layer4]
targets = [SemanticSegmentationTarget(car_category, car_mask_float)]
with GradCAM(model=model,
             target_layers=target_layers) as cam:
    grayscale_cam = cam(input_tensor=input_tensor,
                        targets=targets)[0, :]
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

Image.fromarray(cam_image)
    

# Load trained model FR-Unet

- Image size (565, 584, 3)
##########################
- input: 
torch.Size([1, 1, 592, 592]) <br>
torch.float32<br>

- model:
'<class 'torch.nn.parallel.data_parallel.DataParallel'>'

- output:
'<class 'torch.Tensor'>'<br>
torch.Size([1, 1, 592, 592])<br>
torch.float32
torch.float32

Inference Pipeline:
1. Padding input, make sure input in shape 592
2. 

In [2]:
yaml = YAML(typ='safe', pure=True)

with open("config.yaml", encoding="utf-8") as file:
    CFG = Bunch(yaml.load(file))

In [3]:
class ModelInference:
    def __init__(self, model, weight_path, device=None, input_size=592):
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.model = nn.DataParallel(model.to(self.device)) if torch.cuda.is_available() else model.to(self.device)
        self.input_size = input_size
        self.checkpoint = torch.load(weight_path)
        self.model.load_state_dict(self.checkpoint['state_dict'])
        self.model.eval()  # Set the model to evaluation mode

    def _read_image(self, image_path: str) -> np.ndarray:
        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise FileNotFoundError(f"Image not found at {image_path}")
        return img

    def _resize_image(self, raw_image):
        return cv2.resize(raw_image, (self.input_size, self.input_size), interpolation=cv2.INTER_LINEAR)

    def _pad_image(self, raw_image):
        """
        Pad image to the input_size if the image dimensions are less than input_size.
        """
        h, w = raw_image.shape
        pad_h = self.input_size - h 
        pad_w = self.input_size - w 
        return F.pad(torch.tensor(raw_image), (0, pad_w, 0, pad_h), mode="constant", value=0).numpy()

    def preprocess_image(self, raw_image):
        """
        Preprocesses the input image by either resizing or padding it to match the input_size.
        """
        h, w = raw_image.shape
        
        if h < self.input_size or w < self.input_size:
            preproc_image = self._pad_image(raw_image)
        elif h > self.input_size or w > self.input_size:
            preproc_image = self._resize_image(raw_image)

        return ToTensor()(preproc_image).unsqueeze(0).to(self.device)

    def predict(self, image_tensor):
        """
        Run inference on the preprocessed image tensor using the model.
        """
        with torch.no_grad():
            prediction = self.model(image_tensor)
            return torch.sigmoid(prediction).squeeze().cpu().numpy()

    def postprocess_output(self, prediction, threshold=0.5):
        """
        Postprocess the prediction to obtain a binary mask.
        """
        return (prediction >= threshold).astype(np.uint8)

    def save_output(self, binary_mask, output_path):
        """
        Save the binary mask as an image.
        """
        cv2.imwrite(output_path, binary_mask * 255)

    def infer(self, image_path, output_path=None, threshold=0.5):
        """
        Run the full inference pipeline: read image, preprocess, predict, postprocess, and save output.
        """
        # Read the raw image
        raw_image = self._read_image(image_path)

        # Preprocess the image
        image_tensor = self.preprocess_image(raw_image)
        print(type(image_tensor))
        print(image_tensor.shape)
        # Make a prediction
        prediction = self.predict(image_tensor)
        print(type(prediction))
         

        # Postprocess the prediction into a binary mask
        binary_mask = self.postprocess_output(prediction, threshold)

        # Save the output mask if output_path is provided
        if output_path:
            self.save_output(binary_mask, output_path)

        return binary_mask

In [4]:
# Initialize model
model = get_instance(models, 'model', CFG)
checkpoint = "pretrained_weights/DRIVE/checkpoint-epoch40.pth"

# Initialize inference class
inference = ModelInference(model, checkpoint)

# Run inference on a single image
image_path = r'C:\\Users\\UCL\\Desktop\\Do\\datasets\\DRIVE\\DRIVE\\test\\images\\01_test.tif'
output_path = "output_mask.png"
binary_mask = inference.infer(image_path, output_path)

# Print result shape
print(f"Processed {image_path}, result shape: {binary_mask.shape}")

<class 'torch.Tensor'>
torch.Size([1, 1, 592, 592])
<class 'numpy.ndarray'>
Processed C:\\Users\\UCL\\Desktop\\Do\\datasets\\DRIVE\\DRIVE\\test\\images\\01_test.tif, result shape: (592, 592)


In [5]:
print(type(binary_mask))

<class 'numpy.ndarray'>


In [None]:
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

class ModelEvaluation:
    def __init__(self, threshold=0.5):
        self.threshold = threshold

    def binarize(self, prediction, ground_truth):
        """
        Binarizes the prediction and ground truth using a threshold.
        """
        prediction_bin = (prediction >= self.threshold).astype(np.uint8)
        ground_truth_bin = (ground_truth >= self.threshold).astype(np.uint8)
        return prediction_bin, ground_truth_bin

    def accuracy(self, prediction, ground_truth):
        """
        Calculates accuracy.
        """
        prediction_bin, ground_truth_bin = self.binarize(prediction, ground_truth)
        return accuracy_score(ground_truth_bin.flatten(), prediction_bin.flatten())

    def sensitivity(self, prediction, ground_truth):
        """
        Calculates sensitivity (True Positive Rate).
        """
        prediction_bin, ground_truth_bin = self.binarize(prediction, ground_truth)
        cm = confusion_matrix(ground_truth_bin.flatten(), prediction_bin.flatten())
        tp = cm[1, 1]
        fn = cm[1, 0]
        return tp / (tp + fn) if (tp + fn) > 0 else 0

    def specificity(self, prediction, ground_truth):
        """
        Calculates specificity (True Negative Rate).
        """
        prediction_bin, ground_truth_bin = self.binarize(prediction, ground_truth)
        cm = confusion_matrix(ground_truth_bin.flatten(), prediction_bin.flatten())
        tn = cm[0, 0]
        fp = cm[0, 1]
        return tn / (tn + fp) if (tn + fp) > 0 else 0

    def f1_score(self, prediction, ground_truth):
        """
        Calculates F1 score.
        """
        prediction_bin, ground_truth_bin = self.binarize(prediction, ground_truth)
        return f1_score(ground_truth_bin.flatten(), prediction_bin.flatten())

    def auc(self, prediction, ground_truth):
        """
        Calculates AUC (Area Under Curve).
        """
        prediction_prob = prediction.flatten()  # Flatten prediction to 1D for AUC calculation
        ground_truth_bin = ground_truth.flatten()  # Flatten ground truth to 1D
        return roc_auc_score(ground_truth_bin, prediction_prob)

    def iou(self, prediction, ground_truth):
        """
        Calculates Intersection over Union (IoU).
        """
        prediction_bin, ground_truth_bin = self.binarize(prediction, ground_truth)
        intersection = np.sum(prediction_bin & ground_truth_bin)
        union = np.sum(prediction_bin | ground_truth_bin)
        return intersection / union if union > 0 else 0

    def evaluate(self, prediction, ground_truth):
        """
        Evaluates the model's prediction against the ground truth.
        Returns:
            dict: Dictionary containing all evaluation metrics.
        """
        metrics = {
            "Accuracy": self.accuracy(prediction, ground_truth),
            "Sensitivity": self.sensitivity(prediction, ground_truth),
            "Specificity": self.specificity(prediction, ground_truth),
            "F1 Score": self.f1_score(prediction, ground_truth),
            "AUC": self.auc(prediction, ground_truth),
            "IoU": self.iou(prediction, ground_truth),
        }
        return metrics



def load_mask_from_gif(gif_path):
    """
    Load mask ground truth from .gif file.
    """
    img = Image.open(gif_path)
    img = img.convert("L")  # Convert to Grayscale
    return np.array(img)

ground_truth_path = 'C:\\Users\\UCL\\Desktop\\Do\\datasets\\DRIVE\\DRIVE\\test\\1st_manual\\01_manual1.gif'
ground_truth = load_mask_from_gif(ground_truth_path)

evaluator = ModelEvaluation(threshold=0.5)

# Đánh giá kết quả
metrics = evaluator.evaluate(binary_mask, ground_truth)

# In kết quả đánh giá
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")
