In [1]:
import geometry
import scale_transform
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
import platonics
from segment_anything import sam_model_registry, SamPredictor
import sys
import segmentfunction


In [None]:
# configuration
checkpointfilepath = r"C:\\Users\\aarus\\Downloads\\sam_vit_h_4b8939.pth"
num_slices = 5
transform_list = platonics.get_cube_transforms()

test_1 = np.zeros((128, 128, 128))
test_1[32:96, 48:80, 32:96] = 200
test_2=np.ones((128,128,128))
for i in range(128):
    for j in range(128):
        for k in range(128):
            if ((i-64)**2+(j-64)**2+(k-64)**2)**0.5>50:
                test_2[i,j,k]=0
                
plt.imshow(test_1[:,:,64])

In [3]:
def get_prompt_slices(image):
    slices_list = []
    for a, t in enumerate(transform_list):
        transformed_img = scale_transform.global_to_local(image, t)
        slice_transformed_img = transformed_img[:,:,transformed_img.shape[2]//2]
        cv2.imwrite(f'slices_for_prompting/slice_{a}.png', slice_transformed_img)

        slice_info = dict()
        slice_info['idx'] = transformed_img.shape[2]//2
        slice_info['transform'] = t
        slice_info['shape'] = transformed_img.shape
        slices_list.append(slice_info)    

    return slices_list

def get_line_segments(slices_list, pos_polylines_slices, neg_polylines_slices):
    pos_seg = []
    neg_seg = []

    for i, s in enumerate(slices_list):
        idx = s['idx']
        shape = s['shape']
        transform_curr = s['transform']

        for line in pos_polylines_slices[i]:
            global_line = []
            for point in line:
                scaled_point = scale_transform.scale_forward(point[:2] + [idx], shape)
                transformed_point = scale_transform.coord_to_index(scaled_point, transform_curr, shape)
                global_line.append(transformed_point)
            for i in range(len(global_line) - 1):
                pos_seg.append([global_line[i], global_line[i + 1]])

        for line in neg_polylines_slices[i]:
            global_line = []
            for point in line:
                scaled_point = scale_transform.scale_forward(point[:2] + [idx], shape)
                transformed_point = scale_transform.coord_to_index(scaled_point, transform_curr, shape)
                global_line.append(transformed_point)
            for i in range(len(global_line) - 1):
                neg_seg.append([global_line[i], global_line[i + 1]])
    
    return pos_seg, neg_seg

def get_intersections(matrix_shape, pos_seg, neg_seg, t, z):
    # calculate intersection points
    pos_intersects = []
    neg_intersects = []

    for p in pos_seg:
        point_1 = scale_transform.scale_backward(scale_transform.index_to_coord(p[0], t, matrix_shape), matrix_shape)
        point_2 = scale_transform.scale_backward(scale_transform.index_to_coord(p[1], t, matrix_shape), matrix_shape)

        intersection = scale_transform.get_intersection_point(point_1, point_2, z)
        if intersection:
            pos_intersects.append(intersection[:2])

    for n in neg_seg:
        point_1 = scale_transform.scale_backward(scale_transform.index_to_coord(n[0], t, matrix_shape), matrix_shape)
        point_2 = scale_transform.scale_backward(scale_transform.index_to_coord(n[1], t, matrix_shape), matrix_shape)

        intersection = scale_transform.get_intersection_point(point_1, point_2, z)
        if intersection:
            neg_intersects.append(intersection[:2])
    
    return pos_intersects, neg_intersects

def normalize (image):
    image = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
    image = image.astype(np.uint8)
    stacked = np.stack([image, image, image], axis=2)
    return stacked


In [7]:
# get prompt slices
slices_list = get_prompt_slices(test_1)

In [None]:
# given json with polylines
# parse
with open('test_data.json', 'r') as file:
    prompt_points = json.load(file)
pos_polylines_slices = []
neg_polylines_slices = []

for prompt in prompt_points:
    pos_polylines_slices.append(prompt['pos_polylines'])
    neg_polylines_slices.append(prompt['neg_polylines'])

# get pos, neg line segments
pos_seg, neg_seg = get_line_segments(slices_list, pos_polylines_slices, neg_polylines_slices)

In [None]:
# initialize the model
sam_checkpoint = checkpointfilepath
model_type = "vit_h"
# device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
predictor = SamPredictor(sam)

In [None]:
# BIG forloop
prompting_slices_dict = dict()
count = 0

for t in transform_list:
    transformed_img = scale_transform.global_to_local(test_1, t)
    matrix_shape = transformed_img.shape
    
    z_coord_list = np.linspace(0, transformed_img.shape[2], num_slices + 1, endpoint=False, dtype=int)[1:]  

    for z in z_coord_list:
        # get the slice of the rotated array
        slice_transformed_img = transformed_img[:,:,z]
        slice_shape = slice_transformed_img.shape
        
        pos_intersects, neg_intersects = get_intersections(matrix_shape, pos_seg, neg_seg, t, z)

        # maybe we dont need
        if len(pos_intersects) != 0:
            prompt = [pos_intersects, neg_intersects]
            segmentfunction.segment(predictor, normalize(slice_transformed_img), 0, prompt)
