In [None]:
import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from torchvision.io import read_video
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image):
        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(520, 960)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        return list_of_flows[-1][0]
    def get_optical_flow(self, from_image, to_image):
        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        return list_of_flows[-1][0]

    def demo_optical_flow_dots(self, from_image, to_image, num_dots=50):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        # Resize the predicted flow to match the original image size
        height, width = from_image.size
        predicted_flow = T.Resize((height, width))(predicted_flow.unsqueeze(0)).squeeze(0)
        
        # Generate random dot coordinates
        dots_x = np.random.randint(0, width, num_dots)
        dots_y = np.random.randint(0, height, num_dots)
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()

# Usage example
device = "cuda" if torch.cuda.is_available() else "cpu"
optical_flow = RaftOpticalFlow(device)

# Load video
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path)

frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

# Select two frames
img1 = frames[100]
img2 = frames[101]

# Convert frames to PIL images
img1 = T.ToPILImage()(img1)
img2 = T.ToPILImage()(img2)

optical_flow.demo_optical_flow_dots(img1, img2, num_dots=100)

In [None]:
import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from torchvision.io import read_video
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import rp
import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from PIL import Image
import matplotlib.pyplot as plt
from icecream import ic

class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image):
        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(520, 960)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        return list_of_flows[-1][0]

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_dots=50):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        # Generate random dot coordinates
        height, width = from_image.size
        dots_x = np.random.randint(0, width, num_dots)
        dots_y = np.random.randint(0, height, num_dots)
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()



    def get_optical_flow(self, from_image, to_image):
        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        return list_of_flows[-1][0]

    def demo_optical_flow_dots(self, from_image, to_image, num_dots=50):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        # Resize the predicted flow to match the original image size
        height, width = from_image.size
        predicted_flow = T.Resize((height, width))(predicted_flow.unsqueeze(0)).squeeze(0)
        
        # Generate random dot coordinates
        dots_x = np.random.randint(0, width, num_dots)
        dots_y = np.random.randint(0, height, num_dots)
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()

# Usage example
device = "cuda" if torch.cuda.is_available() else "cpu"
optical_flow = RaftOpticalFlow(device)

# Load video
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path)

frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

# Select two frames
img1 = frames[100]
img2 = frames[101]

# Convert frames to PIL images
img1 = T.ToPILImage()(img1)
img2 = T.ToPILImage()(img2)

optical_flow.demo_optical_flow(img1, img2)
optical_flow.demo_optical_flow_dots(img1, img2)

In [None]:
predicted_flow = optical_flow.get_optical_flow(img1, img2)

In [None]:
predicted_flow.shape

In [None]:
import rp
rp.as_numpy_image(img1).shape

In [None]:
import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from torchvision.io import read_video
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_dots=50):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        # Generate random dot coordinates
        width, height = from_image.size
        dots_x = np.random.randint(0, width, num_dots)
        dots_y = np.random.randint(0, height, num_dots)
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()

# Usage example
device = "cuda" if torch.cuda.is_available() else "cpu"
if not 'optical_flow' in vars():
    optical_flow = RaftOpticalFlow(device)

# Load video
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path)

frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

# Select two frames
img1 = frames[100]
img2 = frames[110]

# Convert frames to PIL images
img1 = T.ToPILImage()(img1)
img2 = T.ToPILImage()(img2)

optical_flow.demo_optical_flow(img1, img2)
optical_flow.demo_optical_flow_dots(img1, img2)

In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms as T
from icecream import ic
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image


class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_rows=10, num_cols=10):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        width, height = from_image.size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = dots_x.flatten()
        dots_y = dots_y.flatten()
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()

    def demo_optical_flow_animation(self, frames, num_rows=10, num_cols=10):
        width, height = frames[0].size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = dots_x.flatten()
        dots_y = dots_y.flatten()
        
        animation_frames = []
        
        for frame_idx in range(len(frames) - 1):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]
            
            predicted_flow = self.get_optical_flow(from_image, to_image)
            
            # Draw dots on the from_image
            from_image_with_dots = from_image.copy()
            draw = ImageDraw.Draw(from_image_with_dots)
            for x, y in zip(dots_x, dots_y):
                draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
            
            # Warp the dots to the to_image based on the predicted flow
            for x, y in zip(dots_x, dots_y):
                flow_x = predicted_flow[0, y, x].item()
                flow_y = predicted_flow[1, y, x].item()
                warped_x = x + flow_x
                warped_y = y + flow_y
                draw.line([(x, y), (warped_x, warped_y)], fill="blue", width=1)
                draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
            
            animation_frames.append(from_image_with_dots)
        
        rp.display_image_slideshow(animation_frames)

