<a href="https://colab.research.google.com/github/ttssibley/MSPP/blob/main/SAM_for_Microstructural_Analysis_Loops_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#If possible, use GPU
!nvidia-smi

In [None]:
import os
import torch
import torchvision
import cv2
from google.colab import files
import math
import matplotlib.pyplot as plt
import json
import numpy as np
from sklearn.metrics import f1_score, jaccard_score
from PIL import Image
import pandas as pd
import csv

In [None]:
HOME = os.getcwd()
print("HOME:", HOME)

#install segment-anything from github
!pip install -q 'git+https://github.com/facebookresearch/segment-anything'
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")

!mkdir -p {HOME}/data
!mkdir -p {HOME}/labels
!mkdir -p {HOME}/images
!mkdir -p {HOME}/masks

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

HOME: /content
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.7/88.7 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.9/49.9 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.2/207.2 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m68.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━

In [None]:
def calculate_iou(binary_array1, binary_array2):
    # Calculate the intersection and union
    intersection = np.logical_and(binary_array1, binary_array2)
    union = np.logical_or(binary_array1, binary_array2)
    # Calculate IoU
    iou = np.sum(intersection) / np.sum(union)
    return iou

def calculate_f1_score(ground_truth, predicted_labels):
    eps=1e-10
    ground_truth = np.array(ground_truth)
    predicted_labels = np.array(predicted_labels)

    true_positives = np.sum((ground_truth == 1) & (predicted_labels == 1))
    false_positives = np.sum((ground_truth == 0) & (predicted_labels == 1))
    false_negatives = np.sum((ground_truth == 1) & (predicted_labels == 0))

    precision = true_positives / (true_positives + false_positives + eps)
    recall = true_positives / (true_positives + false_negatives + eps)

    f1_score = 2 * (precision * recall) / (precision + recall + eps)
    return f1_score, precision, recall, true_positives, false_positives, false_negatives

def make_gt_mask(gt_segmentations, combined_mask):
  for segmentation in gt_segmentations:
    polygon_np = np.array(segmentation, dtype=np.int32)
    polygon_np = polygon_np.reshape((-1, 2))
    cv2.fillPoly(combined_mask, [polygon_np], color=(255, 255, 255))

  combined_mask[combined_mask==255]=1
  return combined_mask


