In [1]:
import cv2
import numpy as np
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk, ImageEnhance
import time
import imutils

class PanoramaBuilder:
    
    def __init__(self, tk_window):
        cv2.ocl.setUseOpenCL(False)
        
        """Initialize instance variables."""
        self.video_path = None
        self.images = []

        self.tk_window = tk_window
        self.tk_window.title("Panorama Builder")
        self.tk_window.geometry("900x700")
        self.tk_window.resizable(False, False)
        
        menu_bar = tk.Menu(self.tk_window)
        self.tk_window.config(menu=menu_bar)

        menu_bar.add_command(label="Generate", command=self.generate_panorama)

        help_menu = tk.Menu(menu_bar, tearoff=False)
        help_menu.add_command(label="About", command=self.show_about)
        menu_bar.add_cascade(label="Help", menu=help_menu)

        menu_bar.add_command(label="Exit", command=self.quit)

        self.canvas = tk.Canvas(self.tk_window, width=700, height=500)
        self.canvas.pack()        
        self.canvas.create_text(350, 250, text="Welcome to Panorama Builder", fill='blue', font=100)        


    def update_loading(self, current_image, total_images):
        if not hasattr(self, 'fill_line'):
            self.fill_line = self.canvas.create_rectangle(1.5, 1.5, 1.5, 23, width=0, fill="green")
    
        if total_images == 0:
            return
    
        # Calculates the width of the progress bar
        new_width = (600 / total_images) * current_image
    
        # Update the coordinates of the progress bar to reflect the current progress
        self.canvas.coords(self.fill_line, 1.5, 1.5, new_width, 23)
    
        # Prevents current_image from exceeding total_images
        ratio = min(1, current_image / total_images)
    
        # The color gradient transitions from green to red
        green_to_red = int(255 * ratio)
        red_to_green = 255 - green_to_red
        color = f'#{red_to_green:02x}{green_to_red:02x}00'
        self.canvas.itemconfig(self.fill_line, fill=color)
    
        self.tk_window.update()

    def enhance(self, image, dfa):
        '''Adjust image attributes based on brightness deviation (dfa)'''
        brightness = 1.0
        contrast = 1.0
        sharpness = 1.0
    
       # Adjust image parameters
        if dfa > 0:
            brightness -= min(0.1 * abs(dfa) / 128, 0.5)
        else:
            contrast += min(0.1 * abs(dfa) / 128, 0.5)
            sharpness += min(0.1 * abs(dfa) / 128, 0.5)
    
        try:
             # Apply brightness enhancement
            enhancer = ImageEnhance.Brightness(Image.fromarray(np.uint8(image)))
            image = enhancer.enhance(brightness)
    
            # Apply contrast enhancement
            enhancer = ImageEnhance.Contrast(image)
            image = enhancer.enhance(contrast)
    
            # Apply sharpness enhancement
            enhancer = ImageEnhance.Sharpness(image)
            image = enhancer.enhance(sharpness)
            
            return np.array(image)
        except Exception as e:
            print(f"Error enhancing image: {e}")
            return image


    def calculate_coefficient(self, gray_img):
        '''Calculate the cofficient'''
        # Constants
        NUM_HISTOGRAM_BIN = 256
    
        try:
            # Calculate normalized histogram of gray values
            hist = cv2.calcHist([gray_img], [0], None, [NUM_HISTOGRAM_BIN], [0, 256]).ravel() / gray_img.size
    
            # Calculate the mean brightness directly from histogram
            mean_brightness = np.dot(hist, np.arange(NUM_HISTOGRAM_BIN))
    
            # Calculate deviation from the mean brightness (Mean Absolute Deviation, MAD)
            deviation_from_mean = np.arange(NUM_HISTOGRAM_BIN) - mean_brightness
            mean_absolute_deviation = np.sum(np.abs(deviation_from_mean) * hist)
    
            # Brightness coefficient, normalized by mean absolute deviation
            bc = np.abs(mean_brightness - 128) / mean_absolute_deviation if mean_absolute_deviation != 0 else 0
    
            # Return the brightness coefficient and the deviation of the mean from 128
            dfa = mean_brightness - 128
            return bc, dfa
        except Exception as e:
            print(f"Error calculating image coefficients: {e}")
            return 0, 0  # Default values in case of error


    def preprocess_video(self):
        if self.video_path:
            self.setup_video_processing()
            self.process_frames()


    def setup_video_processing(self):
        """Initialize video processing with essential configurations."""
        self.capture = cv2.VideoCapture(self.video_path)
        self.images = []
        # ORB for feature detection
        self.orb = cv2.ORB_create()
        self.total_images = int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT))

        frame_rate = self.capture.get(cv2.CAP_PROP_FPS)
        self.frame_rate = round(frame_rate)
        # Dynamically adjust the frame interval threshold based on the frame rate
        self.frame_interval = max(1, int(self.frame_rate // 3))
        
        self.frame_count = 1
        self.setup_progress_bar()

    def setup_progress_bar(self):
        ''' set up progress bar'''
        tk.Label(self.tk_window, text='Process Bar').place(x=50, y=80)
        self.canvas = tk.Canvas(self.tk_window, width=600, height=22, bg="white")
        self.canvas.place(x=150, y=80)
        self.update_loading(0, self.total_images)

    def process_frames(self):
        """Process each frame in the video to extract key frames for stitching."""
        success, last_image = self.capture.read()
        if success:
            self.process_first_frame(last_image)
    
        while success:
            success, image = self.capture.read()
            if success:
                # Update the last_image for the next cycle
                if self.frame_count % self.frame_interval == 0:
                    self.process_subsequent_frame(last_image, image)
                last_image = image
            self.frame_count += 1
    
        # Make sure the last frame is processed
        self.process_subsequent_frame(last_image, last_image)
        self.finalize_progress_bar()


    def process_first_frame(self, last_image):
        """Process the first frame separately to initialize key frame list."""
        first_image = self.preprocess_image(last_image)
        self.images.append(first_image)

    def auto_white_balance(self, image):
        """ Apply automatic white balance using the Gray World assumption. """
        result = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        
        avg_a = np.mean(result[:, :, 1])
        avg_b = np.mean(result[:, :, 2])
        
        result[:, :, 1] = result[:, :, 1] - ((avg_a - 128) * (result[:, :, 1] / 128))
        result[:, :, 2] = result[:, :, 2] - ((avg_b - 128) * (result[:, :, 2] / 128))
        
        return cv2.cvtColor(result, cv2.COLOR_LAB2BGR)

    def preprocess_image(self, image):
        """Preprocess image based on its brightness and feature content."""
        gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        bc, dfa = self.calculate_coefficient(gray_img)
        
        # Apply automatic white balance
        image = self.auto_white_balance(image)
        
        if bc > 1:
            image = self.enhance(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), dfa)
        return image
    

    def process_subsequent_frame(self, last_image, image):
        """Process and compare the current frame with the last key frame to find matches."""
        kp1, kp2, valid_matches = self.find_matches(last_image, image)
        if len(valid_matches) > 4:
            self.handle_valid_matches(kp1, kp2, last_image, image, valid_matches)
        else:
            self.update_loading(self.frame_count, self.total_images)

    def find_matches(self, last_image, image):
        '''find match frames'''
        kp1, des1 = self.orb.detectAndCompute(last_image, None)
        kp2, des2 = self.orb.detectAndCompute(image, None)
        bf = cv2.BFMatcher(cv2.NORM_HAMMING)
        matches = bf.knnMatch(des1, des2, k=2)
    
        # Filter matches using Lowe's ratio test
        good_matches = []
        for m,n in matches:
            if m.distance < 0.75 * n.distance:
                good_matches.append(m)
        
        return kp1, kp2, good_matches


    def handle_valid_matches(self, kp1, kp2, last_image, image, valid_matches):
        """Handle valid matches by checking for sufficient inliers and updating the image list."""
        img1_pts, img2_pts = zip(*[(kp1[match.queryIdx].pt, kp2[match.trainIdx].pt) for match in valid_matches])
        img1_pts = np.float32(img1_pts).reshape(-1, 1, 2)
        img2_pts = np.float32(img2_pts).reshape(-1, 1, 2)
        _, mask = cv2.findHomography(img1_pts, img2_pts, cv2.RANSAC, 5.0)
        if 35 < np.count_nonzero(mask) < 1500:
            self.images.append(self.preprocess_image(image))
            self.update_loading(self.frame_count, self.total_images)


    def finalize_progress_bar(self):
        """Finalize the progress bar after processing all frames."""
        self.update_loading(self.total_images, self.total_images)
        self.capture.release()

    # Open the video
    def open_video(self):
        # Open a file dialog to choose the video file
        self.video_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4;*.avi"), ("All files", "*.*")])
        self.video_name = self.video_path[self.video_path.rfind('/')+1:self.video_path.rfind('.')]
        self.preprocess_video()

    def show_image(self, name, image):
        cv2.namedWindow(name, 0)
        cv2.resizeWindow(name, 800, 600)
        cv2.imshow(name, image) 


    def find_boundrect(self, image):
        '''find the boundrect for image'''
        binary_image = cv2.Canny(image, 50, 150) # Canny

        contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
        # If no boundrect is found, print the prompt and return a default rectangle
        if not contours:
            return (0, 0, image.shape[1], image.shape[0])

        # Look for the outline with the largest area
        corners = max(contours, key=cv2.contourArea)
        (x, y, w, h) = cv2.boundingRect(corners)
        return (x, y, w, h)

    def generate_panorama(self):
        self.open_video()
        if not self.images:
            messagebox.showerror("Error", "No video loaded.")
            return

        stitched_image = self.stitch_images()
        if stitched_image is None:
            messagebox.showerror("Error", "Generate panorama failed. Please upload a clear video.")
            return
        
        self.save_and_display_results(stitched_image)

    def stitch_images(self):
        """Stitch images using OpenCV's Stitcher."""
        images = np.array(self.images)
        stitcher = cv2.Stitcher_create()
        status, stitched_image = stitcher.stitch(images)
        if status == cv2.Stitcher_OK:
            return stitched_image
        else:
            return None



    def save_and_display_results(self, stitched_image):
        """Save and display the panorama and cropped panorama images."""
        panorama_image_name = f'Panorama_{self.video_name}.jpg'
        crop_panorama_image_name = f'Cropped_Panorama_{self.video_name}.jpg'
        
        cv2.imwrite(panorama_image_name, stitched_image)
        cropped_image = self.crop_image(stitched_image)
        cv2.imwrite(crop_panorama_image_name, cropped_image)
        
        self.show_image("Panorama", stitched_image)
        self.show_image("Cropped Panorama", cropped_image)
        cv2.waitKey()

    def crop_image(self, stitched_image):
        """Crop the panorama image to remove black borders."""
        stitched_image = cv2.copyMakeBorder(stitched_image, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        gray = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2GRAY)
        binary_image = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)[1]

        # Initialize a mask to find the bounding rectangle
        mask = np.zeros(binary_image.shape, dtype="uint8")
        x, y, w, h = self.find_boundrect(binary_image)
        cv2.rectangle(mask, (x, y), (x + w, y + h), 255, -1)

        min_rect = mask.copy()
        sub_rect = mask.copy()
        # Iteratively erode the mask until there are no non-zero pixels left in the subtracted region
        while cv2.countNonZero(sub_rect) > 0:
            min_rect = cv2.erode(min_rect, None)
            sub_rect = cv2.subtract(min_rect, binary_image)

        x, y, w, h = self.find_boundrect(min_rect)
        return stitched_image[y:y + h, x:x + w]

    def show_about(self):
        messagebox.showinfo("About", "Panorama Builder Version 1.0")
        
    # Ask whether the user want to exit the programme
    def quit(self):
        message = messagebox.askokcancel("Info", "Do you want to quit?", default = messagebox.CANCEL)
        if message:
            self.tk_window.destroy()

# Create the Tkinter gui
tk_window = tk.Tk()
app = PanoramaBuilder(tk_window)

# Run the Tkinter event loop
tk_window.mainloop()