## Set-up

In [1]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from PIL import Image, ImageDraw, ImageFont
import subprocess
import piexif
from datetime import datetime
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
# MPS seems to crash every now and then
# elif torch.backends.mps.is_available():
#     device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [10]:
# params
base_path='/Users/peter/playground/parkbridge'
folder_path = f'{base_path}/images'
warped_images_folder = f'{base_path}/auto_warped'
base_image = 'IMG_20240820_074416.jpg'

#1
matched_segments_bbox_expand = 0.05
min_segment_surface = 0.0015
max_segment_surface = 0.15
match_distance_penalty_factor=0.1
match_distance_threshold=0.015

#2
# matched_segments_bbox_expand = 0.10
# min_segment_surface = 0.0010
# max_segment_surface = 0.20
# match_distance_penalty_factor=0.1
# match_distance_threshold=0.010

#3
# matched_segments_bbox_expand = 0.10
# min_segment_surface = 0.0010
# max_segment_surface = 0.20
# match_distance_penalty_factor=0.15
# match_distance_threshold=0.015


stop_motion_result_path=f'{base_path}/auto_parkbridge.mp4'


#Gent
# folder_path = f'{base_path}/gent'
# warped_images_folder = f'{base_path}/gent_warped'
# base_image = '20241014_085000.jpg'
# stop_motion_result_path=f'{base_path}/gent.mp4'


In [4]:
np.random.seed(3)
def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)

def downscale_image_by_percentage(image, scale_percent):
    """
    Downscale the image by a percentage while maintaining the aspect ratio.

    Args:
        image (PIL.Image or numpy.ndarray): The input image to downscale.
        scale_percent (float): The percentage to scale the image by (e.g., 50 for 50% of the original size).
    
    Returns:
        PIL.Image: The downscaled image.
    """
    # If the image is in NumPy format, convert it back to a PIL Image
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Calculate the new size based on the scale percentage
    width, height = image.size
    new_width = int(width * scale_percent / 100)
    new_height = int(height * scale_percent / 100)
    new_size = (new_width, new_height)
    
    # Resize the image to the new size
    downscaled_image = image.resize(new_size,  Image.Resampling.LANCZOS)
    
    return downscaled_image

def load_image(path):
    image = Image.open(path)
    image = image.convert("RGB")
    image = downscale_image_by_percentage(image, scale_percent=100)
    image = np.array(image.convert("RGB"))
    return image

def plt_image(image):
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    
def expand_bounding_boxes_in_masks(image, masks, expand_percent=0.0):
    """
    Expand the bounding boxes of all masks by a given percentage, while ensuring they stay within image boundaries,
    and update the masks in place with the expanded bounding boxes.

    Args:
        masks (list): List of masks generated by SAM, each with a 'bbox' key.
        image (numpy.ndarray): The image from which the bounding box is extracted (to get dimensions).
        expand_percent (float): The percentage by which to expand the bounding boxes (e.g., 0.1 for 10%).

    Returns:
        list: The updated list of masks with modified 'bbox' values.
    """
    # Get the dimensions of the image
    image_height, image_width = image.shape[:2]  # Height and width from the image dimensions
    
    # Iterate over all masks and expand their bounding boxes
    for mask in masks:
        x_min, y_min, width, height = mask['bbox']
        
        # Calculate the amount to expand in each direction
        expand_w = width * expand_percent
        expand_h = height * expand_percent
        
        # Calculate the new bounding box
        new_x_min = max(0, x_min - expand_w / 2)  # Ensure it doesn't go below 0
        new_y_min = max(0, y_min - expand_h / 2)  # Ensure it doesn't go below 0
        
        new_x_max = min(image_width, x_min + width + expand_w / 2)  # Ensure it doesn't exceed image width
        new_y_max = min(image_height, y_min + height + expand_h / 2)  # Ensure it doesn't exceed image height
        
        # Calculate new width and height based on the expanded coordinates
        new_width = new_x_max - new_x_min
        new_height = new_y_max - new_y_min
        
        # Update the 'bbox' in the mask with the expanded bounding box
        mask['bbox'] = (new_x_min, new_y_min, new_width, new_height)
    
    # Return the updated list of masks
    return masks

