In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import re
import gc

from PIL import Image
from sam2.sam2_video_predictor import SAM2VideoPredictor
from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    GenerationConfig,
    BitsAndBytesConfig
)
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [None]:
torch.set_num_threads(8)

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cuda'
# device = 'cpu'

## SAM2 Helper Functions

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap('tab10')
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    ax.axis('off')

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

## Extract Video Frames

In [None]:
video_frames_dir = 'video_frames'
os.makedirs(video_frames_dir, exist_ok=True)

In [None]:
!ffmpeg

For your custom videos, you can extract their JPEG frames using ffmpeg (https://ffmpeg.org/) as follows:
```
ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'
```
where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks ffmpeg to start the JPEG file from `00000.jpg`.

In [None]:
!ffmpeg -i ../demo_data/video_1.mp4 -q:v 2 -start_number 0 video_frames/'%05d.jpg'

In [None]:
# Get frames and visualize the first one.
frame_names = [
    p for p in os.listdir(video_frames_dir)
    if os.path.splitext(p)[-1] in ['.jpg', '.jpeg', '.JPG', '.JPEG']
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_frames_dir, frame_names[frame_idx])))

## Get the Molmo Points and Delete Model from Memory

In [None]:
# Load Molmo model and processor.
quant_config = BitsAndBytesConfig(
    load_in_4bit=True
)

# load the processor
processor = AutoProcessor.from_pretrained(
    'allenai/MolmoE-1B-0924',
    trust_remote_code=True,
    device_map='auto',
    torch_dtype='auto'
)

# load the model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/MolmoE-1B-0924',
    trust_remote_code=True,
    offload_folder='offload',
    quantization_config=quant_config,
    torch_dtype='auto',
)

In [None]:
def draw_point_and_show(image_path=None, points=None):
    image = cv2.imread(image_path)
    h, w, _ = image.shape

    for point in points:
        image = cv2.circle(
            image, 
            (point[0], point[1]), 
            radius=5, 
            color=(0, 255, 0), 
            thickness=5,
            lineType=cv2.LINE_AA
        )

    plt.imshow(image[..., ::-1])
    plt.axis('off')
    plt.show()

def get_coords(output_string, image_path):
    image = cv2.imread(image_path)
    h, w, _ = image.shape
    
    if 'points' in output_string:
        # Handle multiple coordinates
        matches = re.findall(r'(x\d+)="([\d.]+)" (y\d+)="([\d.]+)"', output_string)
        coordinates = [(int(float(x_val)/100*w), int(float(y_val)/100*h)) for _, x_val, _, y_val in matches]
    else:
        # Handle single coordinate
        match = re.search(r'x="([\d.]+)" y="([\d.]+)"', output_string)
        if match:
            coordinates = [(int(float(match.group(1))/100*w), int(float(match.group(2))/100*h))]
            
    return coordinates

In [None]:
def get_output(image_path=None, prompt='Describe this image.'):
    # process the image and text
    if image_path:
        inputs = processor.process(
            images=[Image.open(image_path)],
            text=prompt
        )
    else:
        inputs = processor.process(
            images=[Image.open(requests.get('https://picsum.photos/id/237/536/354', stream=True).raw)],
            text=prompt
        )

    # move inputs to the correct device and make a batch of size 1
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}

    # generate output; maximum 200 new tokens; stop generation when <|endoftext|> is generated
    output = model.generate_from_batch(
        inputs,
        GenerationConfig(max_new_tokens=200, stop_strings='<|endoftext|>'),
        tokenizer=processor.tokenizer
    )

    # only get generated tokens; decode them to text
    generated_tokens = output[0,inputs['input_ids'].size(1):]
    generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    # print the generated text
    print(generated_text)
    return generated_text


In [None]:
# Pass the first frame to get the coordinates.
image_path = 'video_frames/00000.jpg'

outputs = get_output(image_path=image_path, prompt="Point to the main player's shoes")

In [None]:
del processor, model
gc.collect()
torch.cuda.empty_cache() 

## Initialize SAM2 Inference State

In [None]:
predictor = SAM2VideoPredictor.from_pretrained(
    'facebook/sam2.1-hiera-tiny', device=device
)

In [None]:
# with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
inference_state = predictor.init_state(video_path=video_frames_dir)

## Segment and Track and Object with Manual Points

In [None]:
object_point1 = 310
object_point2 = 220

In [None]:
sample_frame = Image.open(os.path.join(video_frames_dir, frame_names[frame_idx]))
w, h = sample_frame.size
print(w, h)
plt.imshow(sample_frame)
plt.plot(object_point1, object_point2, 'ro')

In [None]:
ann_frame_idx = 0 # Frame index to interact/start with.
ann_object_id = 1 # Give a unique object ID to the object, an integer.

In [None]:
# Add the coordinate to track, here the ball.
points = np.array([[object_point1, object_point2]], dtype=np.float32)
# Add positive label, 1 to track. Negative labels, 0 do not track objects.
labels = np.array([1], np.int32)

