RAIN - Real & Artificial Intelligence in Neuroscience

## Video aligner

How to use:
- Select the videos you want to align
- On each frame, click on two points that you want to be aligned between videos
- The aligned videos will be stored in the 'Aligned' folder

In [1]:
import cv2
import numpy as np
import os
from tkinter import Tk, filedialog, messagebox

import random

def merge_random_frames(video_file, num_frames: int = 5) -> np.ndarray:
    """
    Merge a specified number of random frames from each video file into a single image.

    Args:
        num_frames (int): Number of random frames to merge from each video. Default is 5.
    
    Returns:
        np.ndarray: Merged image.
    """
    merged_image = None
    
    cap = cv2.VideoCapture(video_file)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    selected_frame_indices = random.sample(range(total_frames), num_frames)

    for frame_idx in selected_frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        success, frame = cap.read()

        if not success:
            print(f"Could not read frame {frame_idx} from {video_file}")
            continue

        # Calculate transparency
        transparency = round(1 / num_frames, 4)
        transparent_frame = (frame * transparency).astype(np.uint8)
        
        if merged_image is None:
            # Initialize merged image
            merged_image = np.zeros_like(transparent_frame)
        
        # Add transparent frame to the merged image
        merged_image = cv2.add(merged_image, transparent_frame)
    
    cap.release()
    
    return merged_image

