In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import cv2
import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms
from icecream import ic
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from tqdm import tqdm



class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def _preprocess_image(self, image):
        assert rp.is_image(image)
        
        image = rp.as_float_image(rp.as_rgb_image(image))

        #Floor height and width to the nearest multpiple of 8
        height, width = rp.get_image_dimensions(image)
        new_height = (height // 8) * 8
        new_width  = (width  // 8) * 8

        T = torchvision.transforms
        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        
        output = transforms(image)[None].to(self.device).float()

        assert rp.is_torch_tensor(output)
        assert output.shape == (1, 3, height, width)

        return output
    
    def get_flow_map(self, from_image, to_image):
        assert rp.is_image(from_image)
        assert rp.is_image(to_image)
        assert rp.get_image_dimensions(from_image) == rp.get_image_dimensions(to_image)
        
        height, width = rp.get_image_dimensions(from_image)
        
        with torch.no_grad():
            img1 = self._preprocess_image(from_image)
            img2 = self._preprocess_image(to_image  )
            
            list_of_flows = self.model(img1, img2)
            output_flow = list_of_flows[-1][0]
    
            # Resize the predicted flow back to the original image size
            resize = torchvision.transforms.Resize((height, width))
            output_flow = resize(output_flow[None])[0]

        assert rp.is_torch_tensor(output_flow)
        assert output_flow.shape == (2, height, width)

        return output_flow

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_flow_map(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_rows=10, num_cols=10):
        predicted_flow = self.get_flow_map(from_image, to_image)
        
        height, width = rp.get_image_dimensions(from_image)
        from_image = rp.as_pil_image(from_image)
        to_image = rp.as_pil_image(to_image)
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = dots_x.flatten()
        dots_y = dots_y.flatten()
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
        
    def query_flow_map(self, points, flow_map):

        #--------------------------------------------------------------- TODO: AUDIT THIS FUNC ---------------------------------------------------------------
        #------------------------------------------------------------ REDO IT WITH GENERTIC MAP SAMPLING FUNC ------------------------------------------------
        #------------------------------------------------------------ REDO IT WITH GENERTIC MAP SAMPLING FUNC ------------------------------------------------
        # print("Coordinates shape:", points.shape)
        # print("Flow map shape:", flow_map.shape)
        
        # Normalize points to [-1, 1] range
        height, width = flow_map.shape[1:]
        normalized_coordinates = points.clone()
        normalized_coordinates[0] = (points[0] / (width - 1)) * 2 - 1
        normalized_coordinates[1] = (points[1] / (height - 1)) * 2 - 1
        normalized_coordinates = normalized_coordinates.permute(1, 0).unsqueeze(0).unsqueeze(0)
        
        # print("Normalized points shape:", normalized_coordinates.shape)
        
        # Perform bilinear interpolation using grid_sample
        flow_map_permuted = flow_map.unsqueeze(0)
        # print("Flow map permuted shape:", flow_map_permuted.shape)
        
        deltas = torch.nn.functional.grid_sample(flow_map_permuted, normalized_coordinates, mode='bilinear', align_corners=True)
        # print("Deltas shape after grid_sample:", deltas.shape)
        
        # deltas = deltas.squeeze(0).permute(1, 0)
        deltas = deltas.squeeze(0).squeeze(1).permute(1, 0).T
        # print("Deltas shape after squeeze and permute:", deltas.shape)
        
        return deltas
    def add_optical_flow(self, points, flow_map):
        deltas = self.query_flow_map(points, flow_map)
        return points + deltas

    def demo_optical_flow_animation(self, frames, num_rows=100, num_cols=100):
        height, width = rp.get_image_dimensions(frames[0])
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = torch.from_numpy(dots_x.flatten()).float().to(self.device)
        dots_y = torch.from_numpy(dots_y.flatten()).float().to(self.device)
        
        points = torch.stack([dots_x, dots_y])
        
        animation_frames = []
        
        for frame_idx in tqdm(range(len(frames) - 1)):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]
            
            predicted_flow = self.get_flow_map(from_image, to_image)
            
            # Update dot positions by accumulating optical flows
            points = self.add_optical_flow(points, predicted_flow)
            
            # Convert PIL image to OpenCV format
            from_image_cv = cv2.cvtColor(np.array(from_image), cv2.COLOR_RGB2BGR)
            
            # Draw dots on the from_image using OpenCV
            dot_positions = points.cpu().numpy().T
            dot_positions = dot_positions[~np.isnan(dot_positions).any(axis=1)]  # Remove NaN values
            dot_positions = dot_positions.astype(np.int32)
            from_image_cv[
                np.clip(dot_positions[:, 1], 0, from_image_cv.shape[0]-1),
                np.clip(dot_positions[:, 0], 0, from_image_cv.shape[1]-1),
            ] = (0, 0, 255)  # Draw red dots
            
            # Convert the OpenCV image back to PIL format
            from_image_with_dots = Image.fromarray(cv2.cvtColor(from_image_cv, cv2.COLOR_BGR2RGB))
            
            animation_frames.append(from_image_with_dots)
        
        rp.display_image_slideshow(animation_frames)


def demo_flow_anim():
    with torch.no_grad():
        
        # Usage example
        device = "cuda" if torch.cuda.is_available() else "cpu"
        optical_flow = RaftOpticalFlow(device)
    
        
        # Load video
        video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
        video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
        _ = urlretrieve(video_url, video_path)

        video_path = '/root/CleanCode/Projects/deepfloyd_init_test/videos/diffuse_kevin_spin_height1024.mp4'
        
        frames = rp.load_video(video_path)
        
        # Demo optical flow with dots
        optical_flow.demo_optical_flow_dots(frames[100], frames[110], num_rows=20, num_cols=30)
        
        # Demo optical flow animation
        optical_flow.demo_optical_flow_animation(frames[100:200], num_rows=400, num_cols=200)




"""
TODO:
- Bulk operations: getting flows from a video tensor or iterable and returning a generator (lazy) or a tensor, with a show_progress
- Einops all the way
- Document each func and tensor shapes
- backwards / forwards occlusion detection + demo

For gaussian swarm we need:
def image_to_points(image, mask) --> xy, rgb
def points_to_image(xy, rgb, height, width) --> area_image, rgb_sum_image, xy_sum_image
"""

demo_flow_anim()