# Segment Anything Model as Inference

In [1]:
from skimage.segmentation import watershed, felzenszwalb
from skimage.filters import sobel
import pandas as pd
from pathlib import Path
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.filters import rank
from scipy import ndimage as ndi
from skimage.morphology import disk
import sklearn.metrics

In [3]:
# Load the train labels
# Note the transpose!
data_dir = Path("./../../data/")
labels_train = pd.read_csv(data_dir  / "Y_train.csv", index_col=0).T

In [4]:
# Here is a function to load the data
def load_dataset(dataset_dir):
    dataset_list = []
    # Note: It's very important to load the images in the correct numerical order!
    for image_file in list(sorted(Path(dataset_dir).glob("*.png"), key=lambda filename: int(filename.name.rstrip(".png")))):
        dataset_list.append(cv2.imread(str(image_file), cv2.IMREAD_GRAYSCALE))
    return np.stack(dataset_list, axis=0)

In [6]:
# Load the train and test sets
# If you've put the shortcut directly in your drive, this should work out of the box
# Else, edit the path
data_dir = Path("./../../data/")
data_train = load_dataset(data_dir / "X_train")
data_test = load_dataset(data_dir / "X_test")

In [7]:
# The train data is a numpy array of 1000 images of 512*512
print(f"X_train shape: {data_train.shape}")
# The train label is a dataframe of 1000 rows with 262144 (=512x512) columns
print(f"Y_train shape: {labels_train.shape}")

X_train shape: (1000, 512, 512)
Y_train shape: (1000, 262144)


# On travaille ici sur les X_train qui sont labelisé

In [30]:
import os
import shutil

# Spécifiez le chemin d'accès au dossier contenant les images
path_to_X_train = './../../data/X_train/'

# Créez un nouveau dossier pour les 200 premières images
path_to_X_train_ground_truth = './../../data/X_train_ground_truth'
os.mkdir(path_to_X_train_ground_truth)

# Copiez les 200 premières images dans le nouveau dossier
for i in range(201):
    image_path = os.path.join(path_to_X_train, str(i) + '.png')
    ground_truth_path = os.path.join(path_to_X_train_ground_truth, str(i) + '.png')
    shutil.copyfile(image_path, ground_truth_path)

In [8]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sam = sam_model_registry["vit_l"](checkpoint="/Users/sarrabenyahia/Documents/GitHub/MedSAM/checkpoints_sam/sam_vit_l_0b3195.pth")

In [31]:
import os
import cv2
import numpy as np
import pandas as pd

# Replace this with the path to your X_test directory : Ici on prends les X_train car on n'entraine pas dessus et c'est les seules fichiers où on a un ground truth pour  calculer les métriques
x_test_dir = './../../data/X_train_ground_truth/'

# Create the predictions directory if it does not exist
predictions_dir = './../../data/predictions/inference_sam/'
os.makedirs(predictions_dir, exist_ok=True)


In [32]:

# Load the SamAutomaticMaskGenerator object
sam = sam_model_registry["vit_l"](checkpoint="/Users/sarrabenyahia/Documents/GitHub/MedSAM/checkpoints_sam/sam_vit_l_0b3195.pth")


In [33]:
import csv

# Loop over each image in the directory
for filename in os.listdir(x_test_dir):
    # Read the image and convert it to RGB
    image = cv2.imread(os.path.join(x_test_dir, filename))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Generate the mask using the SamAutomaticMaskGenerator object
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)

    # Write data to CSV file
    with open(f'{predictions_dir}{filename}.csv', 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
        for mask in masks:
            writer.writerow([mask['segmentation'], mask['area'], mask['bbox'], mask['predicted_iou'], mask['point_coords'], mask['stability_score'], mask['crop_box']]) 



# ajouter viz entre inférence et ground truth 

In [None]:
    # Plot the image and mask
    plt.figure(figsize=(7, 7))
    plt.imshow(image)
    show_anns(masks)
    plt.axis('off')
    plt.show()

ajouter métriques