In [1]:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from monai.metrics import DiceMetric, MeanIoU, SurfaceDiceMetric, SSIMMetric, GeneralizedDiceScore
from segment_anything.utils.transforms import ResizeLongestSide
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from monai.losses import GeneralizedDiceLoss, DiceLoss, GeneralizedDiceFocalLoss
from monai.metrics import DiceMetric, GeneralizedDiceScore
from LinearWarmupCosine import LinearWarmupCosineAnnealingLR



PyTorch version: 1.13.0
Torchvision version: 0.14.0
CUDA is available: True


In [2]:
kv_folder = 'D:\Yuheng Li\Segment Anything\\TestDataset\\TestDataset\\Kvasir'
cvc_db_folder =  'D:\Yuheng Li\Segment Anything\\TestDataset\\TestDataset\\CVC-ClinicDB'
cvc_colon_folder = 'D:\Yuheng Li\Segment Anything\\TestDataset\\TestDataset\\CVC-ColonDB'
cvc_300_folder = 'D:\Yuheng Li\Segment Anything\\TestDataset\\TestDataset\\CVC-300'
etis_folder = 'D:\Yuheng Li\Segment Anything\\TestDataset\\TestDataset\\ETIS-LaribPolypDB'


image_path = []
mask_path = []

for root, dirs, files in os.walk(etis_folder, topdown=False): #finds MRI files
    for name in files:
        if name.endswith(".png"):
            apath=os.path.join(root, name)
            if 'images' in apath:
                image_path.append(apath)
            if 'masks' in apath:
                mask_path.append(apath)
                
print(image_path[1],mask_path[1], len(image_path), len(mask_path))



D:\Yuheng Li\Segment Anything\TestDataset\TestDataset\ETIS-LaribPolypDB\images\10.png D:\Yuheng Li\Segment Anything\TestDataset\TestDataset\ETIS-LaribPolypDB\masks\10.png 196 196


In [3]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# sam_checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"

# sam_checkpoint = "sam_vit_b_01ec64.pth"
# model_type = "vit_b"

sam_checkpoint = "sam_vit_l_0b3195.pth"
model_type = "vit_l"


device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

transform = ResizeLongestSide(sam.image_encoder.img_size)


In [4]:
def extract_bboxes(mask, num_instances):

    """Compute bounding boxes from masks.

    mask: [height, width, num_instances]. Mask pixels are either 1 or 0.

 

    Returns: bbox array [num_instances, (y1, x1, y2, x2)].

    """

    boxes = np.zeros([num_instances, 4], dtype=np.int32)

    for i in range(num_instances):

        m = mask

        # Bounding box.

        horizontal_indicies = np.where(np.any(m, axis=0))[0]

#         print("np.any(m, axis=0)",np.any(m, axis=0))

#         print("p.where(np.any(m, axis=0))",np.where(np.any(m, axis=0)))

        vertical_indicies = np.where(np.any(m, axis=1))[0]

        if horizontal_indicies.shape[0]:

            x1, x2 = horizontal_indicies[[0, -1]]

            y1, y2 = vertical_indicies[[0, -1]]

            # x2 and y2 should not be part of the box. Increment by 1.

            x2 += 1

            y2 += 1

        else:

            # No mask for this instance. Might happen due to

            # resizing or cropping. Set bbox to zeros

            x1, x2, y1, y2 = 0, 0, 0, 0

        boxes[i] = np.array([y1, x1, y2, x2])

    return boxes.astype(np.int32)

In [5]:
# coco mask style dataloader

class ColonDataset(Dataset):
    def __init__(self, image_path, mask_path, image_size):
        self.image_path = image_path
        self.mask_path = mask_path
        self.image_size = image_size
        
        # TODO: use ResizeLongestSide and pad to square
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, index):
        idx = self.image_path[index].split('images\\')[1].split('.png')[0]

        image = cv2.imread(self.image_path[index])
        gt = cv2.imread(self.mask_path[index])
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY) / 255
        gt = gt.astype('float32')

        bbox_arr = extract_bboxes(gt, 1)

        gt_resized = cv2.resize(gt, (1024, 1024), cv2.INTER_NEAREST)
        gt_resized = torch.as_tensor(gt_resized > 0).long()
        
        gt = torch.from_numpy(gt)
        gt_binary_mask = torch.as_tensor(gt > 0).long()

        transform = ResizeLongestSide(self.image_size)
        input_image = transform.apply_image(image)
        input_image =  cv2.resize(input_image, (1024, 1024), cv2.INTER_CUBIC)
        input_image= self.to_tensor(input_image)
        
        # input_image= self.normalize(input_image)
#         print(input_image.shape)
#         plt.figure()
#         plt.imshow(input_image[0])
#         print('before preprcoess', torch.max(input_image[0]), torch.min(input_image[0]))
        # input_image = sam.preprocess(input_image.to('cuda:0')).detach().cpu()
