In [1]:
from transformers import SamModel, SamConfig, SamProcessor
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
from PIL import Image,ImageDraw
import os
from sklearn.model_selection import train_test_split
from transformers import SamModel
from torch.optim import Adam
import monai
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize
import numpy as np
import random
import torch
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import matplotlib.patches as patches
import numpy as np
import torch
import cv2
from scipy.ndimage import label, binary_fill_holes
from torchvision.ops import nms
from torchvision.transforms.functional import to_pil_image, to_tensor, resize
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches



In [2]:
def generate_input_points(array_size=1024, grid_size=5):
    x = np.linspace(0, array_size-1, grid_size)
    y = np.linspace(0, array_size-1, grid_size)
    xv, yv = np.meshgrid(x, y)
    input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
    input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)  # Reshape as needed
    return input_points

In [3]:
def refine_masks(masks, kernel_size=3):

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
    refined_masks = np.zeros_like(masks)
    for i in range(masks.shape[0]):
        mask = masks[i]
        # Fill holes
        mask_filled = binary_fill_holes(mask).astype(np.uint8)
        # Open operation to remove noise
        mask_refined = cv2.morphologyEx(mask_filled, cv2.MORPH_OPEN, kernel)
        refined_masks[i] = mask_refined
    return refined_masks


def extract_bounding_boxes(masks, iou_threshold=0.1, min_area=1, max_area=1000, aspect_ratio_range=(0.5, 2)):

    boxes = []
    scores = []
    for i in range(masks.shape[0]):
        mask = masks[i].astype(np.uint8) * 255

        #Ensure the mask is 2D
        if mask.ndim > 2:
            mask = mask.squeeze()
        
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            area = cv2.contourArea(cnt)
            if area < min_area or area > max_area:
                continue
            
            x, y, w, h = cv2.boundingRect(cnt)
            aspect_ratio = w / float(h)
            if not (aspect_ratio_range[0] <= aspect_ratio <= aspect_ratio_range[1]):
                continue
            
            score = area  
        
            boxes.append([x, y, x+w, y+h])
            scores.append(score)

    if boxes:
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        scores = torch.as_tensor(scores, dtype=torch.float32)
        keep = nms(boxes, scores, iou_threshold)
        return boxes[keep].numpy(), scores[keep].numpy()
    else:
        return np.array([]), np.array([])


In [4]:
# Define paths
image_directory = '/Users/sahiljethani/Desktop/MLP/archive/preprocessed_train'

df_test = pd.read_csv('/Users/sahiljethani/Desktop/MLP/flattened_df_normalized.csv')

# Load the model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = SamModel.from_pretrained("facebook/sam-vit-base")
 # Create an instance of the model architecture with the loaded configuration
model = SamModel(config=model_config)
model.load_state_dict(torch.load("/Users/sahiljethani/Desktop/MLP/epochs_4_model_checkpoint.pth", map_location=torch.device('cpu')))


In [5]:
# df_test.nunique()
df_test['list_pred_bboxes_sam_grid_4epoch']=None
df_test

In [6]:
def load_and_preprocess_image(image_path, image_id, output_path):
    image = Image.open(image_path) # Load grayscale image
    image = to_tensor(image)  # Convert to PyTorch tensor
    image = image.repeat(3, 1, 1)  # Repeat the single channel to get 3 channels
    image = to_pil_image(image)  # Convert back to PIL Image\    
    image_save_path = os.path.join(output_path, image_id + '.png')
    image.save(image_save_path) 
    return image

In [7]:


mask_prob=0.6
iou_threshold=0.1
min_area=0.6 

test_image_directory='/Users/sahiljethani/Desktop/MLP/archive/preprocessed_test'


def makedataset(df,A,B):
    for image_id in df[A:B]['image_id']:
        image_path = os.path.join(test_image_directory, f"{image_id}.png")
        # image = Image.open(image_path)
        image=load_and_preprocess_image(image_path)
        input_points = generate_input_points(grid_size=5)

        inputs = processor(images=image, input_points=input_points, return_tensors="pt")

        # Move inputs to the correct device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Set model to evaluation mode
        model.eval()

        # Forward pass
        with torch.no_grad():
            outputs = model(**inputs, multimask_output=False)

      
        # Post-process the output
        pred_masks_prob_new =torch.sigmoid(outputs.pred_masks.squeeze(1)).cpu().numpy().squeeze(1)
        pred_masks_new = (pred_masks_prob_new > mask_prob).astype(np.uint8)

        refined_masks = refine_masks(pred_masks_new.astype(np.uint8))
        boxes, _ = extract_bounding_boxes(refined_masks,iou_threshold=iou_threshold, min_area=min_area, max_area=50000, aspect_ratio_range=(0.5, 2))

        for box in boxes:
            #divide each bbox by 256
            box[0]=box[0]/256
            box[1]=box[1]/256
            box[2]=box[2]/256
            box[3]=box[3]/256
        boxes = boxes.tolist()
        boxes=str(boxes)

        #storing the list of boxes in the dataframe
        df_test.loc[df_test['image_id'] == image_id, 'list_pred_bboxes_sam_grid_4epoch'] = boxes
    return df_test


In [8]:
A=0
B=3000
df_new=makedataset(df_test,A,B)
save_path = '/Users/sahiljethani/Desktop/MLP/testcsv/bbox_eval_sam_{}_{}.csv'.format(A,B)
df_new.to_csv(save_path, index=False)