def find_matching_segment_with_distance_penalty(template, target_image, template_bbox, penalty_factor, distance_threshold):
    """
    Use cv2.matchTemplate to find the location of the segment in the target image,
    and penalize the match score based on how far the match is from the original template's bounding box.
    
    Args:
        template (numpy.ndarray): The extracted segment (template) from the first image.
        target_image (numpy.ndarray): The target image in which to search for the template.
        template_bbox (tuple): The bounding box of the template in the format (x_min, y_min, width, height).
        penalty_factor (float): A factor to control how much distance affects the score.
    
    Returns:
        tuple: Top-left corner of the best matching region in the target image, penalized match score.
    """
    # Convert both the template and target image to grayscale
    template_gray = cv2.cvtColor(template, cv2.COLOR_RGB2GRAY)
    target_image_gray = cv2.cvtColor(target_image, cv2.COLOR_RGB2GRAY)
    
    # Perform template matching using cv2.matchTemplate
    result = cv2.matchTemplate(target_image_gray, template_gray, cv2.TM_CCOEFF_NORMED)
    
    # Find the location with the highest match score
    min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
    
    # max_loc is the top-left corner of the best match
    matched_top_left = max_loc
    
    # Extract the top-left corner of the original template's bounding box
    template_top_left = (template_bbox[0], template_bbox[1])
    
    # Calculate the Euclidean distance between the matched location and the template's original location
    distance = np.linalg.norm(np.array(matched_top_left) - np.array(template_top_left))
    
    # Apply a penalty to the match score based on the distance
    penalty = 1 / (1 + penalty_factor * distance)
    penalized_score = max_val * penalty
    if(penalized_score < distance_threshold):
        print(f'Score:{max_val} - Penalty score:{penalized_score}. Skipping')
        return None, None
    else:
        return matched_top_left, penalized_score

def process_all_masks(image, masks, target_image):
    """
    Process all masks, extract segments from the base image, and find corresponding matching regions in the target image.
    
    Args:
        image (numpy.ndarray): The input base image.
        masks (list): List of SAM-generated mask results (each containing 'segmentation').
        target_image (numpy.ndarray): The target image to search for matching regions.
    
    Returns:
        list: List of dictionaries with information about each match.
    """
    results = []
    
    # Loop over all the masks
    for idx, mask_data in enumerate(masks):
        # Extract the segmentation mask from the mask_data
        mask = mask_data['segmentation']
        bbox = mask_data['bbox']
        # Extract the segment from the base image using the mask
        extracted_segment = image[int(bbox[1]):int(bbox[1]+bbox[3]+1), int(bbox[0]):int(bbox[0]+bbox[2]+1)]
        # Find the corresponding matching region in the target image
        best_match_loc, match_score = find_matching_segment_with_distance_penalty(extracted_segment, target_image, bbox, penalty_factor=match_distance_penalty_factor, distance_threshold=match_distance_threshold)
        if(best_match_loc is None):
            continue
        
        # Store the result with necessary information
        results.append({
            'mask_index': idx,
            'best_match_loc': best_match_loc,
            'match_score': match_score,
            'segment_shape': extracted_segment.shape[:2]  # Height, width of the segment
        })
    
    return results

def plot_matches_side_by_side(base_image_name, target_image_name, base_image, target_image, match_results, masks):
    """
    Plot the original base image with segments on the left and the matched segments on the target image on the right.

    Args:
        base_image (numpy.ndarray): The original base image from which segments were extracted.
        target_image (numpy.ndarray): The target image where matches were found.
        match_results (list): List of dictionaries containing match information for each mask.
        masks (list): List of SAM-generated masks (with 'segmentation' and 'bbox').
    """
    # Create a copy of both images for displaying
    base_image_copy = base_image.copy()
    target_image_copy = target_image.copy()
    
    # Create a matplotlib figure with two subplots (side by side)
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    # Plot the base image with segment outlines on the left
    axes[0].imshow(base_image_copy)
    axes[0].set_title(f"Original {base_image_name} with Segments")
    
    # Plot the target image with match rectangles on the right
    axes[1].imshow(target_image_copy)
    axes[1].set_title(f"Matched Segments on {target_image_name}")
    
    # Loop through the match results and draw bounding boxes for both images
    for idx, (result, mask_data) in enumerate(zip(match_results, masks)):
        # Extract the original mask and bounding box (for the base image)
        mask_bbox = mask_data['bbox']
        x_min, y_min, width, height = mask_bbox
        
        # Draw rectangle and index in the center of the segment in the base image
        rect_base = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='blue', facecolor='none')
        axes[0].add_patch(rect_base)
        
        # Calculate center of the bounding box
        center_x_base = x_min + width / 2
        center_y_base = y_min + height / 2
        
        # Add index in the center of the base image's bounding box
        axes[0].text(center_x_base, center_y_base, str(idx), color='white', fontsize=12, ha='center', va='center')

        # Draw rectangles around the best match location on the target image
        top_left = result['best_match_loc']
        h, w = result['segment_shape']  # Height and width of the segment
        rect_target = patches.Rectangle(top_left, w, h, linewidth=2, edgecolor='green', facecolor='none')
        axes[1].add_patch(rect_target)
        
        # Calculate center of the bounding box on the target image
        center_x_target = top_left[0] + w / 2
        center_y_target = top_left[1] + h / 2
        
        # Add index in the center of the target image's bounding box
        axes[1].text(center_x_target, center_y_target, str(idx), color='white', fontsize=12, ha='center', va='center')
    
    # Hide axis ticks for both subplots
    axes[0].axis('off')
    axes[1].axis('off')
    
    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()
    
