In [None]:
%matplotlib notebook
import os
import cv2
from SegTracker import SegTracker
from model_args import aot_args,sam_args,segtracker_args
from PIL import Image
from aot_tracker import _palette
import numpy as np
import torch
import imageio
import matplotlib.pyplot as plt
from scipy.ndimage import binary_dilation
import gc
def save_prediction(pred_mask,output_dir,file_name):
    save_mask = Image.fromarray(pred_mask.astype(np.uint8))
    save_mask = save_mask.convert(mode='P')
    save_mask.putpalette(_palette)
    save_mask.save(os.path.join(output_dir,file_name))
def colorize_mask(pred_mask):
    save_mask = Image.fromarray(pred_mask.astype(np.uint8))
    save_mask = save_mask.convert(mode='P')
    save_mask.putpalette(_palette)
    save_mask = save_mask.convert(mode='RGB')
    return np.array(save_mask)
def draw_mask(img, mask, alpha=0.5, id_countour=False):
    img_mask = np.zeros_like(img)
    img_mask = img
    if id_countour:
        # very slow ~ 1s per image
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[obj_ids!=0]

        for id in obj_ids:
            # Overlay color on  binary mask
            if id <= 255:
                color = _palette[id*3:id*3+3]
            else:
                color = [0,0,0]
            foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
            binary_mask = (mask == id)

            # Compose image
            img_mask[binary_mask] = foreground[binary_mask]

            countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
            img_mask[countours, :] = 0
    else:
        binary_mask = (mask!=0)
        countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
        foreground = img*(1-alpha)+colorize_mask(mask)*alpha
        img_mask[binary_mask] = foreground[binary_mask]
        img_mask[countours,:] = 0
        
    return img_mask.astype(img.dtype)

### Set parameters for input and output

In [None]:
video_name = 'au_air'
io_args = {
    'input_video': f'./assets/{video_name}.mp4',
    'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks
    'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video
    'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization
}

In [None]:
file_name = "au_air"
file_path = f'./assets/{file_name}'
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)])

### Tuning SAM on the First Frame for Good Initialization

In [None]:
def mouse_event(event):
    x, y = event.xdata, event.ydata
    print(x, y)

In [None]:
frame = cv2.imread(imgs_path[0])
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)

fig = plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(frame)
cid = fig.canvas.mpl_connect('button_press_event', mouse_event)
plt.show()


In [None]:
# del segtracker
torch.cuda.empty_cache()
gc.collect()

In [None]:
def get_click_mask(image, click_prompt):
    # choose good parameters in sam_args based on the first frame segmentation result
    # other arguments can be modified in model_args.py
    # note the object number limit is 255 by default, which requires < 10GB GPU memory with amp
    sam_args['generator_args'] = {
            'points_per_side': 30,
            'pred_iou_thresh': 0.8,
            'stability_score_thresh': 0.9,
            'crop_n_layers': 1,
            'crop_n_points_downscale_factor': 2,
            'min_mask_region_area': 200,
        }
    segtracker = SegTracker(segtracker_args,sam_args,aot_args)
    segtracker.restart_tracker()
    frame = cv2.imread(image)
    frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
    seg_mask, masked_frame = segtracker.seg_acc_click( 
                                                  origin_frame=frame, 
                                                  coords=np.array(click_prompt["points_coord"]),
                                                  modes=np.array(click_prompt["points_mode"]),
                                                  multimask=click_prompt["multimask"],
                                                )
    return seg_mask, masked_frame
    

In [None]:
click_masks = []
masked_frames = []

seg_mask, masked_frame = get_click_mask(imgs_path[0], click_prompt = {
    "points_coord":[[1060,500],[1276,464]],
    "points_mode":[1,0],
    "multimask":"True",
})
click_masks.append(seg_mask)
masked_frames.append(masked_frame)

seg_mask, masked_frame = get_click_mask(imgs_path[0], click_prompt = {
    "points_coord":[[1535,460],[1276,464]],
    "points_mode":[1,0],
    "multimask":"True",
})
click_masks.append(seg_mask)
masked_frames.append(masked_frame)

