In [2]:
import torch
import time
import numpy as np
import skimage

from fast_slic import Slic
from fast_slic.avx2 import SlicAvx2
from PIL import Image
import cv2

In [51]:
def original_segment(x, n_segments=196):
    B, C, H, W = x.shape
    x = x.permute(0, 2, 3, 1) # Change channels to be last dimension
    x = x.numpy()

     # Iterate over each image in batch and get segmentation mask
    start_time = time.time()
    save_mask = np.zeros((B, H, W))
    for i, img in enumerate(x):
        cp_img = np.squeeze(img)

        segmentation_mask = skimage.segmentation.slic(cp_img, n_segments=n_segments, start_label=0, min_size_factor=0)
        save_mask[i, :, :] = segmentation_mask

    print(f"Time to get segmentation masks (OG): {time.time() - start_time}")
    return save_mask

def test_segment(x, n_segments=196):
    B, C, H, W = x.shape
    x = x.permute(0, 2, 3, 1) # Change channels to be last dimension
    x = x.numpy()

     # Iterate over each image in batch and get segmentation mask
    start_time = time.time()
    save_mask = np.zeros((B, H, W))
    for i, img in enumerate(x):
        cp_img = np.squeeze(img)

        # slic = Slic(num_components=n_segments, min_size_factor=0)
        slic = SlicAvx2(num_components=n_segments, min_size_factor=0)
        segmentation_mask = slic.iterate(cp_img)
        save_mask[i, :, :] = segmentation_mask
        # print(len(np.unique(segmentation_mask)))

    print(f"Time to get segmentation masks (NEW): {time.time() - start_time}")
    return save_mask


def original_ft(x, save_mask, n_segments=196, n_points=64, grayscale=True):

    x = x.permute(0, 2, 3, 1)
    x = x.cpu().numpy()
    B = x.shape[0]

    start_time = time.time()
    
    seg_out = torch.zeros((B, n_segments, n_points * n_points * 2))
    pos_out = torch.zeros((B, n_segments, 5))

    ft_time = time.time()

    for j, img in enumerate(x):

        unique_integers = np.unique(save_mask[j], return_counts=True)

        # Iterate over number of tokens we want
        for i in range(n_segments):                    
            # If token ID not in mask, we must set to [PAD]
            if i not in unique_integers[0]:
                seg_out[j, i, :] = torch.zeros((1, 1, n_points*n_points*2))
                pos_out[j, i, :] = torch.zeros((1, 1, 5))

            # Else, token exists and we must extract freq content & positional embeddings
            else:                        
                # Get each segment and take FT. Unroll and save
                binary_mask = (save_mask[j] == i)
                segmented_img = img * np.expand_dims(binary_mask, axis=-1)

                # Convert to grayscale if specified
                if grayscale:
                    cp_img = cv2.cvtColor(segmented_img, cv2.COLOR_BGR2GRAY)

                # Take FT
                fourier_transform = np.fft.fft2(cp_img, s=(n_points, n_points))

                # Extract magnitude and phase information
                magnitude = np.abs(fourier_transform)
                phase = np.angle(fourier_transform)

                # Save FT info
                to_save = np.stack((magnitude, phase)).flatten(order="F")                    
                assert to_save.shape[0] == seg_out.shape[2]
                seg_out[j, i, :] = torch.from_numpy(to_save)

                #### Get position embedding info (static) 
                # Area
                area = np.sum(binary_mask) / binary_mask.size

                # Center (Average)
                center_x = np.average(np.where(binary_mask)[1]) / binary_mask.shape[1]
                center_y = np.average(np.where(binary_mask)[0]) / binary_mask.shape[0]

                # Width/Height
                # print(np.where(binary_mask)[1].shape)
                width = (np.max(np.where(binary_mask)[1]) - np.min(np.where(binary_mask)[1])) / binary_mask.shape[1]
                height = (np.max(np.where(binary_mask)[0]) - np.min(np.where(binary_mask)[0])) / binary_mask.shape[0]

                # Store in array and convert to tensor on device
                pos_save = torch.from_numpy(np.array([area, center_x, center_y, width, height])) #TODO: make this device (how does this work with DataParallel)
                pos_out[j, i, :] = pos_save

    print(f"Time to get FT + pos info (OG): {time.time() - start_time}")
    return seg_out, pos_out

def test_ft(x, save_mask, n_segments=196, n_points=64, grayscale=True):

    x = x.permute(0, 2, 3, 1)
    x = x.cpu().numpy()
    B = x.shape[0]

    # Sanity check
    test = save_mask.reshape(save_mask.shape[0], -1) # Flatten H/W info
    test = np.apply_along_axis(np.unique, 1, test)
    test = np.max(test, axis=1)
    assert np.all(test == (n_segments-1))
    
    seg_out = torch.zeros((B, n_segments, n_points * n_points * 2))
    pos_out = torch.zeros((B, n_segments, 5))

    ft_time = time.time()
    x = np.array([cv2.cvtColor(color_image, cv2.COLOR_BGR2GRAY) for color_image in x])

    # TODO: Verify implementation
    for i in range(n_segments):
        binary_mask = (save_mask == i)

        segmented_imgs = binary_mask * x

        fourier_transform = np.fft.fft2(segmented_imgs, s=(n_points, n_points))

        magnitude = np.abs(fourier_transform)
        phase = np.angle(fourier_transform)

        to_save = np.stack((magnitude, phase)).reshape((fourier_transform.shape[0], -1), order="F") 
        seg_out[:, i, :] = torch.from_numpy(to_save)

        area = np.sum(binary_mask, axis=(1,2)) / (binary_mask.shape[1] * binary_mask.shape[2])

        centroid = np.array([np.mean(np.argwhere(img_mask),axis=0) / np.array(img_mask.shape) for img_mask in binary_mask])
        center_x = centroid[:, 0]
        center_y = centroid[:, 1]

        rect = np.array([(np.max(np.argwhere(img_mask), axis=0) - np.min(np.argwhere(img_mask), axis=0)) / np.array(img_mask.shape) for img_mask in binary_mask])

        width = rect[:, 0]
        height = rect[:, 1]

        pos_save = torch.from_numpy(np.array([area, center_x, center_y, width, height]).transpose()) #TODO: make this device (how does this work with DataParallel)
        pos_out[:, i, :] = pos_save
            

    print(f"Time to get FT +pos info (NEW): {time.time() - ft_time}")
    return seg_out, pos_out

In [43]:
# with Image.open("test_img.png") as f:
#    image = np.array(f)

# image = cv2.resize(image, (224, 224))
# data = np.tile(image, (256, 1, 1, 1))

# Randomly generate image data
data = np.random.randint(0, 256, size=(256, 224, 224, 3), dtype=np.uint8)
data = (torch.from_numpy(data)).permute(0, 3, 1, 2)

# Test with original implementation
og_segments = original_segment(data)
og_seg, og_pos = original_ft(data, og_segments)

Time to get segmentation masks (OG): 13.031810522079468
Time to get FT + pos info (OG): 48.90101361274719


In [52]:
new_segments = test_segment(data)
print(new_segments.shape)
test_seg, test_pos = test_ft(data, new_segments)

Time to get segmentation masks (NEW): 0.6145634651184082
(256, 224, 224)
Time to get FT +pos info (NEW): 24.836795568466187