def plot_matches_and_warped_side_by_side(base_image_name, target_image_name, base_image, target_image, warped_image, match_results, masks):
    """
    Plot the original base image with segments on the left and the matched segments on the target image on the right.

    Args:
        base_image (numpy.ndarray): The original base image from which segments were extracted.
        target_image (numpy.ndarray): The target image where matches were found.
        match_results (list): List of dictionaries containing match information for each mask.
        masks (list): List of SAM-generated masks (with 'segmentation' and 'bbox').
    """
    # Create a copy of both images for displaying
    base_image_copy = base_image.copy()
    target_image_copy = target_image.copy()
    warped_image_copy = warped_image.copy()
    # Create a matplotlib figure with three subplots (side by side)
    fig, axes = plt.subplots(1, 3, figsize=(16, 8))
    
    # Plot the base image with segment outlines on the left
    axes[0].imshow(base_image_copy)
    axes[0].set_title(f"Original {base_image_name} with Segments")
    
    # Plot the target image with match rectangles 
    axes[1].imshow(target_image_copy)
    axes[1].set_title(f"Matched Segments on {target_image_name}")
    
    # Plot the warped image on the right
    axes[2].imshow(warped_image_copy)
    axes[2].set_title(f"Warped {target_image_name}")
    
    # Loop through the match results and draw bounding boxes for base and target image
    template_bboxes = [masks[result['mask_index']]['bbox'] for result in match_results]
    target_bboxes = [(result['best_match_loc'][0], result['best_match_loc'][1], result['segment_shape'][1], result['segment_shape'][0]) for result in match_results]

    # for idx, (result, mask_data) in enumerate(zip(match_results, masks)):
    for idx, (result, mask_data) in enumerate(zip(match_results, template_bboxes)):
        print(f"{idx} - mask bbox:{mask_data} - best_match_loc:{result['best_match_loc']} ")
        # Extract the original mask and bounding box (for the base image)
        # mask_bbox = mask_data['bbox']
        mask_bbox = mask_data
        x_min, y_min, width, height = mask_bbox
        
        # Draw rectangle and index in the center of the segment in the base image
        rect_base = patches.Rectangle((x_min, y_min), width, height, linewidth=2, edgecolor='blue', facecolor='none')
        axes[0].add_patch(rect_base)
        
        # Calculate center of the bounding box
        center_x_base = x_min + width / 2
        center_y_base = y_min + height / 2
        
        # Add index in the center of the base image's bounding box
        axes[0].text(center_x_base, center_y_base, str(idx), color='white', fontsize=12, ha='center', va='center')

        # Draw rectangles around the best match location on the target image
        top_left = result['best_match_loc']
        h, w = result['segment_shape']  # Height and width of the segment
        rect_target = patches.Rectangle(top_left, w, h, linewidth=2, edgecolor='green', facecolor='none')
        axes[1].add_patch(rect_target)
        
        # Calculate center of the bounding box on the target image
        center_x_target = top_left[0] + w / 2
        center_y_target = top_left[1] + h / 2
        
        # Add index in the center of the target image's bounding box
        axes[1].text(center_x_target, center_y_target, str(idx), color='white', fontsize=12, ha='center', va='center')
    
    # Hide axis ticks
    axes[0].axis('off')
    axes[1].axis('off')
    axes[2].axis('off')
    
    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()    
            
def calculate_bounding_box_center(bbox):
    """
    Calculate the center of a bounding box.
    
    Args:
        bbox (tuple): Bounding box in the format (x_min, y_min, width, height).
    
    Returns:
        tuple: Center point (x, y) of the bounding box.
    """
    x_min, y_min, width, height = bbox
    center_x = x_min + width / 2
    center_y = y_min + height / 2
    return center_x, center_y

