# Algorithmic Tokenizer 
Takes Segmentation Masks and Returns a Text-based token representation of the frame and its masks

In [11]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import cv2
import sys
import pickle
import time
from IPython.display import clear_output
import os
import re
import imageio
sys.path.append('..')
from helper_functions import *

In [2]:
big_plot_dim = 8

In [3]:
# Raycast out from the centroid and find the distance to the outer border of the mask
# TODO: Implement logic for when centroid point is outside the mask- need to employ special logic in that case
def raycast(mask, point, angle, max_dist):
    furthest_dist = -1
    in_bounds = True
    for i in range(max_dist):
        x = int(point[0] + i * np.cos(angle))
        y = int(point[1] + i * np.sin(angle))
        if x < 0 or y < 0 or x >= mask.shape[1] or y >= mask.shape[0]:
            if in_bounds:
                furthest_dist = i
            return furthest_dist
        if not in_bounds and mask[y, x] == 1:
            in_bounds = True
        if in_bounds and mask[y, x] == 0:
            in_bounds = False
            furthest_dist = i
    return furthest_dist

In [4]:
def tokenize_single_mask(mask, id, num_rays):
    # Find the centroid of the mask
    centroid = find_centroid(mask)
    # Tokenize the mask
    tokens = []
    tokens.append(str(id))
    tokens.append(centroid[0])
    tokens.append(centroid[1])
    str_tok = "{"
    str_tok += str(id)
    str_tok += (",{")
    str_tok += str(centroid[0])
    str_tok += (",")
    str_tok += str(centroid[1])
    str_tok += ("}")
    for i in range(num_rays):
        angle = i * 2 * np.pi / num_rays
        raycast_distance = raycast(mask, centroid, angle, 400)
        tokens.append(raycast_distance)
        str_tok += (",")
        str_tok += str(raycast_distance)
    str_tok += ("}")
    return tokens, str_tok

In [5]:
def tokenize_masks(masks, num_rays):
    ret_tokens = []
    str_tokens = []
    for i in range(len(masks)):
        if i == 0:
            arr, str = tokenize_single_mask(masks[i], "gripper", num_rays)
        elif i == 1:
            arr, str = tokenize_single_mask(masks[i], "table", num_rays)
        elif i == 2:
            arr, str = tokenize_single_mask(masks[i], "yellow block", num_rays)
        elif i == 3:
            arr, str = tokenize_single_mask(masks[i], "green block", num_rays)
        elif i == 4:
            arr, str = tokenize_single_mask(masks[i], "blue block", num_rays)
        elif i == 5:
            arr, str = tokenize_single_mask(masks[i], "red block", num_rays)
        else:
            arr, str = tokenize_single_mask(masks[i], i, num_rays)
        ret_tokens.append(arr)
        str_tokens.append(str)
        # print('Tokens for mask', i, ':', tokens)
    return ret_tokens, str_tokens

In [6]:
# Given a list of points that represent a polygon, write a function to convert it to a mask
def polygon_to_mask(polygon, width, height):
    mask = np.zeros((height, width))
    polygon = np.array(polygon, np.int32)
    mask = cv2.fillPoly(mask, [polygon], 1)
    return mask

# Given a centroid and a list of radiused points and the height and width of the original image, reconstruct the mask, making sure all the points inside the polygon are filled in
def reconstruct_mask(centroid, radiused_points, width, height):
    mask = np.zeros((height, width))
    polygon_points = []
    for i in range(len(radiused_points)):
        angle = i * 2 * np.pi / len(radiused_points)
        polygon_points.append([int(centroid[0] + radiused_points[i] * np.cos(angle)), int(centroid[1] + radiused_points[i] * np.sin(angle))])

    mask = polygon_to_mask(polygon_points, width, height)
    return mask

In [7]:
image_dir = "/home/yashas/Documents/thesis/test-images/group_0/traj0/images0/"
image_prefix = 'im_'
image_suffix = '.jpg'