In [None]:
# with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_object_id,
    points=points,
    labels=labels
)

In [None]:
# Visualize results.
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_frames_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
# Propgate the prompt to get masklet across the video.
# Run propagation throughout the video and collect the results in a dict
video_segments = {}  # `video_segments` contains the per-frame segmentation results
max_frame_num_to_track = None
# with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
    inference_state, max_frame_num_to_track=max_frame_num_to_track
):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

In [None]:
output_dir = 'video_out'
os.makedirs(output_dir, exist_ok=True)

In [None]:
# OpenCV VideoWriter
codec = cv2.VideoWriter_fourcc("X", "V", "I", "D")
# codec = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(
    f"{output_dir}/output.avi",
    codec, 30,
    (w, h)
)

In [None]:
# Visualize a few segmentation result frames.
vis_frame_stride = 1
plt.close('all')

dpi = plt.rcParams['figure.dpi']

for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    #### SAM visulization starts here ####
    image = Image.open(os.path.join(video_frames_dir, frame_names[out_frame_idx]))

    figsize = image.size[0] / dpi, image.size[1] / dpi
    plt.figure(figsize=figsize)
    fig, ax = plt.subplots(figsize=figsize)
    
    ax.imshow(image)
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, ax, obj_id=out_obj_id)

    plt.tight_layout()
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    #### SAM visulization ends here ####
    
    #### Converting to Numpy and saving video starts here ####
    # Convert the Matplotlib plot to a NumPy array
    canvas = FigureCanvas(fig)
    canvas.draw()
    
    # Get the RGBA buffer from the figure
    image_rgba = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8).reshape(h, w, 4)

    # Convert ARGB to RGBA
    image_rgba = np.roll(image_rgba, 3, axis=2)

    # Convert RGBA to RGB by discarding the alpha channel
    image_rgb = image_rgba[..., :3]

    # Convert RGB to BGR for OpenCV
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    # Save the image using OpenCV
    out.write(image_bgr)

# Close the plot to free memory
plt.close(fig)

In [None]:
predictor.reset_state(inference_state)

## Use Molmo Coordinates for SAM Prediction

In [None]:
coords = get_coords(outputs, image_path=image_path)

In [None]:
print(coords)

In [None]:
draw_point_and_show(image_path, coords)

In [None]:
input_points = np.array(coords)
input_labels = np.ones(len(input_points), dtype=np.int32)
print(input_points, input_labels)

In [None]:
# Add both the shoe points.
for i in range(len(input_points)):
    input_point = np.array([input_points[i]])
    input_label = np.array([input_labels[i]])
    ann_frame_idx = 0 # Frame index to interact/start with.
    ann_object_id = i # Give a unique object ID to the object, an integer.

    # with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_object_id,
        points=input_point,
        labels=input_label
    )

In [None]:
# Prediction on the first frame only.
# Visualize results.
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_frames_dir, frame_names[ann_frame_idx])))
show_points(input_points, input_labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

In [None]:
# Propagate through the entire video.
# Propgate the prompt to get masklet across the video.
# Run propagation throughout the video and collect the results in a dict
video_segments = {}  # `video_segments` contains the per-frame segmentation results
max_frame_num_to_track = None
# with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16):
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
    inference_state, max_frame_num_to_track=max_frame_num_to_track
):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

In [None]:
output_dir = 'video_out'
os.makedirs(output_dir, exist_ok=True)

In [None]:
# OpenCV VideoWriter
codec = cv2.VideoWriter_fourcc("X", "V", "I", "D")
# codec = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(
    f"{output_dir}/molmo_points_output.avi",
    codec, 30,
    (w, h)
)

In [None]:
# Visualize a few segmentation result frames.
vis_frame_stride = 1
plt.close('all')

dpi = plt.rcParams['figure.dpi']

for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    #### SAM visulization starts here ####
    image = Image.open(os.path.join(video_frames_dir, frame_names[out_frame_idx]))

    figsize = image.size[0] / dpi, image.size[1] / dpi
    plt.figure(figsize=figsize)
    fig, ax = plt.subplots(figsize=figsize)
    
    ax.imshow(image)
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, ax, obj_id=out_obj_id)

    plt.tight_layout()
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    #### SAM visulization ends here ####
    
    #### Converting to Numpy and saving video starts here ####
    # Convert the Matplotlib plot to a NumPy array
    canvas = FigureCanvas(fig)
    canvas.draw()
    
    # Get the RGBA buffer from the figure
    image_rgba = np.frombuffer(canvas.tostring_argb(), dtype=np.uint8).reshape(h, w, 4)

    # Convert ARGB to RGBA
    image_rgba = np.roll(image_rgba, 3, axis=2)

    # Convert RGBA to RGB by discarding the alpha channel
    image_rgb = image_rgba[..., :3]

    # Convert RGB to BGR for OpenCV
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    # Save the image using OpenCV
    out.write(image_bgr)

# Close the plot to free memory
plt.close(fig)