def calculate_homography(template_bboxes, target_bboxes):
    """
    Calculate the homography matrix to warp the target image to match the template.
    
    Args:
        template_bboxes (list of tuples): List of bounding boxes in the template image (x_min, y_min, width, height).
        target_bboxes (list of tuples): List of bounding boxes in the target image (x_min, y_min, width, height).
    
    Returns:
        numpy.ndarray: The 3x3 homography matrix.
    """
    # Calculate the centers of the bounding boxes
    template_points = np.array([calculate_bounding_box_center(bbox) for bbox in template_bboxes])
    target_points = np.array([calculate_bounding_box_center(bbox) for bbox in target_bboxes])
    
    # Find the homography matrix using the points
    H, status = cv2.findHomography(target_points, template_points, cv2.RANSAC)
    
    return H

def warp_target_image(target_image, homography_matrix, template_image_size):
    """
    Warp the target image using the homography matrix.
    
    Args:
        target_image (numpy.ndarray): The target image to be warped.
        homography_matrix (numpy.ndarray): The 3x3 homography matrix.
        template_image_size (tuple): The size of the template image (width, height).
    
    Returns:
        numpy.ndarray: The warped target image.
    """
    # Warp the target image to align with the template
    warped_image = cv2.warpPerspective(target_image, homography_matrix, template_image_size)
    
    return warped_image

def align_images_using_homography(base_image, target_image, match_results, masks):
    """
    Align the target image to the base image using homography based on matched bounding boxes.
    
    Args:
        base_image (numpy.ndarray): The base image (template).
        target_image (numpy.ndarray): The target image to be warped.
        match_results (list): List of dictionaries containing match information for each mask.
        masks (list): List of SAM-generated masks (with 'bbox' key).
    
    Returns:
        numpy.ndarray: The warped target image.
    """
    # Extract the bounding boxes from the masks and match results
    template_bboxes = [masks[result['mask_index']]['bbox'] for result in match_results]
    target_bboxes = [(result['best_match_loc'][0], result['best_match_loc'][1], result['segment_shape'][1], result['segment_shape'][0]) for result in match_results]
    print(f"Building homograph from {len(target_bboxes)} target_bboxes")
    print(f"...template:{template_bboxes}")
    print(f"...target:{target_bboxes}")
    # Calculate the homography matrix
    if(len(template_bboxes) < 4):
        print(f"Not enough matches to warp")
        return None
    H = calculate_homography(template_bboxes, target_bboxes)
    
    # Get the size of the base (template) image
    template_image_size = (base_image.shape[1], base_image.shape[0])  # (width, height)
    
    # Warp the target image to align with the base image
    warped_image = warp_target_image(target_image, H, template_image_size)
    
    return warped_image

def get_bbox_corners(bbox):
    """Returns the four corners of the bounding box."""
    x, y, w, h = bbox
    return [
        (x, y),               # Top-left
        (x + w, y),           # Top-right
        (x, y + h),           # Bottom-left
        (x + w, y + h)        # Bottom-right
    ]

def align_images_using_homography_corners(base_image, target_image, match_results, masks):
    """
    Align the target image to the base image using homography based on matched bounding boxes.
    
    Args:
        base_image (numpy.ndarray): The base image (template).
        target_image (numpy.ndarray): The target image to be warped.
        match_results (list): List of dictionaries containing match information for each mask.
        masks (list): List of SAM-generated masks (with 'bbox' key).
    
    Returns:
        numpy.ndarray: The warped target image.
    """
    # Extract the bounding boxes from the masks and match results
    template_bboxes = [masks[result['mask_index']]['bbox'] for result in match_results]
    target_bboxes = [(result['best_match_loc'][0], result['best_match_loc'][1], result['segment_shape'][1], result['segment_shape'][0]) for result in match_results]
    
    print(f"Building homograph from {len(target_bboxes)} target_bboxes")
    print(f"...template:{template_bboxes}")
    print(f"...target:{target_bboxes}")

    # Use the corners of the bounding boxes instead of just the center
    template_points = []
    target_points = []

    for template_bbox, target_bbox in zip(template_bboxes, target_bboxes):
        template_points.extend(get_bbox_corners(template_bbox))
        target_points.extend(get_bbox_corners(target_bbox))
        
    # Convert points to numpy arrays with the correct shape for cv2.findHomography
    template_points = np.array(template_points, dtype=np.float32)
    target_points = np.array(target_points, dtype=np.float32)        

    # Calculate the homography matrix
    if len(template_points) < 4:
        print(f"Not enough matches to warp")
        return None

    # H = calculate_homography(template_points, target_points)
    H, status = cv2.findHomography(target_points, template_points, cv2.RANSAC)

    
    # Get the size of the base (template) image
    template_image_size = (base_image.shape[1], base_image.shape[0])  # (width, height)
    
    # Warp the target image to align with the base image
    warped_image = warp_target_image(target_image, H, template_image_size)
    
    return warped_image