In [8]:
def tokenize_full_process():
    # importing segmentations
    segmentations = import_segmentations(image_dir)
    # print(segmentations.shape)

    # get copy of segmentations[0] where the value is 1 if the pixel is in the mask and 0 if it is not
    # print(binarize_mask(segmentations[0], 1))
    # print(max(segmentations[0].flatten()))

    # create numpy array of shape (segmentations.shape[0], 6, segmentations.shape[1], segmentations.shape[2]) to store the segmentations
    # for each image in the dataset
    for i in range(segmentations.shape[0]):
        segs = np.array([])
        for j in range(6):
            mask = binarize_and_preprocess(segmentations[i], j+1)
            np_mask = np.array([mask])
            if len(segs) == 0:
                segs = np_mask
            else:
                segs = np.concatenate([segs, np_mask], axis=0)
        if i == 0:
            segmentations_reshaped = np.array([segs])
        else:
            segmentations_reshaped = np.concatenate([segmentations_reshaped, [segs]], axis=0)
    print(segmentations_reshaped.shape)

    # plot the segmentations
    # fig, ax = plt.subplots(1, 1, figsize=(big_plot_dim, big_plot_dim))
    # show_mask(segmentations_reshaped[0][0], ax, random_color=True)
    # plt.show()

    curr_frame_num = 0
    curr_frame = segmentations_reshaped[curr_frame_num]
    print(curr_frame.shape)
    
    # Display the current frame and then the mask
    image = cv2.imread(image_dir + image_prefix + str(curr_frame_num) + image_suffix)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # plt.figure(figsize=(big_plot_dim, big_plot_dim))
    # plt.title('Frame ' + str(curr_frame_num))
    # plt.imshow(image)
    # for j in range(curr_frame.shape[0]):
    #     show_mask(curr_frame[j], plt.gca(), color=color_list[j])
    # plt.axis('off')
    # plt.show()

    # Getting the centroids
    centroids = []
    for i in range(len(curr_frame)):
        centroid = find_centroid(curr_frame[i])
        centroids.append(centroid)

    # Displaying the centroid on the image
    # plt.figure(figsize=(big_plot_dim, big_plot_dim))
    # plt.title('Frame ' + str(curr_frame_num))
    # plt.imshow(image)
    # for j in range(curr_frame.shape[0]):
    #     show_mask(curr_frame[j], plt.gca(), color=color_list[j])
    # for j in range(len(centroids)):
    #     plt.scatter(centroids[j][0], centroids[j][1], color='red', marker='.', s=375, edgecolor='white', linewidth=1.25)
    # plt.axis('off')
    # plt.show()

    # Test the raycast function
    angle = np.pi
    max_dist = 400
    raycast_distances = []
    for i in range(len(centroids)):
        raycast_distance = raycast(curr_frame[i], centroids[i], angle, max_dist)
        raycast_distances.append(raycast_distance)
        print('Raycast distance for mask', i, 'is', raycast_distance)

    # Displaying the raycast on the image
    # plt.figure(figsize=(big_plot_dim, big_plot_dim))
    # plt.title('Frame ' + str(curr_frame_num))
    # # plt.imshow(image)
    # for j in range(curr_frame.shape[0]):
    #     show_mask(curr_frame[j], plt.gca(), color=color_list[j])
    # for j in range(len(centroids)):
    #     plt.scatter(centroids[j][0], centroids[j][1], color='red', marker='.', s=375, edgecolor='white', linewidth=1.25)
    #     plt.plot([centroids[j][0], centroids[j][0] + raycast_distances[j] * np.cos(angle)], [centroids[j][1], centroids[j][1] + raycast_distances[j] * np.sin(angle)], color='red', linestyle='-', linewidth=2)
    # plt.axis('off')
    # plt.show()

    # Test the tokenization function
    num_rays = 25
    tokens = tokenize_single_mask(curr_frame[0], "gripper", num_rays)
    print('Tokens for mask 0:', tokens)


    # Test the tokenization function
    frame_tokens, string_tokens = tokenize_masks(curr_frame, num_rays)
    print(frame_tokens)
    np_frame = np.array(frame_tokens)
    print(np_frame.shape)
    print(string_tokens)

    # Visualize frame_tokens
    # plt.figure(figsize=(big_plot_dim, big_plot_dim))
    # plt.title('Frame ' + str(curr_frame_num))
    # plt.imshow(image)
    # for j in range(curr_frame.shape[0]):
    #     show_mask(curr_frame[j], plt.gca(), color=color_list[j])
    # for j in range(len(centroids)):
    #     plt.scatter(centroids[j][0], centroids[j][1], color='blue', marker='.', s=375, edgecolor='white', linewidth=1.25)
    #     for k in range(num_rays):
    #         plt.plot([centroids[j][0], centroids[j][0] + frame_tokens[j][k+3] * np.cos(k * 2 * np.pi / num_rays)], [centroids[j][1], centroids[j][1] + frame_tokens[j][k+3] * np.sin(k * 2 * np.pi / num_rays)], color='red', linestyle='-', linewidth=1)
    # plt.axis('off')
    # plt.show()

    # Test the reconstruction function
    reconstructed_masks = []
    for i in range(len(frame_tokens)):
        reconstructed_mask = reconstruct_mask(np.array([frame_tokens[i][1], frame_tokens[i][2]]), frame_tokens[i][3:], image.shape[1], image.shape[0])
        reconstructed_masks.append(reconstructed_mask)

    # Visualize the reconstructed masks
    # plt.figure(figsize=(big_plot_dim, big_plot_dim))
    # plt.title('Frame ' + str(curr_frame_num))
    # plt.imshow(image)
    # for j in range(len(reconstructed_masks)):
    #     show_mask(reconstructed_masks[j], plt.gca(), color=color_list[j])
    # plt.axis('off')
    # plt.show()


    # Tokenize all frames and visualize the reconstructed masks as an animation
    num_rays = 25
    reconstructed_masks_over_time = []
    string_tokens_over_time = []
    for i in range(len(segmentations_reshaped)):
        curr_frame = segmentations_reshaped[i]
        print('Tokenizing frame', i)
        frame_tokens, str_tokens = tokenize_masks(curr_frame, num_rays)
        string_tokens_over_time.append(str(str_tokens))
        reconstructed_masks = []
        for j in range(len(frame_tokens)):
            reconstructed_mask = reconstruct_mask(np.array([frame_tokens[j][1], frame_tokens[j][2]]), frame_tokens[j][3:], image.shape[1], image.shape[0])
            reconstructed_masks.append(reconstructed_mask)
        reconstructed_masks_over_time.append(reconstructed_masks)

    string_tokens_over_time = np.array(string_tokens_over_time)
    print(len(string_tokens_over_time))
    print(string_tokens_over_time)
    # output to a text file  in same directory as segmentations
    text_file_path = image_dir + "tokens_over_time.txt"
    np.savetxt(text_file_path, string_tokens_over_time, fmt='%s')


    # Visualize the reconstructed masks over time compared to the original masks
    # for i in range(len(reconstructed_masks_over_time)):
    #     image = cv2.imread(image_dir + image_prefix + str(i) + image_suffix)
    #     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    #     fig, axs = plt.subplots(1, 2, figsize=(big_plot_dim*3, big_plot_dim*3))
    #     axs[0].set_title(f'Frame {i} Reconstruction')
    #     # axs[0].imshow(image)
    #     for j in range(len(reconstructed_masks_over_time[i])):
    #         show_mask(reconstructed_masks_over_time[i][j], axs[0], color=color_list[j])
    #     # plt.axis('off')
    #     axs[1].set_title(f'Frame {i} Orig Segmentation')
    #     for j in range(segmentations_reshaped[i].shape[0]):
    #         show_mask(segmentations_reshaped[i][j], axs[1], color=color_list[j])
    #     plt.show()
    #     time.sleep(0.1)
    #     clear_output(wait=True)