# Usage example
device = "cuda" if torch.cuda.is_available() else "cpu"
optical_flow = RaftOpticalFlow(device)

# Load video
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path)

frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

# Convert frames to PIL images
pil_frames = [T.ToPILImage()(frame) for frame in frames]

# Demo optical flow with dots
optical_flow.demo_optical_flow_dots(pil_frames[100], pil_frames[110], num_rows=20, num_cols=30)

# Demo optical flow animation
optical_flow.demo_optical_flow_animation(pil_frames[100:120], num_rows=20, num_cols=30)

In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms as T
from icecream import ic
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image

import torch
import torchvision.transforms as T
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from torchvision.io import read_video
from PIL import Image, ImageDraw
import numpy as np
import tempfile
from pathlib import Path
from urllib.request import urlretrieve
from tqdm import tqdm

import rp

class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_rows=10, num_cols=10):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        width, height = from_image.size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = dots_x.flatten()
        dots_y = dots_y.flatten()
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
        
    def query_optical_flow(self, coordinates, flow_map):
        # print("Coordinates shape:", coordinates.shape)
        # print("Flow map shape:", flow_map.shape)
        
        # Normalize coordinates to [-1, 1] range
        height, width = flow_map.shape[1:]
        normalized_coordinates = coordinates.clone()
        normalized_coordinates[0] = (coordinates[0] / (width - 1)) * 2 - 1
        normalized_coordinates[1] = (coordinates[1] / (height - 1)) * 2 - 1
        normalized_coordinates = normalized_coordinates.permute(1, 0).unsqueeze(0).unsqueeze(0)
        
        # print("Normalized coordinates shape:", normalized_coordinates.shape)
        
        # Perform bilinear interpolation using grid_sample
        flow_map_permuted = flow_map.unsqueeze(0)
        # print("Flow map permuted shape:", flow_map_permuted.shape)
        
        deltas = torch.nn.functional.grid_sample(flow_map_permuted, normalized_coordinates, mode='bilinear', align_corners=True)
        # print("Deltas shape after grid_sample:", deltas.shape)
        
        # deltas = deltas.squeeze(0).permute(1, 0)
        deltas = deltas.squeeze(0).squeeze(1).permute(1, 0).T
        # print("Deltas shape after squeeze and permute:", deltas.shape)
        
        return deltas
    def add_optical_flow(self, coordinates, flow_map):
        deltas = self.query_optical_flow(coordinates, flow_map)
        return coordinates + deltas

    def demo_optical_flow_animation(self, frames, num_rows=100, num_cols=100):
        width, height = frames[0].size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = torch.from_numpy(dots_x.flatten()).float().to(self.device)
        dots_y = torch.from_numpy(dots_y.flatten()).float().to(self.device)
        
        coordinates = torch.stack([dots_x, dots_y])
        
        animation_frames = []
        
        for frame_idx in tqdm(range(len(frames) - 1)):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]
            
            predicted_flow = self.get_optical_flow(from_image, to_image)
            
            # Update dot positions by accumulating optical flows
            coordinates = self.add_optical_flow(coordinates, predicted_flow)
            
            # Convert PIL image to OpenCV format
            from_image_cv = cv2.cvtColor(np.array(from_image), cv2.COLOR_RGB2BGR)
            
            # Draw dots on the from_image using OpenCV
            dot_positions = coordinates.cpu().numpy().T
            dot_positions = dot_positions[~np.isnan(dot_positions).any(axis=1)]  # Remove NaN values
            dot_positions = dot_positions.astype(np.int32)
            from_image_cv[dot_positions[:, 1], dot_positions[:, 0]] = (0, 0, 255)  # Draw red dots
            
            # Convert the OpenCV image back to PIL format
            from_image_with_dots = Image.fromarray(cv2.cvtColor(from_image_cv, cv2.COLOR_BGR2RGB))
            
            animation_frames.append(from_image_with_dots)
        
        rp.display_image_slideshow(animation_frames)

def demo_flow_anim():
    with torch.no_grad():
        
        # Usage example
        device = "cuda" if torch.cuda.is_available() else "cpu"
        optical_flow = RaftOpticalFlow(device)
        
        # Load video
        video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
        video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
        _ = urlretrieve(video_url, video_path)
        
        frames, _, _ = read_video(str(video_path))
        frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        
        # Convert frames to PIL images
        pil_frames = [T.ToPILImage()(frame) for frame in frames]
        
        # Demo optical flow with dots
        optical_flow.demo_optical_flow_dots(pil_frames[100], pil_frames[110], num_rows=20, num_cols=30)
        
        # Demo optical flow animation
        optical_flow.demo_optical_flow_animation(pil_frames[100:200], num_rows=20, num_cols=30)