def mask_area_filter(image,masks,min_surf=min_segment_surface, max_surf=max_segment_surface):
    surface=image.shape[0]*image.shape[(1)]
    return [m for m in masks if ((m['area'] / surface) > min_surf and (m['area'] / surface) < max_surf)]

def match_segments(image,masks,target_image,min_segment_area=min_segment_surface,max_segment_area=max_segment_surface):
    filtered_masks = mask_area_filter(image, masks, min_segment_area, max_segment_area)
    results = process_all_masks(image, filtered_masks, target_image)
    return results,filtered_masks

def match_and_plot_segments(base_image_name,target_image_name, image,masks,target_image,min_segment_area=min_segment_surface,max_segment_area=max_segment_surface):
    filtered_masks = mask_area_filter(image, masks, min_segment_area, max_segment_area)
    results = process_all_masks(image, filtered_masks, target_image)
    plot_matches_side_by_side(base_image_name,target_image_name, image, target_image, results, filtered_masks)
    return results,filtered_masks

def match_warp_and_plot_segments(base_image_name, target_image_name, base_image,masks,target_image,min_segment_area=0.0015,max_segment_area=max_segment_surface):
    filtered_masks = mask_area_filter(base_image, masks, min_segment_area, max_segment_area)
    results = process_all_masks(base_image, filtered_masks, target_image)
    if(len(results) > 3):
        warped_image = align_images_using_homography(base_image, target_image, results, filtered_masks)
    else:
        print(f'Only {len(results)} matches. Using bboxes in stead of centers')
        warped_image = align_images_using_homography_corners(base_image, target_image, results, filtered_masks)
    if(warped_image is not None):
        plot_matches_and_warped_side_by_side(base_image_name, target_image_name, base_image, target_image, warped_image, results, filtered_masks)
    return results,filtered_masks,warped_image

def create_stop_motion_movie(input_folder, output_file, frame_duration=2, transition_duration=1, fps=30):
    # Get all jpg files in the input folder
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.jpg')]
    image_files.sort()  # Sort the files to ensure correct order
    print(f"files {image_files}")

    # Get the dimensions of the first image
    first_image = cv2.imread(os.path.join(input_folder, image_files[0]))
    height, width = first_image.shape[:2]

    # Create a temporary video file
    temp_output = 'temp_output.mp4'
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))

    for i in range(len(image_files)):
        current_img = cv2.imread(os.path.join(input_folder, image_files[i]))
        next_img = cv2.imread(os.path.join(input_folder, image_files[(i + 1) % len(image_files)]))
        
        # Hold the current image
        for _ in range(fps * frame_duration):
            out.write(current_img)
        
        # Cross-fade to the next image
        for j in range(int(fps * transition_duration)):
            alpha = j / (fps * transition_duration)
            blended = cv2.addWeighted(current_img, 1 - alpha, next_img, alpha, 0)
            out.write(blended)

    out.release()

    # Use FFmpeg to convert the temporary video to the final output with improved compression
    ffmpeg_cmd = [
        'ffmpeg',
        '-y',
        '-i', temp_output,
        '-c:v', 'libx264',
        '-preset', 'slow',
        '-crf', '23',
        '-vf', f'scale=-2:720',  # Scale to 720p, maintaining aspect ratio
        '-movflags', '+faststart',
        '-c:a', 'aac',
        '-b:a', '128k',
        output_file
    ]
    subprocess.run(ffmpeg_cmd, check=True)

    # Remove the temporary file
    os.remove(temp_output)

    print(f"Stop motion movie created: {output_file}")

def plot_overlayed(image,masks):
    filtered_masks = mask_area_filter(image,masks,0.0,1.0)
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_anns(filtered_masks)
    plt.axis('off')
    plt.show() 

