In [1]:
import sys

import numpy as np
from matplotlib import pyplot as plt
import cv2
import time
import itertools
from scipy.ndimage import label
import warnings
from sklearn.metrics import precision_recall_curve, auc

from helpers import *
from gaussian_diffusion import GaussianDiffusionModel, get_beta_schedule
from unet import UNetModel
import data_loader
from anomaly_detection import *
import skimage.exposure

from torchvision import transforms
import random
import lpips
from torch.cuda.amp import autocast
import cv2
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def normalize_image(image, low, high):
    """
    Normalize the input image to the range [low, high].

    :param image: Input image as a NumPy array.
    :param low: The lower bound of the target range.
    :param high: The upper bound of the target range.

    :return: The normalized image as a NumPy array.
    """
    # Find the minimum and maximum pixel values in the image
    min_val = np.min(image)
    max_val = np.max(image)

    # Normalize the image to the range [low, high]
    normalized_image = (image - min_val) / (max_val - min_val) * (high - low) + low

    # Convert the image to the appropriate data type (uint8 for an image)
    normalized_image = normalized_image

    return normalized_image

def create_gaussian_blur_difference_map(x_0, x_pred, kernel_size=3, threshold=5.0):
    """
    Create a difference map between the input image and the predicted input image with Gaussian blur.

    :param x_0: input image
    :param x_pred: predicted input image
    :param kernel_size: kernel size for Gaussian blur
    :param threshold: threshold for anomaly
    :return: difference map
    """

    x_0_array = x_0.cpu().squeeze().numpy()
    x_0_blurred = cv2.GaussianBlur(x_0_array, (kernel_size, kernel_size), 0)
    x_pred_array = x_pred.cpu().detach().numpy().squeeze()
    x_pred_blurred = cv2.GaussianBlur(x_pred_array, (kernel_size, kernel_size), 0)

    diff = abs(x_0_blurred - x_pred_blurred)
    diff[diff < threshold] = 0

    diff_final = remove_small_spots(diff)
    return diff_final


def remove_small_spots(map, threshold=30):
    """
    Remove too small spots from the difference map.

    :param map: difference map
    :param threshold: threshold for the size of the spots
    :return: difference map with removed small spots
    """

    binary_map = map > 0
    labeled_map, num_features = label(binary_map)
    component_sizes = np.bincount(labeled_map.ravel())
    large_components_masked = component_sizes[labeled_map] >= threshold
    return large_components_masked * map

def get_dice_score(diff_truth, diff_pred):
    """
    Calculate the Dice score between the ground truth and predicted anomaly.

    :param diff_truth: ground truth anomaly
    :param diff_pred: predicted anomaly
    :return: Dice score
    """

    if diff_truth.sum() == 0 and diff_pred.sum() == 0:
        return 1.0
    dice_score = 2 * (diff_truth & diff_pred).sum() / (diff_truth.sum() + diff_pred.sum())
    return round(dice_score, 4)

def get_precision_score(diff_truth, diff_pred):
    """
    Calculate the precision score between the ground truth and predicted anomaly.

    :param diff_truth: ground truth anomaly
    :param diff_pred: predicted anomaly
    :return: precision score
    """

    if diff_truth.sum() == 0 and diff_pred.sum() == 0:
        return 1.0
    true_positives = np.sum(diff_truth & diff_pred)
    false_positives = np.sum(diff_pred) - true_positives
    if true_positives + false_positives != 0:
        precision_score = true_positives / (true_positives + false_positives)
    else:
        precision_score = 0.0
    return round(precision_score, 4)

def get_recall_score(diff_truth, diff_pred):
    """
    Calculate the recall score between the ground truth and predicted anomaly.

    :param diff_truth: ground truth anomaly
    :param diff_pred: predicted anomaly
    :return: recall score
    """

    if diff_truth.sum() == 0 and diff_pred.sum() == 0:
        return 1.0
    true_positives = np.sum(diff_truth & diff_pred)
    false_negatives = np.sum(diff_truth) - true_positives
    if true_positives + false_negatives != 0:
        recall_score = true_positives / (true_positives + false_negatives)
    else:
        recall_score = 0.0
    return round(recall_score, 4)

def enlarge_bounding_box(bbox, scale=1.1):
    """
    Enlarge a bounding box by a given scale factor.
    
    :param bbox: [x_min, y_min, x_max, y_max]
    :param scale: Scale factor (default is 1.1 for 10% increase)
    :return: Enlarged bounding box [x_min_new, y_min_new, x_max_new, y_max_new]
    """
    x_min, y_min, x_max, y_max = bbox

    # Compute width and height
    width = x_max - x_min
    height = y_max - y_min

    # Compute expansion amount (10% increase)
    width_increase = (scale - 1) * width / 2
    height_increase = (scale - 1) * height / 2

    # Apply changes
    x_min_new = x_min - width_increase
    y_min_new = y_min - height_increase
    x_max_new = x_max + width_increase
    y_max_new = y_max + height_increase

    return [int(x_min_new), int(y_min_new), int(x_max_new), int(y_max_new)]