In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms as T
from einops import rearrange
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image


class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def query_optical_flow(self, coordinates, flow_map):
        print("Coordinates shape:", coordinates.shape)
        print("Flow map shape:", flow_map.shape)
        
        # Normalize coordinates to [-1, 1] range
        height, width = flow_map.shape[1:]
        normalized_coordinates = coordinates.clone()
        normalized_coordinates[0] = (coordinates[0] / (width - 1)) * 2 - 1
        normalized_coordinates[1] = (coordinates[1] / (height - 1)) * 2 - 1
        normalized_coordinates = rearrange(normalized_coordinates, 'c n -> () () n c')
        
        print("Normalized coordinates shape:", normalized_coordinates.shape)
        
        # Perform bilinear interpolation using grid_sample
        flow_map_permuted = rearrange(flow_map, 'c h w -> () c h w')
        print("Flow map permuted shape:", flow_map_permuted.shape)
        
        deltas = torch.nn.functional.grid_sample(flow_map_permuted, normalized_coordinates, mode='bilinear', align_corners=True)
        print("Deltas shape after grid_sample:", deltas.shape)
        
        deltas = rearrange(deltas, '() c () n -> n c')
        print("Deltas shape after rearrange:", deltas.shape)
        
        return deltas

    def add_optical_flow(self, coordinates, flow_map):
        deltas = self.query_optical_flow(coordinates, flow_map)
        return coordinates + deltas

    def demo_optical_flow_animation(self, frames, num_rows=10, num_cols=10):
        width, height = frames[0].size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = torch.from_numpy(dots_x.flatten()).float().to(self.device)
        dots_y = torch.from_numpy(dots_y.flatten()).float().to(self.device)
        
        coordinates = torch.stack([dots_x, dots_y])
        
        animation_frames = []
        
        for frame_idx in range(len(frames) - 1):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]
            
            predicted_flow = self.get_optical_flow(from_image, to_image)
            
            # Update dot positions by accumulating optical flows
            coordinates = self.add_optical_flow(coordinates, predicted_flow)
            
            # Draw dots on the from_image
            from_image_with_dots = from_image.copy()
            draw = ImageDraw.Draw(from_image_with_dots)
            for x, y in zip(coordinates[0], coordinates[1]):
                if not (torch.isnan(x) or torch.isnan(y)):
                    draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
            
            animation_frames.append(from_image_with_dots)
        
        rp.display_image_slideshow(animation_frames)



In [None]:

def demo_flow_anim():
    with torch.no_grad():
        
        # Usage example
        device = "cuda" if torch.cuda.is_available() else "cpu"
        optical_flow = RaftOpticalFlow(device)
        
        # Load video
        video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
        video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
        _ = urlretrieve(video_url, video_path)
        
        frames, _, _ = read_video(str(video_path))
        frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        
        # Convert frames to PIL images
        pil_frames = [T.ToPILImage()(frame) for frame in frames]
        
        # Demo optical flow with dots
        optical_flow.demo_optical_flow_dots(pil_frames[100], pil_frames[110], num_rows=20, num_cols=30)
        
        # Demo optical flow animation
        optical_flow.demo_optical_flow_animation(pil_frames[100:200], num_rows=200, num_cols=300)

In [None]:
demo_flow_anim()

In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms as T
from icecream import ic
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from tqdm import tqdm

class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def query_optical_flow(self, coordinates, flow_map):
        # Normalize coordinates to [-1, 1] range
        height, width = flow_map.shape[1:]
        normalized_coordinates = coordinates.clone()
        normalized_coordinates[0] = (coordinates[0] / (width - 1)) * 2 - 1
        normalized_coordinates[1] = (coordinates[1] / (height - 1)) * 2 - 1
        normalized_coordinates = normalized_coordinates.permute(1, 0).unsqueeze(0).unsqueeze(0)

        # Perform bilinear interpolation using grid_sample
        flow_map_permuted = flow_map.unsqueeze(0)

        deltas = torch.nn.functional.grid_sample(flow_map_permuted, normalized_coordinates, mode='bilinear', align_corners=True)
        deltas = deltas.squeeze(0).squeeze(1).permute(1, 0).T

        return deltas

    def add_optical_flow(self, coordinates, flow_map):
        deltas = self.query_optical_flow(coordinates, flow_map)
        return coordinates + deltas

    def demo_optical_flow_animation(self, frames, num_rows=100, num_cols=100):
        width, height = frames[0].size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)

        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = torch.from_numpy(dots_x.flatten()).float().to(self.device)
        dots_y = torch.from_numpy(dots_y.flatten()).float().to(self.device)

        coordinates = torch.stack([dots_x, dots_y])

        animation_frames = []

        for frame_idx in tqdm(range(len(frames) - 1)):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]

            predicted_flow = self.get_optical_flow(from_image, to_image)

            # Update dot positions by accumulating optical flows
            coordinates = self.add_optical_flow(coordinates, predicted_flow)

            # Draw dots on the from_image
            from_image_with_dots = from_image.copy()
            draw = ImageDraw.Draw(from_image_with_dots)
            for x, y in zip(coordinates[0], coordinates[1]):
                if not (torch.isnan(x) or torch.isnan(y)):
                    draw.ellipse((x-1, y-1, x+1, y+1), fill="red")

            animation_frames.append(from_image_with_dots)

        rp.display_image_slideshow(animation_frames)