def crop_center(image, crop_width, crop_height):
    """Crops the image from the center with given width and height."""
    img_width, img_height = image.size
    left = (img_width - crop_width) // 2
    top = (img_height - crop_height) // 2
    right = left + crop_width
    bottom = top + crop_height
    return image.crop((left, top, right, bottom))

def add_timestamp_from_exif(orig_image_path,image):
    # Timestamp from exif
    orig_image = Image.open(orig_image_path)
    exif_data = piexif.load(orig_image.info.get('exif', b''))
    exif_datetime = exif_data.get('Exif', {}).get(piexif.ExifIFD.DateTimeOriginal)
    if exif_datetime:
        dt = datetime.strptime(exif_datetime.decode('utf-8'), '%Y:%m:%d %H:%M:%S')
        timestamp_str = dt.strftime('%d-%b %H:%M')

        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(rgb_image)

        draw = ImageDraw.Draw(image)
        font_path = "/Library/Fonts/PTSans-Regular.ttf"
        font_size = 40
        font = ImageFont.truetype(font_path, font_size)
        img_width, img_height = image.size
        text_position = (img_width // 2, 80)  # Top center, 12 pixels from the top
        draw.text(text_position, timestamp_str, font=font, fill=(255, 255, 255), anchor="ms")
        return image
    else:
        print('No exif datatime for')
        return image


In [5]:

sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator_1 = SAM2AutomaticMaskGenerator(sam2)
mask_generator_2 = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=32,               # points_per_side: Optional[int] = 32,
    points_per_batch=64,              # points_per_batch: int = 64,
    pred_iou_thresh=0.8,              # pred_iou_thresh: float = 0.8,
    stability_score_thresh=0.95,      # stability_score_thresh: float = 0.95,
    stability_score_offset=1.0,       # stability_score_offset: float = 1.0,
    crop_n_layers=0,                  # crop_n_layers: int = 0,
    box_nms_thresh=0.7,               # box_nms_thresh: float = 0.7,
    crop_n_points_downscale_factor=1, # crop_n_points_downscale_factor: int = 1,
    min_mask_region_area=5.0,        # min_mask_region_area: int = 0,
    use_m2m=False,                     # use_m2m: bool = False,
)

mask_generator=mask_generator_2

image_files = sorted([f for f in os.listdir(folder_path) if f.endswith('.jpg')])
# image_files = sorted([f for f in os.listdir(folder_path) if f.endswith('.jpg') and (f==base_image or f=='IMG_20240528_080644.jpg')])
images_list = [load_image(os.path.join(folder_path, file)) for file in image_files]
print(f'Loaded {len(images_list)} images from {folder_path}')

base_image_ix = image_files.index(base_image)
masks_base_image = mask_generator.generate(images_list[base_image_ix])

In [11]:
masks_0 = expand_bounding_boxes_in_masks(images_list[base_image_ix], masks_base_image, matched_segments_bbox_expand)

matched_results=[]
matched_masks=[]

images_to_warp = images_list[:len(images_list)]

## Choose one of 1|2|3
# 1. Match segments
# for idx, image in enumerate(images_to_warp):
#     results, masks = match_segments(images_list[base_image_ix], masks_0, image)
#     matched_results.append(results)
#     matched_masks.append(masks)

# 2. Match and plot matched segments
# for idx, image in enumerate(images_to_warp):
#     results, masks = match_and_plot_segments(image_files[base_image_ix],image_files[idx], images_list[base_image_ix], masks_0, image)
#     matched_results.append(results)
#     # matched_top_masks.append(masks)
#     matched_masks.append(masks)