seg_mask, masked_frame = get_click_mask(imgs_path[6], click_prompt = {
    "points_coord":[[775,502],[904,490]],
    "points_mode":[1,0],
    "multimask":"True",
})
click_masks.append(seg_mask)
masked_frames.append(masked_frame)


In [None]:
plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(masked_frames[0])
plt.show()

### Generate Results for the Whole Video

In [None]:
# For every sam_gap frames, we use SAM to find new objects and add them for tracking
# larger sam_gap is faster but may not spot new objects in time
segtracker_args = {
    'sam_gap': 5, # the interval to run sam to segment new objects
    'min_area': 200, # minimal mask area to add a new mask as a new object
    'max_obj_num': 255, # maximal object number to track in a video
    'min_new_obj_iou': 0.8, # the area of a new object in the background should > 80% 
}
# output masks
output_dir = io_args['output_mask_dir']
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
pred_list = []
masked_pred_list = []
sam_gap = segtracker_args['sam_gap']

In [None]:
def track_frames(start: int, end: int):
    with torch.cuda.amp.autocast():
        for i in range(start, end):
            frame = cv2.imread(imgs_path[i])
            frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
            pred_mask = segtracker.track(frame,update_memory=True)
            masked_frame = draw_mask(frame,pred_mask)
            masked_pred_list.append(masked_frame)
            print("processed frame {}, obj_num {}".format(i,segtracker.get_obj_num()),end='\r')

frame = cv2.imread(imgs_path[0])
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
with torch.cuda.amp.autocast():
    segtracker = SegTracker(segtracker_args,sam_args,aot_args)
    refined_merged_mask = segtracker.add_mask(click_masks[0])
    segtracker.update_origin_merged_mask(refined_merged_mask)
    segtracker.curr_idx += 1
    refined_merged_mask = segtracker.add_mask(click_masks[1])
    segtracker.curr_idx += 1
    segtracker.update_origin_merged_mask(refined_merged_mask)
    segtracker.add_reference(frame, segtracker.origin_merged_mask)
    segtracker.first_frame_mask = segtracker.origin_merged_mask

    track_frames(0, 6)

In [None]:
frame = cv2.imread(imgs_path[6])
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)

In [None]:
with torch.cuda.amp.autocast():
    # seg_mask = segtracker.seg(frame)
    torch.cuda.empty_cache()
    gc.collect()
    track_mask = segtracker.track(frame)
    # find new objects, and update tracker with new objects
    new_obj_mask = segtracker.find_new_objs(track_mask, click_masks[2])
    pred_mask = track_mask + new_obj_mask
    # segtracker.restart_tracker()
    segtracker.add_reference(frame, pred_mask, 6)

In [None]:
seg_mask = (track_mask==0) * (click_masks[2]*2)
new_obj_mask2 = seg_mask
new_obj_ids = np.unique(new_obj_mask2)
new_obj_ids = new_obj_ids[new_obj_ids!=0]
idx = 2
new_obj_area = np.sum(new_obj_mask2==idx)
obj_area = np.sum(seg_mask==idx)

In [None]:
np.max(track_mask)

In [None]:
plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(draw_mask(frame,pred_mask,id_countour=False))
# plt.imshow(frame)
plt.show()

In [None]:
with torch.cuda.amp.autocast():
    track_frames(6, 17)

In [None]:
assert False

In [None]:
from matplotlib.widgets import Slider

fig = plt.figure(figsize=(9,6))
ax = fig.add_subplot(1, 1, 1)
img = ax.imshow(masked_pred_list[0], interpolation='nearest')
plt.axis('off')

axfreq = fig.add_axes([0.25, 0.1, 0.65, 0.03])
freq_slider = Slider(
    ax=axfreq,
    label="image",
    valmin=0,
    valmax=len(masked_pred_list),
    valstep=1,
    valinit=0,
)                 

# The function to be called anytime a slider's value changes
def update(val):
    img.set_data(masked_pred_list[val])
    fig.canvas.draw()

freq_slider.on_changed(update)
                 
plt.show()