In [3]:
pretrained_path = '../output/Ultrasound/model/args_us_27/best_model.pt'
# pretrained_path = '../output/BraTS/model/args_brats_12/best_model.pt'
# pretrained_path = '../output/LiTS/model/args_lits_14/best_model.pt'
betas = get_beta_schedule("cosine", 1000)

diffusion = GaussianDiffusionModel(
            128, betas, img_channels=1, loss_type="vlb",
            loss_weight="none", noise_fn="gaussian", noise_params=None, diffusion_mode="inference"
            )

model = UNetModel(128, in_channels=1, model_channels=128,
                num_res_blocks=2, attention_resolutions="32,16,8",
                dropout=0.0, channel_mult="", num_heads=2,
                num_head_channels=64,).to('cuda')

checkpoint = torch.load(pretrained_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.to('cuda')
model.eval()
print("model loaded")

model loaded


In [26]:
root_path = "/home/camp/Projects/Yuan/Data/Ultrasound_synomaly/unhealthy_selected"
folders = os.listdir(root_path)
folders.sort()
dice_list = []
precision_list = []
recall_list = []

noise_steps = [250]
kernel = 15
threshold = 0.3

for f in folders:
    img_files = os.listdir(os.path.join(root_path,f,"img"))
    mask_files = os.listdir(os.path.join(root_path,f,"plaque"))
    img_files.sort()
    mask_files.sort()
    try:
        b_boxes = np.load(os.path.join(root_path,f,"b_boxes.npy"))
    except FileNotFoundError:
        continue
    
    for i, (img_filename, mask_filename, bbox) in enumerate(zip(img_files, mask_files, b_boxes)):
        if i <5:
            continue
        img = cv2.imread(os.path.join(root_path,f,"img",img_filename))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        mask = cv2.imread(os.path.join(root_path,f,"plaque",mask_filename))
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        
        bbox = enlarge_bounding_box(bbox, scale = 1.1)
        
        x_min, y_min, x_max, y_max = bbox

        # Crop the image
        img = img[y_min:y_max, x_min:x_max]
        mask = mask[y_min:y_max, x_min:x_max]
        
        img = cv2.resize(img, (128, 128))
        img = normalize_image(img, -1, 1)
        
        mask = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
        
        x_0 = torch.from_numpy(img).float().view(1,1,128,128).to('cuda')
        device = 'cuda'

        with torch.no_grad():
            x_0_cache = x_0

            inference_images = []
            inference_images.append(x_0)

            difference_map_cache = np.ones_like(x_0.cpu().squeeze().numpy())
            iter_num = 0
            difference = 255*255

            while difference > 0.01:
                t = torch.tensor(noise_steps, device=x_0_cache.device).repeat(x_0_cache.shape[0])
                noise = diffusion.noise_fn(x_0_cache, None)
                x_t = diffusion.q_sample(x_0_cache, t, noise)
                x_pred = diffusion.p_sample(model, x_t, t)

                # x_0_cache=x_pred



                difference_map = create_gaussian_blur_difference_map(x_0, x_pred,
                                                                     kernel_size=3,
                                                                     threshold=0.1)
                difference_map = difference_map.astype(bool).astype(int)

                difference_map_t = torch.from_numpy(difference_map)
                difference_map_t = difference_map_t.view(1,1,128,128).to(device)
                x_0_cache = x_pred*difference_map_t+x_0*(1-difference_map_t)

                inference_images.extend([x_t, x_0_cache])


                nominator = np.abs(np.sum(difference_map_cache.astype(bool).astype(int))-np.sum(difference_map.astype(bool).astype(int)))
                denominator = np.max([np.sum(difference_map_cache.astype(bool).astype(int)),500])
                difference = nominator/denominator
                difference_map_cache = difference_map
                iter_num += 1
                if iter_num>=5:
                    break

        anomaly_map = create_gaussian_blur_difference_map(x_0,x_pred,kernel_size=kernel, threshold=threshold)
        
        predicted_anomaly = anomaly_map.astype(bool)
        groundtruth_anomaly = mask.astype(bool)
        
        dice = get_dice_score(groundtruth_anomaly, predicted_anomaly)
        precision = get_precision_score(groundtruth_anomaly, predicted_anomaly)
        recall = get_recall_score(groundtruth_anomaly, predicted_anomaly)
        
        dice_list.append(dice)
        precision_list.append(precision)
        recall_list.append(recall)

In [27]:
print(np.array(dice_list).mean())
print(np.array(precision_list).mean())
print(np.array(recall_list).mean())

0.6765760948905108
0.8702777372262773
0.5971979927007299


In [17]:
print(np.array(dice_list).mean())
print(np.array(precision_list).mean())
print(np.array(recall_list).mean())

0.6754642335766423


In [28]:
print(np.array(dice_list).mean(), np.array(dice_list).std())
print(np.array(precision_list).mean())
print(np.array(recall_list).mean())

0.6765760948905108 0.18851480027108342
0.8702777372262773
0.5971979927007299