# 3. Match and warp and plot matched segments and warped image
for idx, image in enumerate(images_to_warp):
    print(f"Working on {image_files[idx]}")
    results, masks, warped_image = match_warp_and_plot_segments(image_files[base_image_ix],image_files[idx],  images_list[base_image_ix], masks_0, image)
    matched_results.append(results)
    matched_masks.append(masks)
    if(warped_image is not None):
        warped_image_bgr = cv2.cvtColor(warped_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(f'./auto_warped/{image_files[idx]}', warped_image_bgr)   
    else:
        print(f"Cant't warp {image_files[idx]}")

In [12]:
##
for ix, image in enumerate(images_to_warp):
    if(len(matched_results[ix]) > 3):
        warped_image = align_images_using_homography(images_list[base_image_ix], images_list[ix], matched_results[ix], matched_masks[ix])
    else:
        print(f'Only {len(matched_results[ix])} matches. Using bboxes in stead of centers')
        warped_image = align_images_using_homography_corners(images_list[base_image_ix], images_list[ix], matched_results[ix], matched_masks[ix])
    
    if(warped_image is not None):
        rgb_image = cv2.cvtColor(warped_image, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(rgb_image)
        cropped_img = crop_center(pil_image, 3000, 1800)
        opencv_cropped_img = np.array(cropped_img)
        opencv_cropped_img_bgr = opencv_cropped_img
        final_img = add_timestamp_from_exif(f"{folder_path}/{image_files[ix]}", opencv_cropped_img_bgr)
        final_img.save(f'{warped_images_folder}/{image_files[ix]}')
    else:
        print(f"Cant't warp {image_files[ix]}")

In [13]:
import os
import cv2
import subprocess
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.ticker as ticker
from matplotlib.dates import DateFormatter
from datetime import datetime
import locale

def create_stop_motion_movie_with_steps(input_folder, output_file, csv_file, frame_duration=2, transition_duration=1, fps=30, audio_file=None):
    # Read CSV data
    data = pd.read_csv(csv_file, parse_dates=['date'])

    # Get all jpg files in the input folder
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith('.jpg')]
    image_files.sort()  # Sort the files to ensure correct order
    print(f"files {image_files}")

    # Get the dimensions of the first image
    first_image = cv2.imread(os.path.join(input_folder, image_files[0]))
    height, width = first_image.shape[:2]

    # Create a temporary video file
    temp_output = 'temp_output.mp4'
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))

    # Initialize the line chart image
    line_chart_file = 'line_chart.png'
    full_data = data.copy()

    # Loop through each image
    for i in range(len(image_files)):
        current_img = cv2.imread(os.path.join(input_folder, image_files[i]))
        next_img = cv2.imread(os.path.join(input_folder, image_files[(i + 1) % len(image_files)]))
        
        # Resize images to ensure they have the same size
        current_img = cv2.resize(current_img, (width, height))
        next_img = cv2.resize(next_img, (width, height))

        # Extract the date from the filename (format: IMG_yyyymmdd_hhmmss.jpg)
        image_date_str = image_files[i][4:12]  # Extract yyyymmdd
        image_date = datetime.strptime(image_date_str, '%Y%m%d')

        # Filter the CSV data up to the current image date
        filtered_data = data[data['date'] <= image_date]

        # Update the line chart with the filtered data (this keeps the line chart updated with new data)
        create_line_chart(filtered_data, full_data, line_chart_file)

        # Superimpose the line chart on the current image (without blending it)
        current_img_with_chart = overlay_line_chart(current_img, line_chart_file)

        # Hold the current image (with the updated chart)
        for _ in range(fps * frame_duration):
            out.write(current_img_with_chart)

        # Cross-fade the images, but without fading the line chart
        for j in range(int(fps * transition_duration)):
            alpha = j / (fps * transition_duration)
            blended_img = cv2.addWeighted(current_img, 1 - alpha, next_img, alpha, 0)
            # Overlay the line chart on the blended image (to keep it static during the transition)
            blended_img_with_chart = overlay_line_chart(blended_img, line_chart_file)
            out.write(blended_img_with_chart)

    # Hold the last image if audio is longer than the video
    last_img = current_img_with_chart  # The last image with the chart
    video_duration = len(image_files) * (frame_duration + transition_duration)  # In seconds

    # Get the duration of the MP3 audio file
    if audio_file:
        result = subprocess.run(['ffprobe', '-v', 'error', '-show_entries', 'format=duration', '-of', 'default=noprint_wrappers=1:nokey=1', audio_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        audio_duration = float(result.stdout)
        
        if audio_duration > video_duration:
            hold_duration = audio_duration - video_duration
            print(f"Holding the last frame for {hold_duration} seconds.")
            # Hold the last image for the remaining duration of the audio
            for _ in range(int(hold_duration * fps)):
                out.write(last_img)

    out.release()

    # FFmpeg command to add the MP3 audio to the video
    ffmpeg_cmd = [
        'ffmpeg',
        '-y',
        '-i', temp_output,            # Input: stop motion movie
        '-i', audio_file,             # Input: MP3 file
        '-c:v', 'libx264',
        '-preset', 'slow',
        '-crf', '23',
        '-vf', f'scale=-2:720',  # Scale to 720p, maintaining aspect ratio
        '-movflags', '+faststart',
        '-c:a', 'aac',
        '-b:a', '128k',
        '-ar', '44100',            # Resample audio to 44.1 kHz
        '-shortest',               # Stops the video when the shorter stream ends (video or audio)
        output_file
    ]

    # Run the FFmpeg command
    subprocess.run(ffmpeg_cmd, check=True)

    # Remove the temporary file
    os.remove(temp_output)

    print(f"Stop motion movie with audio created: {output_file}")

def overlay_line_chart(image, line_chart_file):
    """
    Overlay the line chart in the bottom-right corner of the given image,
    resizing it to 750px wide and 450px high.
    """
    # Load the line chart as an image (with alpha channel for transparency)
    line_chart_img = cv2.imread(line_chart_file, cv2.IMREAD_UNCHANGED)  # Load with transparency

    # Resize the line chart to be 750px wide and 450px high
    line_chart_img_resized = cv2.resize(line_chart_img, (750, 450))

    # Get the dimensions of the resized chart
    chart_height, chart_width = line_chart_img_resized.shape[:2]

    # Define the region of interest (ROI) in the bottom-right corner
    x_offset = image.shape[1] - chart_width - 10  # 10 pixels from the right
    y_offset = image.shape[0] - chart_height - 10  # 10 pixels from the bottom

    # If the line chart has an alpha channel, blend it with the image
    if line_chart_img_resized.shape[2] == 4:
        # Split the line chart into its color channels and alpha channel
        b, g, r, alpha = cv2.split(line_chart_img_resized)

        # Normalize the alpha channel to be between 0 and 1
        alpha = alpha / 255.0

        # Blend the line chart with the image
        for c in range(0, 3):  # Iterate over the B, G, R channels
            image[y_offset:y_offset+chart_height, x_offset:x_offset+chart_width, c] = (
                alpha * line_chart_img_resized[:, :, c] +
                (1 - alpha) * image[y_offset:y_offset+chart_height, x_offset:x_offset+chart_width, c]
            )

    return image

def create_line_chart(filtered_data, full_data, output_file):
    """
    Creates a static line chart from filtered data (up to a specific date) and saves it as an image.
    The chart will have a light grey background with 70% opacity, display steps in 1K units with 1 decimal,
    and annotate the last data point with the date in Dutch format (dd-mon).
    """
    # Set locale to Dutch
    locale.setlocale(locale.LC_TIME, 'nl_NL.UTF-8')  # For Dutch date formatting

    fig, ax = plt.subplots(figsize=(4, 2))  # Create a smaller chart (400x200px)

    # Set light grey background for the figure with 70% opacity
    fig.patch.set_facecolor(mcolors.to_rgba('lightgrey', 0.7))
    ax.set_facecolor(mcolors.to_rgba('lightgrey', 0.7))

    # Plot the filtered data
    ax.plot(filtered_data['date'], filtered_data['steps'], lw=2)

    # Calculate 5% margin for x-axis extension
    date_range = full_data['date'].max() - full_data['date'].min()
    margin = date_range * 0.10  # 5% of the date range

    # Set the x-axis limits with the added margin
    ax.set_xlim(full_data['date'].min(), full_data['date'].max() + margin)

    # Set the y-axis limits
    ax.set_ylim(full_data['steps'].min() * 0.9, full_data['steps'].max() * 1.1)

    # Format y-axis to show steps in 1K units with 1 decimal
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f'{x/1000:.1f}K'))

    # Annotate the last data point if available, also in 1K units
    if not filtered_data.empty:
        last_date = filtered_data['date'].iloc[-1]
        last_value = filtered_data['steps'].iloc[-1] / 1000  # Convert to 1K units
        last_date_str = last_date.strftime('%d %b')  # Format date as 'dd-mon' in Dutch

        ax.annotate(f'{last_value:.1f}K\n{last_date_str}',  # Display value in 1K units with 1 decimal and date
                    xy=(last_date, last_value * 1000),  # Point at the last data point in original scale
                    xytext=(5, 5),  # Slightly offset the text
                    textcoords='offset points',
                    fontsize=10,
                    color='black')

    # Remove axis lines and labels
    ax.set_axis_off()

    # Save the chart as an image with a semi-transparent background
    plt.savefig(output_file, dpi=100)
    plt.close(fig)


In [None]:
print(warped_images_folder)
# Stop motion from warped images
input_folder = warped_images_folder
output_file = stop_motion_result_path
create_stop_motion_movie_with_steps(warped_images_folder, output_file, "steps.csv", transition_duration=0.75, audio_file='one_fine_day.mp3')

Show all the masks overlayed on the image.

In [None]:
ix=10
plot_overlayed(images_list[base_image_ix],masks_0)