def align_videos():

    print("Instructions:")
    print("1. Left-click to select points.")
    print("2. Enter to confirm the current point.")
    print("3. Select two points on each video to align them.")
    print("Press 'q' to quit without aligning.")

    # Initialize Tkinter and hide the root window
    root = Tk()
    root.withdraw()
    
    # Open file dialog to select video files
    video_files = filedialog.askopenfilenames(
        title="Select Video Files",
        filetypes=[("Video Files", "*.mp4 *.avi *.mkv *.mov")]
    )
    if not video_files:
        raise ValueError("No video files selected.")
    
    print(f"Selected {len(video_files)} videos.")

    # Initialize variables
    zoom_scale = 5  # How much to zoom in
    zoom_window_size = 25  # Half the width/height of the zoomed-in area
    point_pairs = []  # To store pairs of points for each video
    first_frames = []

    # Define callback function for point selection
    def select_points(event, x, y, flags, param):
        nonlocal frame, temp_frame, current_point, confirmed_points, zoom_scale, zoom_window_size

        #if event == cv2.EVENT_MOUSEMOVE:   

        if event == cv2.EVENT_LBUTTONDOWN:
            # Update the current point with the clicked position
            current_point = (x, y)
            # Draw the current point
            cv2.circle(temp_frame, current_point, 3, (0, 255, 0), -1)
            # Draw the confirmed points on the frame
            for point in confirmed_points: 
                cv2.circle(temp_frame, point, 3, (0, 0, 255), -1)
            # Display the frame
            cv2.imshow('Select Points', temp_frame)
        
        # Create zoomed-in display
        x1 = max(0, x - zoom_window_size)
        x2 = min(temp_frame.shape[1], x + zoom_window_size)
        y1 = max(0, y - zoom_window_size)
        y2 = min(temp_frame.shape[0], y + zoom_window_size)

        zoomed_area = temp_frame[y1:y2, x1:x2]
        
        # Resize zoomed-in area
        zoomed_area_resized = cv2.resize(zoomed_area, None, fx=zoom_scale, fy=zoom_scale, interpolation=cv2.INTER_LINEAR)

        # Add crosshair to the center
        center_x = zoomed_area_resized.shape[1] // 2
        center_y = zoomed_area_resized.shape[0] // 2
        color = (0, 255, 0)  # Black crosshair
        thickness = 2
        line_length = 20  # Length of crosshair lines

        # Draw vertical line
        cv2.line(zoomed_area_resized, (center_x, center_y - line_length), (center_x, center_y + line_length), color, thickness)
        # Draw horizontal line
        cv2.line(zoomed_area_resized, (center_x - line_length, center_y), (center_x + line_length, center_y), color, thickness)

        if x2 > (temp_frame.shape[1] - zoomed_area_resized.shape[1] - 10) and y1 < (10 + zoomed_area_resized.shape[0]):
            # Overlay zoomed-in area in the top-left corner of the frame
            overlay_x1 = 10
            overlay_x2 = 10 + zoomed_area_resized.shape[1]
            overlay_y1 = 10
            overlay_y2 = 10 + zoomed_area_resized.shape[0]
        
        else:
            # Overlay zoomed-in area in the top-right corner of the frame
            overlay_x1 = temp_frame.shape[1] - zoomed_area_resized.shape[1] - 10
            overlay_x2 = temp_frame.shape[1] - 10
            overlay_y1 = 10
            overlay_y2 = 10 + zoomed_area_resized.shape[0]
        
        # Reset the frame
        temp_frame = frame.copy()

        # Draw the current point
        if current_point is not None:
            cv2.circle(temp_frame, current_point, 3, (0, 255, 0), -1)
        # Draw the confirmed points on the frame
        for point in confirmed_points:
            cv2.circle(temp_frame, point, 3, (0, 0, 255), -1)
        # Display the zoomed-in area
        temp_frame[overlay_y1:overlay_y2, overlay_x1:overlay_x2] = zoomed_area_resized
        # Display the frame
        cv2.imshow('Select Points', temp_frame)

    def confirm_point():
        """Confirm the current point and add it to the list."""
        nonlocal temp_frame, confirmed_points, current_point
        if current_point is not None:
            confirmed_points.append(current_point)
            # Draw the confirmed points on the frame
            for point in confirmed_points: 
                cv2.circle(temp_frame, point, 3, (0, 0, 255), -1)
            # Display the frame
            cv2.imshow('Select Points', temp_frame)
            current_point = None
            print(f"Point confirmed: {confirmed_points[-1]}")  # Feedback to the user
    
    # Step 1: Extract first frames and collect two points for each video
    for video_path in video_files:
        cap = cv2.VideoCapture(video_path)
        frame = merge_random_frames(video_path)
        first_frames.append((frame, video_path))
        confirmed_points = []  # Store the two confirmed points for this video
        current_point = None  # Temporary point being adjusted
        temp_frame = frame.copy()  # Create a copy of the frame

        # Run the mouse callback with the frame and confirmed points
        cv2.imshow('Select Points', frame)
        cv2.setMouseCallback('Select Points', select_points)

        # Wait for user to confirm two points
        while len(confirmed_points) < 2:
            key = cv2.waitKey(1) & 0xFF
            if key == 13:  # Enter key to confirm the current point
                confirm_point()
            elif key == ord('q'):  # Press 'q' to quit
                response = messagebox.askquestion("Exit", "Do you want to exit aligner?")
                if response == 'yes':
                    print("Exiting point selection.")
                    cv2.destroyAllWindows()
                    return
            
        # Save the confirmed points
        point_pairs.append(confirmed_points)
        cap.release()

    cv2.destroyAllWindows()
    
    # Step 2: Calculate mean points
    if not point_pairs:
        print("No points were selected.")
        return
    
    mean_points = np.mean(point_pairs, axis=0)
    mean_point1, mean_point2 = mean_points.astype(int)

    response = messagebox.askquestion("Alignment", "Do you want the points to stand on the same horizontal line?")  
    if response == 'yes':
        # Calculate the mean y-value
        y_mean = (mean_point1[1] + mean_point2[1]) // 2  # Use integer division if you want the result as int

        # Update the y-values of both points
        mean_point1[1] = y_mean
        mean_point2[1] = y_mean
    
    print(f"Mean points: {mean_point1}, {mean_point2}")
    
    # Step 3: Align videos (rotate, resize, then translate)
    output_folder = os.path.join(os.path.dirname(video_files[0]), 'Aligned')
    os.makedirs(output_folder, exist_ok=True)
    mean_vector = mean_point2 - mean_point1
    mean_length = np.linalg.norm(mean_vector)
    mean_angle = np.arctan2(mean_vector[1], mean_vector[0])
    
    for (frame, video_path), points in zip(first_frames, point_pairs):
        point1, point2 = points
        vector = np.array(point2) - np.array(point1)
        angle = np.arctan2(vector[1], vector[0])
        length = np.linalg.norm(vector)
        
        scale = mean_length / length
        rotation_angle = mean_angle + angle

        # Step 3.1: Rotate and resize
        height, width = frame.shape[:2]
        center = (width // 2, height // 2)
        M_rotate_scale = cv2.getRotationMatrix2D(center, np.degrees(rotation_angle), scale)
        rotated_resized_frame = cv2.warpAffine(frame, M_rotate_scale, (width, height))
        
        # Step 3.2: Translate
        new_point1 = np.dot(M_rotate_scale[:, :2], np.array(point1).T) + M_rotate_scale[:, 2]
        dx, dy = mean_point1[0] - new_point1[0], mean_point1[1] - new_point1[1]
        M_translate = np.float32([[1, 0, dx], [0, 1, dy]])
        aligned_frame = cv2.warpAffine(rotated_resized_frame, M_translate, frame.shape[1::-1])
        
        # Save aligned video
        cap = cv2.VideoCapture(video_path)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video_name = os.path.basename(video_path)
        output_path = os.path.join(output_folder, video_name.replace('.', '_aligned.'))
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            
            # Apply the same transformations to each frame
            rotated_resized_frame = cv2.warpAffine(frame, M_rotate_scale, frame.shape[1::-1])
            aligned_frame = cv2.warpAffine(rotated_resized_frame, M_translate, frame.shape[1::-1])
            out.write(aligned_frame)
        
        cap.release()
        out.release()

        print(f"Aligned '{video_name}' with scale {scale:.2f}, rotation {rotation_angle:.2f}, and translation {dx:.2f}, {dy:.2f}.")
    
    print(f"Aligned videos saved in '{output_folder}'.")

In [2]:
align_videos()

Instructions:
1. Left-click to select points.
2. Enter to confirm the current point.
3. Select two points on each video to align them.
Press 'q' to quit without aligning.
Selected 18 videos.
Point confirmed: (363, 398)
Point confirmed: (1248, 403)
Point confirmed: (368, 397)
Point confirmed: (1252, 400)
Point confirmed: (367, 395)
Point confirmed: (1252, 400)
Point confirmed: (357, 419)
Point confirmed: (1243, 406)
Point confirmed: (359, 416)
Point confirmed: (1246, 403)
Point confirmed: (360, 416)
Point confirmed: (1246, 403)
Point confirmed: (361, 416)
Point confirmed: (1246, 403)
Point confirmed: (359, 415)
Point confirmed: (1246, 402)
Point confirmed: (359, 413)
Point confirmed: (1244, 402)
Point confirmed: (358, 413)
Point confirmed: (1244, 402)
Point confirmed: (359, 414)
Point confirmed: (1243, 403)
Point confirmed: (357, 413)
Point confirmed: (1242, 402)
Point confirmed: (357, 414)
Point confirmed: (1242, 403)
Point confirmed: (357, 414)
Point confirmed: (1242, 402)
Point confi