In [9]:
# import paths from the file
with open("/home/yashas/Documents/thesis/test-images/image_dirs.txt", "r") as f:
    paths = f.read().splitlines()

for path in paths:
    path = path + "/"
    image_dir = path
    print('Processing images in', image_dir)
    tokenize_full_process()

Processing images in /home/yashas/Documents/thesis/test-images/group_0/traj0/images0/
(47, 6, 480, 640)
(6, 480, 640)
Raycast distance for mask 0 is 87
Raycast distance for mask 1 is 312
Raycast distance for mask 2 is 43
Raycast distance for mask 3 is 43
Raycast distance for mask 4 is 45
Raycast distance for mask 5 is 39
Tokens for mask 0: (['gripper', 231, 48, 91, 80, 85, 94, 98, 14, 13, 13, 82, 78, 71, 94, 90, 86, 88, 84, 64, 55, 50, 50, 52, 59, 72, 91, 87], '{gripper,{231,48},91,80,85,94,98,14,13,13,82,78,71,94,90,86,88,84,64,55,50,50,52,59,72,91,87}')
[['gripper', 231, 48, 91, 80, 85, 94, 98, 14, 13, 13, 82, 78, 71, 94, 90, 86, 88, 84, 64, 55, 50, 50, 52, 59, 72, 91, 87], ['table', 311, 309, 329, 340, 351, 247, 201, 178, 170, 173, 187, 220, 288, 336, 315, 315, 336, 298, 277, 224, 181, 193, 190, 198, 224, 276, 340], ['yellow block', 210, 313, 43, 36, 32, 30, 30, 33, 39, 50, 50, 50, 51, 49, 47, 37, 30, 28, 26, 27, 30, 37, 48, 49, 52, 53, 49], ['green block', 166, 217, 42, 38, 31, 27,

KeyboardInterrupt: 

In [10]:
# import predicted mask and real mask from JSON
import json
with open('../eval_tokens_over_time.json') as f:
    predicted_masks = json.load(f)

print("Loaded", len(predicted_masks), "frames of predicted masks")

Loaded 186 frames of predicted masks


In [26]:
image_dir = "/home/yashas/Documents/thesis/test-images/group_0/traj0/images0/"
image = cv2.imread(image_dir + image_prefix + str(0) + image_suffix)

# getting predicted and real output masks from tokens
output_frame_tokens = []
predic_frame_tokens = []

for i in range(len(predicted_masks)):
    # print('Processing frame', i)
    # print(predicted_masks[i]['output'])
    # print(predicted_masks[i]['predic'])
    # print("-----")
    output_token_arr = re.split(r'[{,}]', predicted_masks[i]['output'])
    output_token_arr = output_token_arr[1:-1]
    output_token_arr = [x.strip() for x in output_token_arr]
    output_token_arr = list(filter(None, output_token_arr))
    output_non_num_indexes = [j for j in range(len(output_token_arr)) if not output_token_arr[j].isdigit()]

    predic_token_arr = re.split(r'[{,}]', predicted_masks[i]['predic'])
    predic_token_arr = predic_token_arr[1:-1]
    predic_token_arr = [x.strip() for x in predic_token_arr]
    predic_token_arr = list(filter(None, predic_token_arr))
    predic_non_num_indexes = [j for j in range(len(predic_token_arr)) if not predic_token_arr[j].isdigit()]

    if len(output_non_num_indexes) != len(predic_non_num_indexes):
        print("Error: Number of tokens in output and predicted masks do not match")
        continue

    output_frame_tok = []
    predic_frame_tok = []
    for n in range(len(output_non_num_indexes)):
        output_ind = output_non_num_indexes[n]
        predic_ind = predic_non_num_indexes[n]
        output_id = output_token_arr[output_ind]
        predic_id = predic_token_arr[predic_ind]
        output_centroid = (output_token_arr[output_ind+1], output_token_arr[output_ind+2])
        predic_centroid = (predic_token_arr[predic_ind+1], predic_token_arr[predic_ind+2])
        if n == len(output_non_num_indexes)-1:
            output_ray = output_token_arr[output_ind+3:]
            predic_ray = predic_token_arr[predic_ind+3:]
        else:
            output_ray = output_token_arr[output_ind+3:output_non_num_indexes[n+1]]
            predic_ray = predic_token_arr[predic_ind+3:predic_non_num_indexes[n+1]]
        output_ray = [int(i) for i in output_ray]
        predic_ray = [int(i) for i in predic_ray]
        output_tok = [output_id, int(output_centroid[0]), int(output_centroid[1])]
        predic_tok = [predic_id, int(predic_centroid[0]), int(predic_centroid[1])]
        output_tok.extend(output_ray)
        predic_tok.extend(predic_ray)
        # print(output_tok)
        # print(predic_tok)
        output_frame_tok.append(output_tok)
        predic_frame_tok.append(predic_tok)

    output_frame_tokens.append(output_frame_tok)
    predic_frame_tokens.append(predic_frame_tok)


print("Loaded", len(output_frame_tokens), "frames of output masks"
        " and", len(predic_frame_tokens), "frames of predicted masks")

# Create a figure and subplots for the animation
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

# Create a list to store the frames of the animation
frames = []

for frame in range(len(output_frame_tokens)):
    print(f"Processing frame {frame}")
    
    axs[0].clear()
    axs[1].clear()
    
    output_reconstructed_masks_test = []
    predic_reconstructed_masks_test = []
    for j in range(len(output_frame_tokens[0])):
        output_reconstructed_mask = reconstruct_mask(np.array([output_frame_tokens[frame][j][1], output_frame_tokens[frame][j][2]]), output_frame_tokens[frame][j][3:], image.shape[1], image.shape[0])
        output_reconstructed_masks_test.append(output_reconstructed_mask)
        predic_reconstructed_mask = reconstruct_mask(np.array([predic_frame_tokens[frame][j][1], predic_frame_tokens[frame][j][2]]), predic_frame_tokens[frame][j][3:], image.shape[1], image.shape[0])
        predic_reconstructed_masks_test.append(predic_reconstructed_mask)
    
    axs[0].set_title(f'Ex {frame} Output Recon')
    for j in range(len(output_reconstructed_masks_test)):
        show_mask(output_reconstructed_masks_test[j], axs[0], color=color_list[j])
    axs[0].axis('off')
    
    axs[1].set_title(f'Ex {frame} Predic Recon')
    for j in range(len(predic_reconstructed_masks_test)):
        show_mask(predic_reconstructed_masks_test[j], axs[1], color=color_list[j])
    axs[1].axis('off')
    
    # Capture the current frame as an image
    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    # Append the frame to the list of frames
    frames.append(image)

# Save the frames as a GIF using imageio
imageio.mimsave('prediction_reconstruction_comparison.gif', frames, fps=1)

plt.close(fig)

# for i in range(len(output_frame_tokens)):
#     output_reconstructed_masks_test = []
#     predic_reconstructed_masks_test = []
#     for j in range (len(output_frame_tokens[0])):
#         output_reconstructed_mask = reconstruct_mask(np.array([output_frame_tokens[i][j][1], output_frame_tokens[i][j][2]]), output_frame_tokens[i][j][3:], image.shape[1], image.shape[0])
#         output_reconstructed_masks_test.append(output_reconstructed_mask)
#         predic_reconstructed_mask = reconstruct_mask(np.array([predic_frame_tokens[i][j][1], predic_frame_tokens[i][j][2]]), predic_frame_tokens[i][j][3:], image.shape[1], image.shape[0])
#         predic_reconstructed_masks_test.append(predic_reconstructed_mask)
#     # Visualize output reconstructed mask in one subplot and predic reconstructed mask in another
#     fig, axs = plt.subplots(1, 2, figsize=(big_plot_dim*3, big_plot_dim*3))
#     axs[0].set_title(f'Frame {i} Output Reconstruction')
#     # axs[0].imshow(image)
#     for j in range(len(output_reconstructed_masks_test)):
#         show_mask(output_reconstructed_masks_test[j], axs[0], color=color_list[j])
#     # plt.axis('off')
#     axs[1].set_title(f'Frame {i} Predic Reconstruction')
#     for j in range(len(predic_reconstructed_masks_test)):
#         show_mask(predic_reconstructed_masks_test[j], axs[1], color=color_list[j])
#     plt.show()
#     time.sleep(0.01)
#     clear_output(wait=True)

Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Error: Number of tokens in output and predicted masks do not match
Loaded 176 frames of output masks and 176 frames of predicted masks
Processing frame 0
Processing frame 1


  image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')


Processing frame 2
Processing frame 3
Processing frame 4
Processing frame 5
Processing frame 6
Processing frame 7
Processing frame 8
Processing frame 9
Processing frame 10
Processing frame 11
Processing frame 12
Processing frame 13
Processing frame 14
Processing frame 15
Processing frame 16
Processing frame 17
Processing frame 18
Processing frame 19
Processing frame 20
Processing frame 21
Processing frame 22
Processing frame 23
Processing frame 24
Processing frame 25
Processing frame 26
Processing frame 27
Processing frame 28
Processing frame 29
Processing frame 30
Processing frame 31
Processing frame 32
Processing frame 33
Processing frame 34
Processing frame 35
Processing frame 36
Processing frame 37
Processing frame 38
Processing frame 39
Processing frame 40
Processing frame 41
Processing frame 42
Processing frame 43
Processing frame 44
Processing frame 45
Processing frame 46
Processing frame 47
Processing frame 48
Processing frame 49
Processing frame 50
Processing frame 51
Processi