#         print('after preprcoess', torch.max(input_image[0]), torch.min(input_image[0]))
#         input_image = cv2.resize(input_image.numpy(), (1024, 1024), cv2.INTER_CUBIC)

#         plt.figure()
#         plt.imshow(input_image[0])
        
        original_image_size = image.shape[:2]
        input_size = tuple(input_image.shape[-2:])
        
        return input_image, np.array(bbox_arr), gt_binary_mask, gt_resized, original_image_size, input_size
    

def my_collate(batch):
    
    images, bboxes, masks, gt_resized, original_image_size, input_size = zip(*batch)
    images = torch.stack(images, dim=0)
    gt_resized = torch.stack(gt_resized, dim=0)
    
    masks = [m for m in masks]
    bboxes = [m for m in bboxes]
    original_image_size = [m for m in original_image_size]
    input_size = [m for m in input_size]
    
    return images, bboxes, masks, gt_resized, original_image_size, input_size

    

In [6]:


train_dataset = ColonDataset(image_path, mask_path, sam.image_encoder.img_size)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn = my_collate)



In [8]:


# model_path =  'D:\\Yuheng Li\\Segment Anything\\Model results\\CVC Clinical\\SAM Finetune Enc Dec'

# model_path  = 'D:\\Yuheng Li\\Segment Anything\\Model results\\All datasets\\SAM Finetune Enc Dec'

# model_path = 'D:\\Yuheng Li\\Segment Anything\\Model results\\All datasets\\SAM Finetune Enc Dec'

model_path = 'D:\\Yuheng Li\\Segment Anything\\Model results\\All datasets\\SAM_L Finetune Enc Dec\\SAM Finetune Enc Dec'


sam.prompt_encoder.load_state_dict(torch.load(os.path.join(model_path, "prompt_enc_best_dice_model_DL.pth")))
sam.image_encoder.load_state_dict(torch.load(os.path.join(model_path, "img_enc_best_dice_model_DL.pth")))
sam.mask_decoder.load_state_dict(torch.load(os.path.join(model_path, "dec_best_dice_model_DL.pth")))
sam.eval()

with torch.no_grad():
    batch_dice = []
    batch_gd = []
    batch_iou = []

    for batch in train_dataloader:

        img, bbox, mask, gt_resized, original_image_size, input_size = batch[0], batch[1], batch[2], batch[3], batch[4], batch[5]

        dice = DiceMetric()
        gd =  GeneralizedDiceScore()
        iou = MeanIoU()

        for i in range(len(mask)):
            image_embedding = sam.image_encoder(img[i].unsqueeze(0).to(device))

            orig_x, orig_y =  original_image_size[i][0], original_image_size[i][1]
            col_x1, col_x2 = bbox[i][:,1] * 1024/orig_y, bbox[i][:,3]* 1024/orig_y
            col_y1, col_y2 = bbox[i][:,0]* 1024/orig_x, bbox[i][:,2]* 1024/orig_x

            box = np.array([col_x1, col_y1, col_x2, col_y2]).transpose()

            num_masks = box.shape[0]
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            sparse_embeddings, dense_embeddings = sam.prompt_encoder(
              points=None,
              boxes= box_torch,
              masks = None
            )

            low_res_masks, iou_predictions = sam.mask_decoder(
              image_embeddings=image_embedding,
              image_pe=sam.prompt_encoder.get_dense_pe(),
              sparse_prompt_embeddings=sparse_embeddings,
              dense_prompt_embeddings=dense_embeddings,
              multimask_output=False
            )

            upscaled_masks = sam.postprocess_masks(low_res_masks, input_size[i], original_image_size[i])

            binary_mask = torch.sigmoid(upscaled_masks.detach().cpu())
            binary_mask =  (binary_mask>0.5).float()

            gt_binary_mask = mask[i].detach().cpu()

            if binary_mask.size()[0] > 1:
                binary_mask = torch.unsqueeze(torch.sum(binary_mask, 0) / binary_mask.size()[0],0)

            dice.reset()
            gd.reset()
            iou.reset()

            dice(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            gd(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            iou(binary_mask[0,:], gt_binary_mask.unsqueeze(0))
            final_dice = dice.aggregate().numpy()[0]
            final_gd = gd.aggregate().numpy()[0]
            final_iou = iou.aggregate().numpy()[0]
            batch_dice.append(final_dice)
            batch_gd.append(final_gd)
            batch_iou.append(final_iou)


    print(f'Mean val dice: {sum(batch_dice) / len(batch_dice)}')
    print(f'Mean val gd: {sum(batch_gd) / len(batch_gd)}')
    print(f'Mean val iou: {sum(batch_iou) / len(batch_iou)}')



Mean val dice: 0.9054498590376913
Mean val gd: 0.6419873968698084
Mean val iou: 0.8603473611328066


In [8]:
print(len(train_dataset
         ))

196