# Usage example
device = "cuda" if torch.cuda.is_available() else "cpu"
optical_flow = RaftOpticalFlow(device)

# Load video
video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
_ = urlretrieve(video_url, video_path)

frames, _, _ = read_video(str(video_path))
frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

# Convert frames to PIL images
pil_frames = [T.ToPILImage()(frame) for frame in frames]

# Demo optical flow animation
optical_flow.demo_optical_flow_animation(pil_frames[100:200], num_rows=200, num_cols=300)

In [None]:
import tempfile
from pathlib import Path
from urllib.request import urlretrieve

import cv2
import matplotlib.pyplot as plt
import numpy as np
import rp
import torch
import torchvision.transforms as T
from icecream import ic
from PIL import Image, ImageDraw
from torchvision.io import read_video
from torchvision.models.optical_flow import raft_large
from torchvision.utils import flow_to_image
from tqdm import tqdm


class RaftOpticalFlow:
    def __init__(self, device):
        self.device = device
        self.model = raft_large(pretrained=True, progress=False).to(device)
        self.model.eval()

    def preprocess(self, image, multiple_of=8):
        # Ensure image dimensions are divisible by `multiple_of`
        width, height = image.size
        new_height = (height // multiple_of) * multiple_of
        new_width = (width // multiple_of) * multiple_of

        transforms = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
                T.Resize(size=(new_height, new_width)),
            ]
        )
        return transforms(image).unsqueeze(0).to(self.device)

    def get_optical_flow(self, from_image, to_image):
        assert from_image.size == to_image.size, "Input images must have the same size"

        img1 = self.preprocess(from_image)
        img2 = self.preprocess(to_image)
        list_of_flows = self.model(img1, img2)
        output_flow = list_of_flows[-1][0]

        # Resize the predicted flow back to the original image size
        width, height = from_image.size
        output_flow = T.Resize((height, width))(output_flow.unsqueeze(0)).squeeze(0)

        return output_flow

    def demo_optical_flow(self, from_image, to_image):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        flow_img = flow_to_image(predicted_flow)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image)
        ax1.set_title("Input Image")
        ax1.axis("off")
        ax2.imshow(flow_img.permute(1, 2, 0).cpu().numpy())
        ax2.set_title("Predicted Optical Flow")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
    
    def demo_optical_flow_dots(self, from_image, to_image, num_rows=10, num_cols=10):
        predicted_flow = self.get_optical_flow(from_image, to_image)
        
        width, height = from_image.size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = dots_x.flatten()
        dots_y = dots_y.flatten()
        
        # Draw dots on the from_image
        from_image_with_dots = from_image.copy()
        draw = ImageDraw.Draw(from_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            draw.ellipse((x-2, y-2, x+2, y+2), fill="red")
        
        # Warp the dots to the to_image based on the predicted flow
        to_image_with_dots = to_image.copy()
        draw = ImageDraw.Draw(to_image_with_dots)
        for x, y in zip(dots_x, dots_y):
            flow_x = predicted_flow[0, y, x].item()
            flow_y = predicted_flow[1, y, x].item()
            warped_x = x + flow_x
            warped_y = y + flow_y
            draw.ellipse((warped_x-2, warped_y-2, warped_x+2, warped_y+2), fill="blue")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        ax1.imshow(from_image_with_dots)
        ax1.set_title("Input Image with Dots")
        ax1.axis("off")
        ax2.imshow(to_image_with_dots)
        ax2.set_title("Output Image with Warped Dots")
        ax2.axis("off")
        plt.tight_layout()
        plt.show()
        
    def query_optical_flow(self, coordinates, flow_map):
        # print("Coordinates shape:", coordinates.shape)
        # print("Flow map shape:", flow_map.shape)
        
        # Normalize coordinates to [-1, 1] range
        height, width = flow_map.shape[1:]
        normalized_coordinates = coordinates.clone()
        normalized_coordinates[0] = (coordinates[0] / (width - 1)) * 2 - 1
        normalized_coordinates[1] = (coordinates[1] / (height - 1)) * 2 - 1
        normalized_coordinates = normalized_coordinates.permute(1, 0).unsqueeze(0).unsqueeze(0)
        
        # print("Normalized coordinates shape:", normalized_coordinates.shape)
        
        # Perform bilinear interpolation using grid_sample
        flow_map_permuted = flow_map.unsqueeze(0)
        # print("Flow map permuted shape:", flow_map_permuted.shape)
        
        deltas = torch.nn.functional.grid_sample(flow_map_permuted, normalized_coordinates, mode='bilinear', align_corners=True)
        # print("Deltas shape after grid_sample:", deltas.shape)
        
        # deltas = deltas.squeeze(0).permute(1, 0)
        deltas = deltas.squeeze(0).squeeze(1).permute(1, 0).T
        # print("Deltas shape after squeeze and permute:", deltas.shape)
        
        return deltas
    def add_optical_flow(self, coordinates, flow_map):
        deltas = self.query_optical_flow(coordinates, flow_map)
        return coordinates + deltas

    def demo_optical_flow_animation(self, frames, num_rows=100, num_cols=100):
        width, height = frames[0].size
        x_step = width // (num_cols + 1)
        y_step = height // (num_rows + 1)
        
        dots_x, dots_y = np.meshgrid(np.arange(x_step, width, x_step), np.arange(y_step, height, y_step))
        dots_x = torch.from_numpy(dots_x.flatten()).float().to(self.device)
        dots_y = torch.from_numpy(dots_y.flatten()).float().to(self.device)
        
        coordinates = torch.stack([dots_x, dots_y])
        
        animation_frames = []
        
        for frame_idx in tqdm(range(len(frames) - 1)):
            from_image = frames[frame_idx]
            to_image = frames[frame_idx + 1]
            
            predicted_flow = self.get_optical_flow(from_image, to_image)
            
            # Update dot positions by accumulating optical flows
            coordinates = self.add_optical_flow(coordinates, predicted_flow)
            
            # Convert PIL image to OpenCV format
            from_image_cv = cv2.cvtColor(np.array(from_image), cv2.COLOR_RGB2BGR)
            
            # Draw dots on the from_image using OpenCV
            dot_positions = coordinates.cpu().numpy().T
            dot_positions = dot_positions[~np.isnan(dot_positions).any(axis=1)]  # Remove NaN values
            dot_positions = dot_positions.astype(np.int32)
            from_image_cv[
                np.clip(dot_positions[:, 1], 0, from_image_cv.shape[0]-1),
                np.clip(dot_positions[:, 0], 0, from_image_cv.shape[1]-1),
            ] = (0, 0, 255)  # Draw red dots
            
            # Convert the OpenCV image back to PIL format
            from_image_with_dots = Image.fromarray(cv2.cvtColor(from_image_cv, cv2.COLOR_BGR2RGB))
            
            animation_frames.append(from_image_with_dots)
        
        rp.display_image_slideshow(animation_frames)


def demo_flow_anim():
    with torch.no_grad():
        
        # Usage example
        device = "cuda" if torch.cuda.is_available() else "cpu"
        optical_flow = RaftOpticalFlow(device)
        
        # Load video
        video_url = "https://download.pytorch.org/tutorial/pexelscom_pavel_danilyuk_basketball_hd.mp4"
        video_path = Path(tempfile.mkdtemp()) / "basketball.mp4"
        _ = urlretrieve(video_url, video_path)
        
        frames, _, _ = read_video(str(video_path))
        frames = frames.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        
        # Convert frames to PIL images
        pil_frames = [T.ToPILImage()(frame) for frame in frames]
        
        # Demo optical flow with dots
        optical_flow.demo_optical_flow_dots(pil_frames[100], pil_frames[110], num_rows=20, num_cols=30)
        
        # Demo optical flow animation
        optical_flow.demo_optical_flow_animation(pil_frames[100:200], num_rows=200, num_cols=300)



"""
TODO:
- Bulk operations: getting flows from a video tensor or iterable and returning a generator (lazy) or a tensor, with a show_progress
- Einops all the way
- Document each func and tensor shapes
- backwards / forwards occlusion detection + demo

"""

In [None]:
demo_flow_anim()