def find_particle_centers(binary_array):
    contours, _ = cv2.findContours(binary_array.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    centers = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        x_c=x+w/2
        y_c=y+h/2
        centers.append((x_c, y_c))
    return centers

def parse_line(line):
    parts = line.strip().split()
    parts = [float(x) for x in parts]
    class_label = int(parts[0])
    class_label=1
    x, y, w, h = map(float, parts[1:])
    return class_label, x, y, w, h

def generate_bbs_from_txt(txt_file, image_width, image_height):
    # Read lines from the txt file
    with open(txt_file, 'r') as file:
        lines = file.readlines()
    bb_gt_yolo=[]

    mask = np.zeros((image_height, image_width), dtype=np.uint8)

    for line in lines:
        # Parse the line
        class_label, x, y, w, h = parse_line(line)
        class_label, x, y, w, h = parse_line(line)
        x,y,w,h=int(x*image_width),int(y*image_height),int(w*image_width),int(h*image_height)

        x1 = int(x-w/2)
        y1 = int(y-h/2)
        x2 = int(x+w/2)
        y2 = int(y+h/2)

        x1 = np.clip(x1, 0, image_width - 1)
        x2 = np.clip(x2, 0, image_width - 1)
        y1 = np.clip(y1, 0, image_height - 1)
        y2 = np.clip(y2, 0, image_height - 1)

        bb_gt_yolo.append((int(x1),int(y1),int(x2),int(y2)))

        mask[y1:y2, x1:x2] = class_label
    return(bb_gt_yolo, mask)

In [None]:
folder_path = '/content/images'

with open(os.path.join(HOME, "updated_test.json")) as f: #for updated_test.json, use the file provided by Jacobs. et al
    bbox_dictionary = json.load(f)

for filename in os.listdir(folder_path):
  IMAGE_NAME= filename.split(".")[0]
  IMAGE_PATH = os.path.join(HOME, "images", IMAGE_NAME+".jpg")
  image_bgr = cv2.imread(IMAGE_PATH)
  image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
  grayscale_image = cv2.cvtColor(image_bgr, cv2.COLOR_RGB2GRAY)
  bounding_boxes_gt= [item['bbox'] for item in bbox_dictionary[IMAGE_NAME]]
  segmentations_gt= [item['segmentation'] for item in bbox_dictionary[IMAGE_NAME]]
  bounding_boxes_gt_corrected=[]
  image_height, image_width, _ = image_rgb.shape


  for segmentation in segmentations_gt:
    x_coordinates = segmentation[0][::2]
    y_coordinates = segmentation[0][1::2]
    width = max(x_coordinates) - min(x_coordinates)
    height = max(y_coordinates) - min(y_coordinates)

#input thresholds
slope_cutoff=1 # example, set as needed
area_cutoff=50 # example, set as needed
aspect_ratio_cutoff_H=.8 # example, set as needed
aspect_ratio_cutoff_L=.1 # example, set as needed
circularity_cutoff=.7 # example, set as needed
bounding_boxes_gt_tensor=torch.tensor(bounding_boxes_gt_corrected)



In [None]:
folder_path = '/content/images'

# input points per side
_points_ = 50  # example, set as needed

with open(os.path.join(HOME, "updated_test.json")) as f: #refer to the Jacobs et al (2022) paper for .json files
    bbox_dictionary = json.load(f)

for filename in os.listdir(folder_path):
    if filename == ".ipynb_checkpoints":
        continue

    IMAGE_NAME = filename.split(".")[0]
    IMAGE_PATH = os.path.join(HOME, "images", IMAGE_NAME + ".jpg")
    image_bgr = cv2.imread(IMAGE_PATH)
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=_points_,
        pred_iou_thresh=0.80, # example, set as needed
        stability_score_thresh=0.91, # example, set as needed
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        randpoints=False,
        min_mask_region_area=50, # example, set as needed
    )

    bounding_boxes_gt = [item['bbox'] for item in bbox_dictionary[IMAGE_NAME]]
    image_height, image_width, _ = image_rgb.shape
    bounding_boxes_gt_tensor = torch.tensor(bounding_boxes_gt)

    sam_result = mask_generator.generate(image_rgb)

    bounding_boxes_predicted = []
    bb_dictionary = {}

    for count, result in enumerate(sam_result):
        x, y, w, h = result['bbox']
        x1, y1, x2, y2 = x, y, x + w, y + h
        area = result['area']
        segmentation = result['segmentation']
        pred_iou = result['predicted_iou']

        # geometric features
        true_indices = [(i, j) for i, row in enumerate(segmentation) for j, val in enumerate(row) if val]
        if not true_indices:
            continue
        x_coordinates = [coord[1] for coord in true_indices]
        y_coordinates = [coord[0] for coord in true_indices]
        width = max(x_coordinates) - min(x_coordinates)
        height = max(y_coordinates) - min(y_coordinates)
        aspect_ratio = width / height if height > 0 else 0

        currmask = segmentation.astype(np.uint8) * 255
        contours, _ = cv2.findContours(currmask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        perimeter = cv2.arcLength(contours[0], True) if contours else 0
        circularity = perimeter / (math.sqrt((4 * np.pi * area))) if area > 0 else 0


        area_cutoff = 5400 # example, set as needed
        if area < area_cutoff and aspect_ratio_cutoff_L < aspect_ratio < aspect_ratio_cutoff_H and pred_iou > predicted_iou_cutoff and circularity < circularity_cutoff:
            bounding_boxes_predicted.append((x1, y1, x2, y2))
            bb_dictionary[count] = {
                'bbox': (x1, y1, x2, y2),
                'circularity': circularity,
                'area': area,
                'iou': result['predicted_iou'],
                'stability_score': result['stability_score'],
                'aspect_ratio': aspect_ratio,
                'segmentation': segmentation,
            }

    bounding_boxes_predicted_tensor = torch.tensor(bounding_boxes_predicted)
    iou = torchvision.ops.box_iou(bounding_boxes_predicted_tensor, bounding_boxes_gt_tensor)
    iou_scores, _ = torch.max(iou, dim=1)
    nms_indices = torchvision.ops.nms(bounding_boxes_predicted_tensor.float(), iou_scores.float(), iou_threshold=0.1)
    bounding_boxes_predicted_nms = bounding_boxes_predicted_tensor[nms_indices]

    iou = torchvision.ops.box_iou(bounding_boxes_predicted_nms, bounding_boxes_gt_tensor)
    iou_array = iou.numpy()

    # pixelwise mask
    gt_segmentations = [item['segmentation'] for item in bbox_dictionary[IMAGE_NAME]]
    gtmask = make_gt_mask(gt_segmentations, np.zeros((image_height, image_width)))
    pred_mask = np.zeros((image_height, image_width))
    for i in nms_indices.numpy():
        seg_arr = np.array(bb_dictionary[i]['segmentation'], dtype=int)
        pred_mask += seg_arr
    pred_mask[pred_mask > 1] = 1

    # object-wise metrics
    thresh = 0.1
    tps_ow = sum(row.max() > thresh for row in iou_array)
    fps_ow = sum(row.max() <= thresh for row in iou_array)
    fns_ow = sum(max(iou_array[:, col]) < thresh for col in range(iou_array.shape[1]))

    precision_ow = tps_ow / (tps_ow + fps_ow + 1e-9)
    recall_ow = tps_ow / (tps_ow + fns_ow + 1e-9)
    f1_ow = 2 * (precision_ow * recall_ow) / (precision_ow + recall_ow + 1e-9)

    f1_pw, precision_pw, recall_pw, tps_pw, fps_pw, fns_pw = calculate_f1_score(gtmask, pred_mask)

    print(IMAGE_NAME, f"Objectwise F1: {f1_ow:.3f}, Pixelwise F1: {f1_pw:.3f}")

    # visualization
    image_with_boxes = image_bgr.copy()
    for box in bounding_boxes_predicted_nms.numpy():
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(image_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)

    image_with_boxes_rgb = cv2.cvtColor(image_with_boxes, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(8, 6))
    plt.axis('off')
    plt.title(f'{IMAGE_NAME}')
    plt.imshow(image_with_boxes_rgb)
    plt.show()
